Skip to content

Conversation

I-l-l-I
Copy link
Contributor

@I-l-l-I I-l-l-I commented Apr 10, 2025

What does this PR do?

Fixes #3142, #3157 (comment). Support vLLM V1 Engine for faster generation.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@qgallouedec
Copy link
Member

Thanks @I-l-l-I!
What version of vllm do you use? I get this error when trying to initialise the client:

Traceback (most recent call last):
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.12/site-packages/uvicorn/protocols/http/httptools_impl.py", line 409, in run_asgi
    result = await app(  # type: ignore[func-returns-value]
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.12/site-packages/uvicorn/middleware/proxy_headers.py", line 60, in __call__
    return await self.app(scope, receive, send)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.12/site-packages/fastapi/applications.py", line 1054, in __call__
    await super().__call__(scope, receive, send)
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.12/site-packages/starlette/applications.py", line 112, in __call__
    await self.middleware_stack(scope, receive, send)
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.12/site-packages/starlette/middleware/errors.py", line 187, in __call__
    raise exc
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.12/site-packages/starlette/middleware/errors.py", line 165, in __call__
    await self.app(scope, receive, _send)
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.12/site-packages/starlette/middleware/exceptions.py", line 62, in __call__
    await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.12/site-packages/starlette/_exception_handler.py", line 53, in wrapped_app
    raise exc
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.12/site-packages/starlette/_exception_handler.py", line 42, in wrapped_app
    await app(scope, receive, sender)
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.12/site-packages/starlette/routing.py", line 714, in __call__
    await self.middleware_stack(scope, receive, send)
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.12/site-packages/starlette/routing.py", line 734, in app
    await route.handle(scope, receive, send)
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.12/site-packages/starlette/routing.py", line 288, in handle
    await self.app(scope, receive, send)
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.12/site-packages/starlette/routing.py", line 76, in app
    await wrap_app_handling_exceptions(app, request)(scope, receive, send)
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.12/site-packages/starlette/_exception_handler.py", line 53, in wrapped_app
    raise exc
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.12/site-packages/starlette/_exception_handler.py", line 42, in wrapped_app
    await app(scope, receive, sender)
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.12/site-packages/starlette/routing.py", line 74, in app
    await response(scope, receive, send)
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.12/site-packages/starlette/responses.py", line 160, in __call__
    await self.background()
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.12/site-packages/starlette/background.py", line 41, in __call__
    await task()
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.12/site-packages/starlette/background.py", line 28, in __call__
    await run_in_threadpool(self.func, *self.args, **self.kwargs)
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.12/site-packages/starlette/concurrency.py", line 37, in run_in_threadpool
    return await anyio.to_thread.run_sync(func)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.12/site-packages/anyio/to_thread.py", line 56, in run_sync
    return await get_async_backend().run_sync_in_worker_thread(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.12/site-packages/anyio/_backends/_asyncio.py", line 2470, in run_sync_in_worker_thread
    return await future
           ^^^^^^^^^^^^
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.12/site-packages/anyio/_backends/_asyncio.py", line 967, in run
    result = context.run(func, *args)
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.12/site-packages/vllm/entrypoints/llm.py", line 496, in collective_rpc
    return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.12/site-packages/vllm/v1/engine/llm_engine.py", line 291, in collective_rpc
    return self.engine_core.collective_rpc(method, timeout, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.12/site-packages/vllm/v1/engine/core_client.py", line 555, in collective_rpc
    return self.call_utility("collective_rpc", method, timeout, args,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.12/site-packages/vllm/v1/engine/core_client.py", line 508, in call_utility
    return future.result()
           ^^^^^^^^^^^^^^^
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.12/concurrent/futures/_base.py", line 456, in result
    return self.__get_result()
           ^^^^^^^^^^^^^^^^^^^
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.12/concurrent/futures/_base.py", line 401, in __get_result
    raise self._exception
Exception: Call to collective_rpc method failed: 'Worker' object has no attribute 'pynccl_comm

@I-l-l-I
Copy link
Contributor Author

I-l-l-I commented Apr 10, 2025

@qgallouedec I use vllm 0.8.3. The error seems to be because pynccl_comm is not initialized, but I've added judgment to each function in WeightSyncWorkerExtension, so this shouldn't happen. Everything works fine when I use it to train.

@qgallouedec
Copy link
Member

It's now working on my side 👍 no idea what I did wrong in the first place

Copy link
Member

@qgallouedec qgallouedec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you agree with https://github.com/huggingface/trl/pull/3276/files#r2037895897 then here are the modifications needed

@qgallouedec
Copy link
Member

Massive speed-up!

vllm_tp

Copy link
Member

@qgallouedec qgallouedec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @I-l-l-I, massive improvement!

@qgallouedec qgallouedec changed the title Fix vLLM server to support V1 Engine ⏱️ Fix vLLM server to support V1 Engine Apr 11, 2025
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@I-l-l-I
Copy link
Contributor Author

I-l-l-I commented Apr 11, 2025

You're welcome @qgallouedec. By the way, from your test results, it is obviously faster to generate a large number of answers at once, but now in GRPOTrainer it is necessary to generate gradient_accumulation_steps times before one optimizer step, why not merge these generation tasks and generate first? I think this can greatly improve the speed when we need gradient accumulation.

@qgallouedec
Copy link
Member

It makes sense, but 2 remarks why oit's not currently done:

  • it's not necessarily faster, see the plot. You can see that using a mini batch size of 64 is mostly equivalent to 256
  • the implementation is rather tricky because we rely on transformers trainer, and the sampling logic doesn't natively allow this. You'd have to hack the sampler/batch size or the dataloader. I'm not sure how to do that at this point, but it would undoubtedly introduce additional complexity into the code.

@qgallouedec qgallouedec merged commit d625c55 into huggingface:main Apr 11, 2025
9 checks passed
yxliu-TAMU pushed a commit to mincheolseong/ECEN743-GRPO-Project-Proposal that referenced this pull request Apr 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

TRL vllm-serve fails to load certain models
4 participants