Skip to content

☝️ [GRPO] Generate once per effective batch #3283

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Apr 17, 2025

Conversation

qgallouedec
Copy link
Member

@qgallouedec qgallouedec commented Apr 12, 2025

What does this PR do?

Summary

This PR optimizes the generation process in GRPO when using gradient accumulation by ensuring that generation happens only once per effective batch, rather than once per micro-batch. This significantly improves performance—particularly with large gradient accumulation values—by leveraging the efficiency of vLLM with large batch sizes.

Context

Currently, when using GRPO with gradient accumulation, generation is triggered at every micro-batch. For instance, with a gradient accumulation factor of 16, generation is launched 16 times before a single network update. This is inefficient, especially given that vLLM performs better with larger batches.

What this PR changes

This PR modifies the logic so that generation occurs only once per effective batch (i.e., after all gradient accumulation steps). To enable this:

  • The dataloader is adjusted to repeat the same batch GAS times (where GAS is the number of gradient accumulation steps).
  • Generation is performed only during the first micro-batch.
  • The actual sampled batch is only a slice of the loaded data
Screenshot 2025-04-12 at 09 58 41

Why it matters

This leads to a speedup in training, particularly with higher accumulation values, by making better use of vLLM’s batched generation capabilities.

Note: This is a preparatory change. The full performance gains will be realized once the TRL vLLM server supports distributed parallelism (DP).

Review Notes

This change involves more complex indexing logic, which may make the diff harder to review. However, the code has been heavily commented for clarity and maintainability.

Benchmark

Training time for 9 configs:
Note that for GAS=1, no speedup is expected

time_comparison

Code used:

from datasets import load_dataset
from trl import GRPOTrainer, GRPOConfig
import argparse
import subprocess

def get_git_commit_hash():
    try:
        commit_hash = subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('utf-8').strip()[:7]
        return commit_hash
    except subprocess.CalledProcessError:
        return None

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="GRPO Trainer")
    parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
    parser.add_argument("--per_device_train_batch_size", type=int, default=1)
    parser.add_argument("--num_iterations", type=int, default=1)
    args = parser.parse_args()

    dataset = load_dataset("trl-lib/tldr", split="train[:2048]")

    # Dummy reward function: count the number of unique characters in the completions
    def reward_num_unique_chars(completions, **kwargs):
        return [len(set(c)) for c in completions]

    git_commit = get_git_commit_hash()
    run_name = f"GRPO_7B_{git_commit}_GAS_{args.gradient_accumulation_steps}_BS_{args.per_device_train_batch_size}_MU_{args.num_iterations}"

    trainer = GRPOTrainer(
        model="Qwen/Qwen2.5-7B",
        reward_funcs=reward_num_unique_chars,
        train_dataset=dataset,
        args=GRPOConfig(
            run_name=run_name,
            logging_steps=10,
            bf16=True,
            max_completion_length=128,
            gradient_checkpointing=True,
            use_vllm=True,
            num_iterations=args.num_iterations,
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            per_device_train_batch_size=args.per_device_train_batch_size,
            vllm_server_timeout=360,
        ),
    )
    trainer.train()
#!/bin/bash
#SBATCH --output=/fsx/qgallouedec/logs/%x-%j.out
#SBATCH --error=/fsx/qgallouedec/logs/%x-%j.err
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --gres=gpu:8
#SBATCH --time=10:00:00
#SBATCH --qos=normal


# Run vLLM server
trl vllm-serve --model Qwen/Qwen2.5-7B --tensor_parallel_size 4 &

# Run trainin
CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch \
     --config_file examples/accelerate_configs/deepspeed_zero3.yaml \
     --num_processes 4 \
     sandbox/3283.py \
     --gradient_accumulation_steps 16 \
     --per_device_train_batch_size 8 \
     --num_iterations 4

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@qgallouedec qgallouedec changed the title Generate once per optimization steps Generate once per optimization step Apr 12, 2025
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

generate_every = self.args.gradient_accumulation_steps * self.num_iterations
if self._step % generate_every == 0 or self._buffered_inputs is None:
# self._buffered_inputs=None can occur when resuming from a checkpoint
accumulated_local_batch = self._generate_and_score_completions(accumulated_local_batch)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think using accumulated_local_batch directly might lead to OOM when calling _get_per_token_logps, maybe splitting _generate_and_score_completions into two parts would be better.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I haven't encountered such OOM yet. In fact, here, we do the forward pass with no grad, so maybe it prevents this problem? I'd have to test it with a large gradient accumulation step to confirm.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested it and it will be OOM when max_completion_length is large.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using the same setting, you don't get OOM with main?

Copy link
Contributor

@I-l-l-I I-l-l-I Apr 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I train 1.5B model on two gpus with GAS=8, BS=16 and max_completion_length=2048. Everything is fine with main.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think using accumulated_local_batch directly might lead to OOM when calling _get_per_token_logps, maybe splitting _generate_and_score_completions into two parts would be better.

In the first part, use vLLM to generate all the completion_ids and calculate the advantages of all these completions. In the second part, compute old_per_token_logps and ref_per_token_logps with original per_device_batchsize (perhaps old_per_token_logps can be obtained directly in the vLLM generation phase).

One benefit of this is that instead of (per_device_batchsize * num_processes) % num_generations == 0, we need (per_device_batchsize * num_processes * gradient_accumulation_steps) % num_generations == 0. This solves (#3017) and (#3288) without introducing additional parameters in config, just adjust the per_device_batchsize and gradient_accumulation_steps to adapt to various scenarios.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A good model to test would be any of the DeepSeek distilled ones: they easily produce 32k tokens per prompt, so you'll know quickly if improvements to the generation/scoring still OOM or not

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a very good point @I-l-l-I, I didn't even realised that with this PR, we're releasing this requirement (per_device_batchsize * num_processes) % num_generations == 0

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should avoid OOM: eedaab5

Copy link
Member Author

@qgallouedec qgallouedec Apr 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now the memory usage is the same as with main:

W B Chart 15_04_2025, 17_24_59

from datasets import load_dataset
from trl import GRPOTrainer, GRPOConfig
import random

dataset = load_dataset("trl-lib/tldr", split="train[:200]")


def reward_random(completions, **kwargs):
    return [random.random() for _ in range(len(completions))]

trainer = GRPOTrainer(
    model="Qwen/Qwen2.5-1.5B",
    reward_funcs=reward_random,
    train_dataset=dataset,
    args=GRPOConfig(
        per_device_train_batch_size=16,
        gradient_accumulation_steps=8,
        max_completion_length=2048,
        use_vllm=True,
        gradient_checkpointing=True,
        bf16=True,
    ),
)
trainer.train()

2 GPUs

CUDA_VISIBLE_DEVICES=0,1 accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml --num_processes 2 3283.py 

Side note: with this script, for example, training time is reduced from 30 min to 10 min. 🤯

@qgallouedec
Copy link
Member Author

Regression

The following experiments aim to ensure that that the results match when the same configuration is used. For all graphs,

Outline:

  • gradient_accumulation_steps=1 and num_iterations=1
  • gradient_accumulation_steps=1 and num_iterations=2
  • gradient_accumulation_steps=1 and num_iterations=4
  • gradient_accumulation_steps=4 and num_iterations=1
  • gradient_accumulation_steps=4 and num_iterations=2
  • gradient_accumulation_steps=4 and num_iterations=4
  • gradient_accumulation_steps=16 and num_iterations=1
  • gradient_accumulation_steps=16 and num_iterations=2
  • gradient_accumulation_steps=16 and num_iterations=4

gradient_accumulation_steps=1 and num_iterations=1

1-1

gradient_accumulation_steps=1 and num_iterations=2

2-1

gradient_accumulation_steps=1 and num_iterations=4

4-1

gradient_accumulation_steps=4 and num_iterations=1

1-4

gradient_accumulation_steps=4 and num_iterations=2

2-4

gradient_accumulation_steps=4 and num_iterations=4

4-4

gradient_accumulation_steps=16 and num_iterations=1

1-16

gradient_accumulation_steps=16 and num_iterations=2

2-16

gradient_accumulation_steps=16 and num_iterations=4

4-16

@qgallouedec qgallouedec changed the title Generate once per optimization step ☝️ Generate once per optimization step Apr 12, 2025
@qgallouedec qgallouedec changed the title ☝️ Generate once per optimization step ☝️ Generate once per effective batch Apr 12, 2025
@qgallouedec qgallouedec changed the title ☝️ Generate once per effective batch ☝️ [GRPO] Generate once per effective batch Apr 12, 2025
@@ -631,7 +695,7 @@ def _get_train_sampler(self) -> Sampler:
data_source=self.train_dataset,
mini_repeat_count=self.num_generations,
batch_size=effective_batch_size // self.num_generations,
repeat_count=self.num_iterations,
repeat_count=self.num_iterations * self.args.gradient_accumulation_steps,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not for this PR, but would it be worth having a per_device_mega_batch_size and completely decoupling the gradient accumulation from the size of the big generation batch? Some batch shuffling would be required to ensure that the minibatches vary from one iteration to the next, but I don't think it would be a huge change.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@edbeeching Do you mean sth like this in simple terms?

maga_batch_size = 524288  # ~0.5M, in number of tokens if T = 1024
B = 64 # micro batch size
T = 1024 
assert maga_batch_size % (B * T * ddp_world_size) == 0
grad_accum_steps = maga_batch_size // (B * T * ddp_world_size)

where we later do;

for micro_step in range(grad_accum_steps):
        x, y = train_loader.next_batch()
        x, y = x.to(device), y.to(device)
        logits, loss = model(x, y)
        loss = loss / grad_accum_steps
        loss_accum += loss.detach()
        loss.backward()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think what Ed is suggesting is to align the GRPOTrainer partially with our other RL trainers where one defines a large effective batch size of N_prompts x N_generations and then performing K steps of mini-batch optimisation (i.e. the num_minibatches arg in the PPOTrainer). This is also what other RL frameworks like verl do.

In that context, gradient accumulation is not involved in defining the large effective batch size, but I realise it is not easy to force the transformers.Trainer logic to do mini-batch optimisation and should be done in a separate PR (if at all)

@@ -836,13 +932,14 @@ def _generate_and_score_completions(
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)

logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally for training (independent of this PR) per_device_eval_batch_size is 2*per_device_train_batch_size

As we are in no_grad mode you can probably set the per_device_train_batch_size here to be 2*self.args.per_device_train_batch_size as there will be no memory used for the storage of activations etc.

It would be worth validating in a scenario with significant memory pressure, so do not worry about it if you want to get the PR merged.

Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice PR with clearly documented logic; this was a pleasure to read! Looking forward to these speed-ups!

qgallouedec and others added 2 commits April 17, 2025 16:19
@qgallouedec qgallouedec merged commit 294f35b into main Apr 17, 2025
10 checks passed
@qgallouedec qgallouedec deleted the generate-once-per-step branch April 17, 2025 23:36
@zhiqihuang
Copy link

From trl logging, I noticed sometime the generations is not correctly paired with prompts. It could a logging issue in multi-gpu training or it could be the order changed during one call per batch to the vllm server. I will try to find more trace and post here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants