-
Notifications
You must be signed in to change notification settings - Fork 2.1k
fix grpo generation_kwargs #3634
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: ahatamizadeh <ahatamizadeh@nvidia.com>
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR prevents a NoneType
error by only updating the default generation_kwargs
when user overrides are provided.
- Adds a guard to skip updating when
self.args.generation_kwargs
isNone
.
Comments suppressed due to low confidence (2)
trl/trainer/grpo_trainer.py:1127
- Consider adding a test case to verify that omitting
generation_kwargs
(i.e., leaving it asNone
) does not raise an error and correctly falls back to default values.
if self.args.generation_kwargs is not None:
trl/trainer/grpo_trainer.py:1117
- [nitpick] It might help to update the method docstring to note that
generation_kwargs
can beNone
and describe the default behavior when no overrides are provided.
generation_kwargs = {
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @ahatamiz Ali jan, Thanks for your contribution;
I've tested your PR;
For future reference I tested with;
from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer
import os
# Set required environment variables for VLLM
os.environ["RANK"] = "0"
os.environ["LOCAL_RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
dataset = load_dataset("trl-lib/tldr", split="train")
def reward_len(completions, **kwargs):
return [-abs(20 - len(completion)) for completion in completions]
training_args = GRPOConfig(
output_dir="pr3634",
use_vllm=True,
vllm_mode="colocate",
gradient_checkpointing=True,
num_generations=4,
per_device_train_batch_size=4,
vllm_gpu_memory_utilization=0.10,
vllm_tensor_parallel_size=1,
max_prompt_length=512,
max_completion_length=1024,
max_steps=2
)
trainer = GRPOTrainer(
model="Qwen/Qwen2-0.5B-Instruct",
reward_funcs=reward_len,
args=training_args,
train_dataset=dataset,
)
trainer.train()
then ran with;
accelerate launch --config_file examples/accelerate_configs/single_gpu.yaml PR_3634.py
Thank you @shirinyamani for the review and also this pointer. Sure, I will keep this in mind. Kind Regards, |
Signed-off-by: ahatamizadeh <ahatamizadeh@nvidia.com> Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Signed-off-by: ahatamizadeh <ahatamizadeh@nvidia.com> Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
What does this PR do?
Currently, If the user does not pass
generation_kwargs
, then you should not face any issues (i.e. default values). Yet, this is what you get if you specifyvllm_mode=colocate
:This PR simply fixes that !
Fixes #3633
Before submitting
Pull Request section?
to it if that's the case.
Who can review?
@qgallouedec
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.