Skip to content

Add vLLM transformers backend to online methods #3773

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jul 30, 2025

Conversation

merveenoyan
Copy link
Contributor

@merveenoyan merveenoyan commented Jul 25, 2025

Add vLLM transformers backend to online methods GRPO and online DPO.
This has two limitations:

  • transformers backend for vLLM for VLMs are on vLLM main, it requires a release. in the meantime, install with VLLM_USE_PRECOMPILED=1 uv pip install -e .
  • server + eager works, for some reason colocate doesn't. I could swear it was working before I merged some changes from main. edit: colocate works on single GPU, although I merged NCCL related changes, there seems to be an issue with multi GPU setup.

I will check the issues, you can test with

CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=0 python3 examples/scripts/grpo_vlm.py     --model_name_or_path   Qwen/Qwen2.5-VL-3B-Instruct    --output_dir grpo-qwen25     --learning_rate 1e-5   --torch_dtype bfloat16     --max_prompt_length 512     --max_completion_length 512    --per_device_train_batch_size 2     --gradient_accumulation_steps 2     --num_generations 2      --bf16 True    --lora_target_modules "q_proj", "v_proj"     --log_completions --use_vllm --vllm_mode colocate --vllm_model_impl transformers

while serving with

CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=1 trl vllm-serve --model Qwen/Qwen2.5-VL-3B-Instruct --tensor-parallel-size 1 --port 8000 --enforce_eager --vllm_model_impl transformers

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@merveenoyan
Copy link
Contributor Author

@qgallouedec @kashif this works, apparently my compiled dev version for vLLM was causing all the issue, updating v1 solved it!

@@ -393,6 +393,14 @@ class GRPOConfig(TrainingArguments):
"contention with training."
},
)
vllm_model_impl: str = field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you also add these doc in the docstrings further up in the file?

@@ -292,6 +292,14 @@ class ScriptArguments:
"'trace'."
},
)
vllm_model_impl: str = field(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's missing in the docstring above

@@ -164,6 +164,14 @@ class may differ from those in [`~transformers.TrainingArguments`].
"(`pip install vllm`)."
},
)
vllm_model_impl: str = field(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same


## vLLM with Transformers Backend

vLLM now supports transformers backend for model implementations. Simply passing in `transformers` in `vllm_model_impl` in configurations or through argument parser will set use transformers backend. See an example below.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it support VLMs? Limitations?
Additionally, maybe linking the blog: https://blog.vllm.ai/2025/04/11/transformers-backend.html
(in case these ideas can be added in 1-2 sentences)

@merveenoyan
Copy link
Contributor Author

@kashif can you merge 🙏🏻

@kashif kashif merged commit 90c7876 into huggingface:main Jul 30, 2025
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants