-
Notifications
You must be signed in to change notification settings - Fork 2.1k
🚀 Scaling GRPO to 70B+ Models and Multi-Node Training with vLLM Server & NCCL Communication #3094
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
@binary-husky thank you very much for this work. It gave us a better understanding of how to achieve this. I wanted to take a more ambitious approach and decided to refactor it further. Since this was more than I could reasonably ask to an external contributor, I took the liberty of committing the changes directly to your branch. I hope that’s okay with you! |
# 3094.py
from datasets import load_dataset
from trl import GRPOTrainer, GRPOConfig
dataset = load_dataset("trl-lib/tldr", split="train")
# Dummy reward function: count the number of unique characters in the completions
def reward_num_unique_chars(completions, **kwargs):
return [len(set(c)) for c in completions]
training_args = GRPOConfig(output_dir="3094", use_vllm=True, bf16=True, gradient_checkpointing=True, logging_steps=10)
trainer = GRPOTrainer(
model="Qwen/Qwen2.5-7B",
args=training_args,
reward_funcs=reward_num_unique_chars,
train_dataset=dataset,
)
trainer.train()
|
Sorry I still haven't make this work, how to make 4 GPU in machine 1 for VLLM the rest 4 and the whole machine 2 for training? as stated here: 2 machine | 1 for training, 1 for VLLM | using NCCL to deliver param updates(1) start MAIN TRAINING script:
|
Ignore the pr description it's an old version. Please refer to the doc |
the doc use SLURM, it only show how to use the whole node for VLLM, can we still do something like: |
@Andcircle You can refer to my personal notebook below for training 32B Qwen, it is ugly, not general, but may deliver some basic ideas:
|
@binary-husky awesome! really appreciated!! |
I'm trying to use GPU as efficient as possible in your above solution, in machine 1, the 0,1,2,3 used for vllm, then 4,5,6,7 can't be used for training anymore. But actually it doesn't work, the vllm client update from machine3 will have error as following: Any hints how should I make this setup work?
|
Maybe the easiest is to use 4 machines? (1 node for training, 1 for vLLM)x2 |
@binary-husky Great job. |
4 GPU is more than enough for vLLM, which means the rest 4 are wasted. |
2 vllms? There are two ports you need to consider, you probably forget the other one? Please check port conflict ~ |
Yeah I set this through GRPOconfig to different port. |
@Andcircle Sorry, but group port is not exposed to |
32B model with ZeRO3 and sync_ref_model = true,will raise OOM in SyncRefModelCallback::sync_target_model(). error stack: |
…r & NCCL Communication (huggingface#3094) * 🚀allow GRPO to connect to VLLM in remote/local node with NCCL communication * Update trl/extras/remote_vllm_helper.py Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> * use argparse for options * add imports for remote vllm helper * formatting * fix arguments * use cli options * vllm serve * clean server * better naming * client * style * new params in generate * this method is the new default * update config * do not use asserts * update config * separate host and post * proper deprectation * deprecated arg in the vllm server * simplify moving * document host and port * style * update trainer * new generate args * update doc * Fix for zero3 * Better naming * Remove remote_vllm_helper * remove grpo_with_remote_vllm * remove cloudpickle from deps * Some consistency * Update docs/source/grpo_trainer.md Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * Update setup.py Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * add revision argument to vllm server * Update docs/source/grpo_trainer.md Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * Update docs/source/grpo_trainer.md Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * Reset the prefix cache after updating weights * Update vllm_client.py * Update vllm_client.py * Update vllm_serve.py * Add health check endpoint to vLLM server * connection timeout * style * fix doc langauge hint * move reset_prefix_cache to its own endpoint * async * merge peft adaptor to send to vllm * Looks simple. Wasn't. * Peft compatibility * Update docs/source/speeding_up_training.md Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * Update docs/source/speeding_up_training.md Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * Update trl/extras/vllm_client.py Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * GatheredParameters can be disabled * gather and ungather peft weights within the same deepseed context * use is_vllm_available * minor consistency fixes * fix error when deepspeed is not installed * fix deepspeed import when not peft * simpler * multinode doc * minor code and comments changes * style * optional deps * vllm_server_timeout as arg * small refinement in doc * update deps * Fix VLLMClient argument in grpo_trainer; Add zero3+peft vllm transfer solution * Revert "Fix VLLMClient argument in grpo_trainer; Add zero3+peft vllm transfer solution" This reverts commit d759c9c. * log num_tokens * disable vllm test (in the future we'll add a mock for vllm server for them) * style * fix ds3_gather_for_generation --------- Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
@binary-husky , thanks for this. Im trying to finetune llama 405b and it uses 16h100s (2 nodes) for vLLM and 8 nodes for training. can you provide me a similar commands config which uses 2 nodes for vllms and the rest for training? Thanks in advance. |
use this one as workaround: #3094 (comment) @tingkuanpei |
@vamshi-rvk sorry, currently I'm unable to allocate that many machines |
@binary-husky Hello, referring to your sharing, I used the first four cards of a single H100 to start the VLLM service, while the other two H100s are used for training. However, I encountered the following error. Do you know how to solve this issue? [Rank12] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=8834, OpType=_ALLGATHER_BASE, NumelIn=1638400, NumelOut=26214400, Timeout(ms)=1800000) ran for 1800055 milliseconds before timing out.
... |
@tongtong0613 I have seen |
…r & NCCL Communication (huggingface#3094) * 🚀allow GRPO to connect to VLLM in remote/local node with NCCL communication * Update trl/extras/remote_vllm_helper.py Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> * use argparse for options * add imports for remote vllm helper * formatting * fix arguments * use cli options * vllm serve * clean server * better naming * client * style * new params in generate * this method is the new default * update config * do not use asserts * update config * separate host and post * proper deprectation * deprecated arg in the vllm server * simplify moving * document host and port * style * update trainer * new generate args * update doc * Fix for zero3 * Better naming * Remove remote_vllm_helper * remove grpo_with_remote_vllm * remove cloudpickle from deps * Some consistency * Update docs/source/grpo_trainer.md Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * Update setup.py Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * add revision argument to vllm server * Update docs/source/grpo_trainer.md Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * Update docs/source/grpo_trainer.md Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * Reset the prefix cache after updating weights * Update vllm_client.py * Update vllm_client.py * Update vllm_serve.py * Add health check endpoint to vLLM server * connection timeout * style * fix doc langauge hint * move reset_prefix_cache to its own endpoint * async * merge peft adaptor to send to vllm * Looks simple. Wasn't. * Peft compatibility * Update docs/source/speeding_up_training.md Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * Update docs/source/speeding_up_training.md Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * Update trl/extras/vllm_client.py Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * GatheredParameters can be disabled * gather and ungather peft weights within the same deepseed context * use is_vllm_available * minor consistency fixes * fix error when deepspeed is not installed * fix deepspeed import when not peft * simpler * multinode doc * minor code and comments changes * style * optional deps * vllm_server_timeout as arg * small refinement in doc * update deps * Fix VLLMClient argument in grpo_trainer; Add zero3+peft vllm transfer solution * Revert "Fix VLLMClient argument in grpo_trainer; Add zero3+peft vllm transfer solution" This reverts commit d759c9c. * log num_tokens * disable vllm test (in the future we'll add a mock for vllm server for them) * style * fix ds3_gather_for_generation --------- Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
|
For peft we need to merge the adapter before moving the model to vLLM. But there is currently no way to do it in a distributed manner. So the answer is, for now, there is no solution |
Warning
The following description is outdated, please refer to the TRL doc
What does this PR do?
This PR isolates VLLM from main GRPO training process(es), using only http & NCCL to communicate with a VLLM instance.
By achieving this isolation:
(1) we can easily address almost all issues related to VLLM GPU device arrangement (simply by setting
CUDA_VISIBLE_DEVICES
). Such as:(2) we can scale to model of any size without worrying about VLLM (we are free to place VLLM on any machine as long as the training process can reach it with TCP). And addressing issues such as:
by the way, I initially came from open-r1 resp, but obviously problem cannot be resolve from there
I have run tests on 32B models (2 accelerate nodes + 1 vllm node), so far so good.
Current limitation
_move_model_to_remote_vllm
at https://github.com/binary-husky/trl/blob/765891c5d39d4e59dca9c3f7c2da0faeeba8f7c7/trl/trainer/grpo_trainer.py#L724