-
Notifications
You must be signed in to change notification settings - Fork 2.1k
💔 [GRPO] Decouple gradient accumulation from the number of minibatches generated #3388
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with some nits!
As discussed offline, it would be interesting to check that the |
trl/trainer/grpo_config.py
Outdated
@@ -61,6 +61,8 @@ class GRPOConfig(TrainingArguments): | |||
with vLLM generation. | |||
shuffle_dataset (`bool`, *optional*, defaults to `True`): | |||
Whether to shuffle the training dataset. | |||
num_mini_batches: (`int`, *optional*, defaults to `None`): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would find it clearer to call this parameter generate_every
or something similar. I can't figure out what a mini-batch is in this case. Named like that, and with the doc ("split"), I expect there to be a division by that number later, like mini_batch_size = batch_size // num_mini_batches
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For context, the mini-batch terminology used here is the same as done in e.g. DAPO:
In other words, @edbeeching is partitioning the effective / generation batch into num_mini_batches
subsets to compute the loss. It takes the model slightly off-policy but allows one to generate a large batch and optimise smaller chunks without going OOM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Although this is technically correct and corresponds to the paper, I have the impression that it's a bit misleading, insofar as, as a user, I imagine that by increasing num_mini_batches
, I'm consequently decreasing the size of the mini-batch and increase the number of gradient updates per rollout step.
And, the sentence in the paper is also constructed like this: "the mini-batch size is set to 512, i.e., 16 gradient updates for each rollout step."
But that's not really what happens here. But if you don't find it misleading, we can leave it like that, just document it better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To clarify the DAPO example above, they have 512 unique prompts per batch and then sample 16 completions per prompt to obtain a generation batch size of 512 x 16 = 8'192.
They then partition the generation batch size into mini-batches of size 512, hence the 16 gradient updates per rollout step.
Now suppose I have 8 GPUs and per_device_train_batch_size=16
, then I believe the corresponding setting in trl
would be:
gradient_accumulation_steps = 64
(to get 8,192 = 8 x 16 x 64)num_mini_batches = 16
(i.e. 16 gradient updates with slices of 512 samples)
If @edbeeching agrees with my logic then indeed it would make sense to document how (a) the number of unique prompts is computed and (b) how mini-batching and gas can be tuned to obtain large generation batch sizes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lewtun in general I agree, but I am not sure if gradient_accumulation_steps = 64
in the case of the DAPO example details above.
To summarize how I see it, for each generation step:
- There will be 16 optimization steps (gradient updates) in total.
- Each step is with 512 prompt-completions pairs
- The total generation batch size is
512*16=8192
per_device_train_batch_size
is unknown, but lets assume it is 16 andnum_gpus=8
.- Which would mean
per_device_train_batch_size=16
andgradient_accumulation_steps=4
as16*8*4=512
In this case the num_mini_batches
would be:
generation_batch_size / (num_gpus*per_device_train_batch_size) = 64
- This is also equal to the
num_optmization_steps* grad_acc_steps
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yes indeed I made a mistake in my above reasoning - I agree with @edbeeching.
We discussed offline and also agree it is confusing to have num_mini_batches
affect the generation batch size.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Another cool optimization!
For the record, it may be more intuitive for the user to set steps_per_generation
directly (and not generation_batch_size
and then calculate steps_per_generation
) but you're the one who's been playing with the code over the last few days so you know better what's more intuitive. I'm happy with both, so feel free to merge either.
I personally prefer |
I think both are useful, so I will make it so the user can set either option, but not both. |
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
This PR decouples the size of the effective batch from the gradient accumulation steps with a new argument
num_mini_batches
.By default
num_mini_batches = gradient_accumulation_steps
A few caveats:
accumulated_local_batch
before adding it to the buffer. PR 🎲 [GRPO] Shuffle mini batches #3391num_iterations > 1
the buffer indices should be permuted / reshuffled in order to ensure that samples are taken randomly.