-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Description
Feature request
In the GRPO training it would be useful if you could split the generations into smaller batches for the gradient calculations similar to how we split batches into multiple gradient calculations with gradient_accumulation_steps.
I am imagining the config to work something like this:
per_device_train_batch_size = 4
num_generations = 8
gradient_accumulation_steps = 4
with the condition that per_device_train_batch_size * gradient_accumulation_steps
is a multiple of num_generations
.
Motivation
In the GRPO algorithm the loss calculation (ignoring the KL part) is an estimation of an expectation under the current models distribution. This will have very high variance if we are limiting the sample size (number of generations) to small numbers giving us a poor estimation of the expectation and therefore making training less stable.
Currently per_device_train_batch_size
must be a multiple of num_generations
which can severely limit how large you can make it before hitting OOM particularly when in resource constrained environments working with long context windows. This seems like an unnecessary restriction since nothing in the algorithm stops us from splitting the gradient calculation of a generation batch into multiple smaller batches.
Your contribution
I don't think I would be able to create the PR myself unfortunately.