Skip to content

GRPO split generations into multiple training batches #3017

@JamesBowerXanda

Description

@JamesBowerXanda

Feature request

In the GRPO training it would be useful if you could split the generations into smaller batches for the gradient calculations similar to how we split batches into multiple gradient calculations with gradient_accumulation_steps.

I am imagining the config to work something like this:

per_device_train_batch_size = 4
num_generations = 8
gradient_accumulation_steps = 4

with the condition that per_device_train_batch_size * gradient_accumulation_steps is a multiple of num_generations.

Motivation

In the GRPO algorithm the loss calculation (ignoring the KL part) is an estimation of an expectation under the current models distribution. This will have very high variance if we are limiting the sample size (number of generations) to small numbers giving us a poor estimation of the expectation and therefore making training less stable.

Currently per_device_train_batch_size must be a multiple of num_generations which can severely limit how large you can make it before hitting OOM particularly when in resource constrained environments working with long context windows. This seems like an unnecessary restriction since nothing in the algorithm stops us from splitting the gradient calculation of a generation batch into multiple smaller batches.

Your contribution

I don't think I would be able to create the PR myself unfortunately.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions