-
Notifications
You must be signed in to change notification settings - Fork 2.2k
⚡ Up to 4x faster: Data Parallel for vLLM server #3310
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
trl/scripts/vllm_serve.py
Outdated
@@ -226,6 +236,45 @@ class ScriptArguments: | |||
) | |||
|
|||
|
|||
def llm_worker(script_args, data_parallel_rank, connection): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main change is that instead of instantiating a single LLM
in the main process, we now need to spawn dp
subprocesses — each responsible for creating its own LLM
instance. We then set up communication between the main process and each subprocess.
While this approach may seem a bit more complex, it's necessary because vLLM depends heavily on environment variables and doesn't accommodate well with running multiple LLM instances within the same process.
Spawning separate subprocesses is the only reliable way to isolate and manage multiple LLM
instances.
trl/scripts/vllm_serve.py
Outdated
for connection, prompts in zip(connections, chunked_prompts): | ||
kwargs = {"prompts": prompts, "sampling_params": sampling_params} | ||
connection.send({"type": "call", "method": "generate", "kwargs": kwargs}) | ||
|
||
# Wait for and collect all results | ||
all_outputs = [connection.recv() for connection in connections] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can't call .generate
directly anymore, since the LLM
instances are in subprocesses. Hence, we're sending a communication instruction, and wait for the results.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So here the outputs will not get mixed up as you go over each connection in order?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so, I'll check manually though as this could lead to a silent bug or unwanted behavior if the order is mixed up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if the output gets mixed up then we can use
from collections import OrderedDict
to preserve the order of prompt-output!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@qgallouedec Did you just check this by any chance? are we getting the prompt-responses aligned with the connections?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not yet
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very clean implementation! LGTM with some nits and a question about whether we can support CUDA graphs when DP=1 and TP>1
Hello, thank you very much for TRL's support for vLLM DP, which is exactly what I've been looking forward to and needing. It has greatly accelerated my experiments. However, I encountered an issue when running vllm_serve. NCCL_DEBUG=WARN python -m trl.cli vllm-serve \
--model /mnt/tenant-home_speed/Model/Qwen/Qwen2.5-7B-Instruct \
--tensor_parallel_size 1 \
--data_parallel_size 8 \
--host 0.0.0.0 \
--port 6004 This resulted in an error:
After searching for the cause for a long time, I finally discovered that deleting all TRL-related code from the vllm_server.py file allows it to work normally, specifically: # from trl import TrlParser
# from trl.import_utils import (
# is_fastapi_available,
# is_pydantic_available,
# is_uvicorn_available,
# is_vllm_available,
# )
# if is_fastapi_available():
# from fastapi import FastAPI
# if is_pydantic_available():
# from pydantic import BaseModel
# if is_uvicorn_available():
# import uvicorn
# if is_vllm_available():
# from vllm import LLM, SamplingParams
# from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
# from vllm.distributed.parallel_state import get_world_group
# from vllm.distributed.utils import StatelessProcessGroup
# from vllm.sampling_params import GuidedDecodingParams
# from vllm.utils import get_open_port
# copy class TrlParser(HfArgumentParser): to there
... and run NCCL_DEBUG=WARN python vllm_serve.py \
--model /mnt/tenant-home_speed/Model/Qwen/Qwen2.5-7B-Instruct \
--tensor_parallel_size 1 \
--data_parallel_size 8 \
--host 0.0.0.0 \
--port 6004 This works correctly. I wonder if the error might be that when importing TRL, certain processes are placed on the cuda:0 device, which causes this error? Could you please help look into this error? Thank you for your assistance. The complete error log:
|
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Unfortunately, I don't believe this feature works properly ! I am not able to run anything with DP>1 as I get this weird error (log for 2 nodes for the DeepSeek-R1-Distill-Qwen-7B, using trl=0.17.0 and vllm==0.8.3): (EngineCore_1 pid=2353892) INFO 04-25 16:14:42 [worker_base.py:589] Injected <class 'trl.scripts.vllm_serve.WeightSyncWorkerExtension'> into <class 'vllm.v1.worker.gpu_worker.Worker'> for extended collective_rpc calls ['close_communicator', 'init_communicator', 'update_named_param'] batch-block7-00733:2353893:2353893 [0] init.cc:943 NCCL WARN Duplicate GPU detected : rank 0 and rank 1 both on CUDA device f000 batch-block7-00733:2353892:2353892 [0] init.cc:943 NCCL WARN Duplicate GPU detected : rank 1 and rank 0 both on CUDA device f000 batch-block7-00733:2353883:2353883 [0] init.cc:943 NCCL WARN Duplicate GPU detected : rank 2 and rank 0 both on CUDA device f000 ERROR: Application startup failed. Exiting. |
What command do you use? |
To run the vllm part, I use this:
The above works fine if DP=1 and TP set properly. |
It could be related to this: Lines 100 to 106 in 29c5e05
Try to replace |
Thanks @qgallouedec ! just tested it with TP=1 and DP=8 and it works ! |
@qgallouedec unfortunately the issue seems to persist despite seemingly being resolved at first. This time, we use TP=1 and DP=8. I'd appreciate if you may have any insights here:
It seems like we do indeed finish several completions, before running into this weird error ! |
Are you using a modified version of GRPO? |
Yes ! but it works without any issues with the previous version which is basically TP=1 and DP=1. |
Try adding this line: self.vllm_client = VLLMClient(args.vllm_server_host, args.vllm_server_port, connection_timeout=args.vllm_server_timeout)
+ self.vllm_client.init_communicator() |
Thanks @qgallouedec ! kick started another training with TP=1 and DP=8 and have not noticed any issues at least for now. Hopefully the issue is resolved. Thanks again for your amazing work ! |
Usage:
For the client: nothing changes: