-
Notifications
You must be signed in to change notification settings - Fork 2.1k
☝️ [GRPO] Generate once per effective batch #3283
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
generate_every = self.args.gradient_accumulation_steps * self.num_iterations | ||
if self._step % generate_every == 0 or self._buffered_inputs is None: | ||
# self._buffered_inputs=None can occur when resuming from a checkpoint | ||
accumulated_local_batch = self._generate_and_score_completions(accumulated_local_batch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think using accumulated_local_batch
directly might lead to OOM when calling _get_per_token_logps
, maybe splitting _generate_and_score_completions
into two parts would be better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. I haven't encountered such OOM yet. In fact, here, we do the forward pass with no grad, so maybe it prevents this problem? I'd have to test it with a large gradient accumulation step to confirm.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tested it and it will be OOM when max_completion_length
is large.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using the same setting, you don't get OOM with main?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I train 1.5B model on two gpus with GAS=8, BS=16 and max_completion_length=2048. Everything is fine with main.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think using
accumulated_local_batch
directly might lead to OOM when calling_get_per_token_logps
, maybe splitting_generate_and_score_completions
into two parts would be better.
In the first part, use vLLM to generate all the completion_ids
and calculate the advantages
of all these completions. In the second part, compute old_per_token_logps
and ref_per_token_logps
with original per_device_batchsize
(perhaps old_per_token_logps
can be obtained directly in the vLLM generation phase).
One benefit of this is that instead of (per_device_batchsize * num_processes) % num_generations == 0
, we need (per_device_batchsize * num_processes * gradient_accumulation_steps) % num_generations == 0
. This solves (#3017) and (#3288) without introducing additional parameters in config, just adjust the per_device_batchsize
and gradient_accumulation_steps
to adapt to various scenarios.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A good model to test would be any of the DeepSeek distilled ones: they easily produce 32k tokens per prompt, so you'll know quickly if improvements to the generation/scoring still OOM or not
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a very good point @I-l-l-I, I didn't even realised that with this PR, we're releasing this requirement (per_device_batchsize * num_processes) % num_generations == 0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should avoid OOM: eedaab5
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now the memory usage is the same as with main:
from datasets import load_dataset
from trl import GRPOTrainer, GRPOConfig
import random
dataset = load_dataset("trl-lib/tldr", split="train[:200]")
def reward_random(completions, **kwargs):
return [random.random() for _ in range(len(completions))]
trainer = GRPOTrainer(
model="Qwen/Qwen2.5-1.5B",
reward_funcs=reward_random,
train_dataset=dataset,
args=GRPOConfig(
per_device_train_batch_size=16,
gradient_accumulation_steps=8,
max_completion_length=2048,
use_vllm=True,
gradient_checkpointing=True,
bf16=True,
),
)
trainer.train()
2 GPUs
CUDA_VISIBLE_DEVICES=0,1 accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml --num_processes 2 3283.py
Side note: with this script, for example, training time is reduced from 30 min to 10 min. 🤯
@@ -631,7 +695,7 @@ def _get_train_sampler(self) -> Sampler: | |||
data_source=self.train_dataset, | |||
mini_repeat_count=self.num_generations, | |||
batch_size=effective_batch_size // self.num_generations, | |||
repeat_count=self.num_iterations, | |||
repeat_count=self.num_iterations * self.args.gradient_accumulation_steps, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably not for this PR, but would it be worth having a per_device_mega_batch_size
and completely decoupling the gradient accumulation from the size of the big generation batch? Some batch shuffling would be required to ensure that the minibatches vary from one iteration to the next, but I don't think it would be a huge change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@edbeeching Do you mean sth like this in simple terms?
maga_batch_size = 524288 # ~0.5M, in number of tokens if T = 1024
B = 64 # micro batch size
T = 1024
assert maga_batch_size % (B * T * ddp_world_size) == 0
grad_accum_steps = maga_batch_size // (B * T * ddp_world_size)
where we later do;
for micro_step in range(grad_accum_steps):
x, y = train_loader.next_batch()
x, y = x.to(device), y.to(device)
logits, loss = model(x, y)
loss = loss / grad_accum_steps
loss_accum += loss.detach()
loss.backward()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think what Ed is suggesting is to align the GRPOTrainer
partially with our other RL trainers where one defines a large effective batch size of N_prompts x N_generations
and then performing K steps of mini-batch optimisation (i.e. the num_minibatches
arg in the PPOTrainer
). This is also what other RL frameworks like verl
do.
In that context, gradient accumulation is not involved in defining the large effective batch size, but I realise it is not easy to force the transformers.Trainer
logic to do mini-batch optimisation and should be done in a separate PR (if at all)
@@ -836,13 +932,14 @@ def _generate_and_score_completions( | |||
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C) | |||
|
|||
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens | |||
batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally for training (independent of this PR) per_device_eval_batch_size
is 2*per_device_train_batch_size
As we are in no_grad
mode you can probably set the per_device_train_batch_size
here to be 2*self.args.per_device_train_batch_size
as there will be no memory used for the storage of activations etc.
It would be worth validating in a scenario with significant memory pressure, so do not worry about it if you want to get the PR merged.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice PR with clearly documented logic; this was a pleasure to read! Looking forward to these speed-ups!
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
From trl logging, I noticed sometime the generations is not correctly paired with prompts. It could a logging issue in multi-gpu training or it could be the order changed during one call per batch to the vllm server. I will try to find more trace and post here. |
What does this PR do?
Summary
This PR optimizes the generation process in GRPO when using gradient accumulation by ensuring that generation happens only once per effective batch, rather than once per micro-batch. This significantly improves performance—particularly with large gradient accumulation values—by leveraging the efficiency of vLLM with large batch sizes.
Context
Currently, when using GRPO with gradient accumulation, generation is triggered at every micro-batch. For instance, with a gradient accumulation factor of 16, generation is launched 16 times before a single network update. This is inefficient, especially given that vLLM performs better with larger batches.
What this PR changes
This PR modifies the logic so that generation occurs only once per effective batch (i.e., after all gradient accumulation steps). To enable this:
GAS
times (whereGAS
is the number of gradient accumulation steps).Why it matters
This leads to a speedup in training, particularly with higher accumulation values, by making better use of vLLM’s batched generation capabilities.
Review Notes
This change involves more complex indexing logic, which may make the diff harder to review. However, the code has been heavily commented for clarity and maintainability.
Benchmark
Training time for 9 configs:
Note that for GAS=1, no speedup is expected
Code used:
Before submitting
Pull Request section?
to it if that's the case.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.