Skip to content

πŸ§‘β€πŸ€β€πŸ§‘ Co-Locating vLLM w/ training to for higher throughput and GPU utilization #3394

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
May 1, 2025

Conversation

toslali-ibm
Copy link
Contributor

@toslali-ibm toslali-ibm commented Apr 30, 2025

What does this PR do?

Enables colocating vLLM with training in each GPU to improve utilization and throughput.

Fixes #3064 and #3113
Addresses: #3195, #2971, #2922, #2887 etc.

Enabler:

vLLM (version >0.7.3) introduced support for an external launcher, allowing vLLM processes to run alongside other workloads on the same GPU.

Benefits:

  • Faster Inference: Speeds up GRPO training by reducing inference latency via parallel prompt processing (each vLLM works on their device's batch)
  • Better GPU Efficiency: Frees up GPU resources by removing the need for a dedicated vLLM server. Multiple vLLM instances can now share GPUs with training jobs (reducing GPU idle time)
  • Supports TP + DP
  • Ray-less solution

Testing vllm colocation

Run it w/ the following:
VLLM_USE_V1=0 ACCELERATE_LOG_LEVEL=info CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch --config_file recipes/accelerate_configs/zero3.yaml --num_processes=8 -m open_r1.grpo --config config_tpcoloc.yaml

  • change vllm_colocation in the config to the sharding you would like.
  • E.g., If vllm_colocation=1, model is not sharded, each GPU holds a full copy of the model.
  • vllm_colocation=2, model is sharded by two, and groups: [0,1], [2,3], [4,5], [6,7].
Click to view config.yaml
# Model arguments
model_name_or_path: Qwen/Qwen2.5-Math-1.5B
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2

# Data training arguments
dataset_name: DigitalLearningGmbH/MATH-lighteval
dataset_config: default
dataset_prompt_column: problem
system_prompt: "You are a helpful AI Assistant, designed to provided well-reasoned and detailed responses. You FIRST think about the reasoning process as an internal monologue and then provide the user with the answer. The reasoning process MUST BE enclosed within <think> and </think> tags."

# GRPO trainer config
bf16: true
use_vllm: true
vllm_colocation: 2
vllm_gpu_memory_utilization: 0.3
vllm_max_model_len: 2048
do_eval: false
gradient_accumulation_steps: 1
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
learning_rate: 3.0e-06
log_completions: false
log_level: info
logging_first_step: true
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: cosine
max_prompt_length: 512
max_completion_length: 1024
max_steps: 50
num_generations: 8
num_train_epochs: 1
overwrite_output_dir: true
# per_device_eval_batch_size: 16
per_device_train_batch_size: 16
push_to_hub: false
report_to:
- wandb
reward_funcs:
- accuracy
- format
reward_weights:
- 1.0
- 1.0
save_strategy: steps
save_steps: 100
save_total_limit: 1
seed: 42
warmup_ratio: 0.1

Sanity check

  • GRPO results Qwen/Qwen2.5-Math-1.5B on DigitalLearningGmbH/MATH-lighteval dataset (as shown above) using both plain TRL (w/ vLLM server) and colocated TRL (w/ TP =1,TP =2, and TP =4); The rewards are identical.

TRL-PR-coloc

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

New version of #3162 (incorporated @qgallouedec 's comments)

CC @fabianlim

@toslali-ibm toslali-ibm changed the title Tpnosleep Co-Locating vLLM w/ training to for higher throughput and GPU utilization Apr 30, 2025
@toslali-ibm toslali-ibm marked this pull request as ready for review April 30, 2025 18:19
@qgallouedec
Copy link
Member

@toslali-ibm I've updated your PR by changing the logic a bit, so I'll let you have a look, test it out, and tell me what you think.

@toslali-ibm
Copy link
Contributor Author

@toslali-ibm I've updated your PR by changing the logic a bit, so I'll let you have a look, test it out, and tell me what you think.

I think it looks good. Let me run a sanity-check experiment.

torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.tp_group)
prompts_text = [p for sublist in gathered_prompts for p in sublist]

all_outputs = self.llm.generate(prompts_text, sampling_params=sampling_params, use_tqdm=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a multi-GPUsetup (2 GPUs) with TP=2, if each rank is given a separate subset of promptsβ€”e.g., rank 0 gets ["a", "b"] and rank 1 gets ["c", "d"]. Does each rank independently call:

llm.generate(["a", "b", "c", "d"])

It seems like duplicated call, but is it coordinated such that each rank only processes its subset of prompts? In other words, if the full prompt list is passed on each rank, does vLLM handle this duplication internally to avoid redundant work?

Copy link
Contributor Author

@toslali-ibm toslali-ibm May 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. When using TP along with external_launcher, we need to make sure that all participating shards receive the same prompts -- and vLLM internally handles it.

So if TP = 2 and GPU = 2, then all workers get the ["a", "b", "c", "d"]

So if TP = 1 and GPU = 2, then first worker get the ["a", "b"] and second worker get the ["c", "d"]

@toslali-ibm
Copy link
Contributor Author

I am getting an error from the current version of the code

  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/workspace/open-r1/src/open_r1/grpo.py", line 179, in <module>
    main(script_args, training_args, model_args)
  File "/workspace/open-r1/src/open_r1/grpo.py", line 133, in main
    train_result = trainer.train(resume_from_checkpoint=checkpoint)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/transformers/trainer.py", line 2238, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/transformers/trainer.py", line 2553, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/transformers/trainer.py", line 3724, in training_step
    inputs = self._prepare_inputs(inputs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/trl/trl/extras/profiling.py", line 87, in wrapper
    return func(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/trl/trl/trainer/grpo_trainer.py", line 991, in _prepare_inputs
    accumulated_local_batch = self._generate_and_score_completions(accumulated_local_batch)
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/trl/trl/trainer/grpo_trainer.py", line 1094, in _generate_and_score_completions
    completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]```

@qgallouedec
Copy link
Member

Any idea why?

@toslali-ibm
Copy link
Contributor Author

Any idea why?

There was a mismatch between config and trainer (colocate vs. colocation). I fixed that, now there is another error I am debugging:

Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/workspace/open-r1/src/open_r1/grpo.py", line 179, in <module>
    main(script_args, training_args, model_args)
  File "/workspace/open-r1/src/open_r1/grpo.py", line 133, in main
    train_result = trainer.train(resume_from_checkpoint=checkpoint)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/transformers/trainer.py", line 2238, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/transformers/trainer.py", line 2553, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/transformers/trainer.py", line 3724, in training_step
    inputs = self._prepare_inputs(inputs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/trl/trl/extras/profiling.py", line 87, in wrapper
    return func(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/trl/trl/trainer/grpo_trainer.py", line 991, in _prepare_inputs
    accumulated_local_batch = self._generate_and_score_completions(accumulated_local_batch)
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/trl/trl/trainer/grpo_trainer.py", line 1024, in _generate_and_score_completions
    self._move_model_to_vllm()
  File "/workspace/trl/trl/extras/profiling.py", line 87, in wrapper
    return func(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/trl/trl/trainer/grpo_trainer.py", line 961, in _move_model_to_vllm
    llm_model = self.llm.llm_engine.model_executor.driver_worker.model
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker_base.py", line 624, in __getattr__
    return getattr(self.worker, attr)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'Worker' object has no attribute 'model'

@toslali-ibm
Copy link
Contributor Author

Okay... I think all fixed now. I was able to run a quick training. I am now running sanity-check experiment for TP =1, 2, 4 and will report rewardds.

@toslali-ibm
Copy link
Contributor Author

toslali-ibm commented May 1, 2025

@qgallouedec , the sanity experiment looks good - please see the figure below.
Labels: Refactored (current version of the code) vs. vLLM coloc (before Quentin's refactor) vs. vLLM server

W B Chart 5_1_2025, 4_46_24 PM

@qgallouedec
Copy link
Member

Nice! Trying to run on my side.

* self.args.gradient_accumulation_steps,
max_model_len=self.max_prompt_length + self.max_completion_length,
distributed_executor_backend="external_launcher",
# Feed identical seed for tp groups to ensure ...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@toslali-ibm I wasn't sure how to motivate this, can you complete this comment?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, is os.getenv("RANK", "0") the same as self.accelerator.process_index? if so I'd use the later

Copy link
Member

@qgallouedec qgallouedec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some changements not related to this PR, but that I did while merge main to your branch

Comment on lines -512 to +527
# redirect the model.module forward to the model forward to ensure pre-forward hooks are called
self._forward_redirection = _ForwardRedirection()
if self.use_liger_loss:
if not is_liger_kernel_available():
raise ImportError(
"Liger is required to use `liger_loss` as the GRPO loss. Run `pip install liger-kernel`."
)

# Redirect the model.module forward to the model forward to ensure pre-forward hooks are called
self._forward_redirection = _ForwardRedirection()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not related to this PR

Comment on lines +868 to +895
def _sync_fsdp_params_to_vllm(self, module: nn.Module, prefix: str = "", visited=None):
"""Memory-efficient post-order traversal of FSDP modules to extract full parameters and sync with vLLM."""
if visited is None:
visited = set()

for child_name, child_module in module.named_children():
child_prefix = f"{prefix}.{child_name}" if prefix else child_name
self._sync_fsdp_params_to_vllm(
child_module, prefix=child_prefix, visited=visited
) # recurse into the child

if isinstance(module, FSDP):
with FSDP.summon_full_params(module, recurse=False, writeback=False):
for param_name, param in module.named_parameters():
full_name = f"{prefix}.{param_name}" if prefix else param_name
for extra in ("_fsdp_wrapped_module.", "_checkpoint_wrapped_module."):
full_name = full_name.replace(extra, "")

if full_name in visited:
continue # skip FSDP subtrees already traversed
visited.add(full_name)

if self.vllm_mode == "server" and self.accelerator.is_main_process:
self.vllm_client.update_named_param(full_name, param.data)
elif self.vllm_mode == "colocate":
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
llm_model.load_weights([(full_name, param.data)])

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not related to this PR

Comment on lines -876 to +921
# With PEFT and FSDP/DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as merging
# adapters in a sharded manner is not supported.
# With PEFT and FSDP/DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as
# merging adapters in a sharded manner is not supported.
# TODO: does this work with FSDP?
with gather_if_zero3(list(self.model.parameters())):
if self.is_fsdp_enabled:
self.model.merge_adapter()

# Update vLLM weights while parameters are gathered
if self.is_fsdp_enabled: # note if using FSDP, gather_if_zero3 is nullcontext
# Update vLLM weights while parameters are gathered
# For PEFT with FSDP we need to use the memory efficient post-order traversal
self.model.merge_adapter()
post_order_fsdp_processing(self.model)
self.model.unmerge_adapter()
self._sync_fsdp_params_to_vllm(self.model)
else:
# DeepSpeed ZeRO-3 with PEFT (not using FSDP)
self.model.merge_adapter()
# DeepSpeed ZeRO-3 with PEFT
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not related to this PR

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@qgallouedec qgallouedec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wonderful work @toslali-ibm
Can you please take a final look before I merge?

@qgallouedec qgallouedec changed the title Co-Locating vLLM w/ training to for higher throughput and GPU utilization πŸ§‘β€πŸ€β€πŸ§‘ Co-Locating vLLM w/ training to for higher throughput and GPU utilization May 1, 2025
@toslali-ibm
Copy link
Contributor Author

toslali-ibm commented May 1, 2025

Wonderful work @toslali-ibm Can you please take a final look before I merge?

Everything looks greatβ€”my sanity experiment ran successfully on the latest version. Thanks so much for the solid help, @qgallouedec ! :)

@qgallouedec qgallouedec merged commit 18596cf into huggingface:main May 1, 2025
@mayanks43
Copy link

Very cool improvement!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Enable External Launcher Support for vLLM in TRL for Efficient GRPO Training
4 participants