Skip to content

🔧 Fix GRPO sampling logic #3725

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 15, 2025
Merged

🔧 Fix GRPO sampling logic #3725

merged 7 commits into from
Jul 15, 2025

Conversation

qgallouedec
Copy link
Member

@qgallouedec qgallouedec commented Jul 12, 2025

No idea how we ended up with this erroneous sampling logic, but after debugging in depth, I realize that GRPO is incorrect in several places. This PR fixes those problems.

EDIT: I was initially wrong, but there are still a couple of errors to fix

If it can help the review:
Untitled-2025-07-08-1423

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qgallouedec qgallouedec changed the title Fix GRPO sampling logic 🔧 Fix GRPO sampling logic Jul 12, 2025
Number of optimization steps per generation. If `None`, it defaults to gradient_accumulation_steps.
`per_device_train_batch_size * num_processes * steps_per_generation`. In other words, there is one
generation batch processed per optimization step. Mutually exclusive with `steps_per_generation`.
steps_per_generation: (`int` or `None`, *optional*, defaults to `None`):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-         steps_per_generations: (`int` or `None`, *optional*, defaults to `None`):
+         steps_per_generation: (`int` or `None`, *optional*, defaults to `None`):

`per_device_train_batch_size * num_processes * steps_per_generation`. In other words, there is one
generation batch processed per optimization step. Mutually exclusive with `steps_per_generation`.
steps_per_generation: (`int` or `None`, *optional*, defaults to `None`):
Number of steps per generation. If `None`, it defaults to `gradient_accumulation_steps`. Mutually exclusive
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not the number of opt step, but the number of steps.

# Check if the effective batch size can be divided by the number of generations
if self.num_generations < 2:
raise ValueError(
"GRPO requires at least 2 generations per prompt to calculate the advantages. You provided "
f"{self.num_generations}, which is less than the minimum required."
)
possible_values = [
Copy link
Member Author

@qgallouedec qgallouedec Jul 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The checks below aren't relevant anymore. The only test we need is

if self.generation_batch_size % (self.per_device_train_batch_size * num_processes) != 0:
    raise ValueError(
        f"generation_batch_size ({self.generation_batch_size}) must be divisible by the global batch size "
        f"({self.per_device_train_batch_size * num_processes})."
    )

@@ -179,7 +179,7 @@ def __iter__(self):
yield index

def __len__(self) -> int:
return self.num_samples * self.mini_repeat_count * self.repeat_count
return (self.num_samples // self.batch_size) * self.batch_size * self.mini_repeat_count * self.repeat_count
Copy link
Member Author

@qgallouedec qgallouedec Jul 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because of this (just above):

        #    [[2, 4, 3], [1, 0, 6], [5]]
        # -> [[2, 4, 3], [1, 0, 6]]
        indexes = [chunk for chunk in indexes if len(chunk) == self.batch_size]

Comment on lines +620 to +625
# Keep logs sized to the generation batch to record only outputs from the latest model update.
self._textual_logs = {
"prompt": deque(maxlen=maxlen),
"completion": deque(maxlen=maxlen),
"rewards": defaultdict(lambda: deque(maxlen=maxlen)),
"advantages": deque(maxlen=maxlen),
"prompt": deque(maxlen=args.generation_batch_size),
"completion": deque(maxlen=args.generation_batch_size),
"rewards": defaultdict(lambda: deque(maxlen=args.generation_batch_size)),
"advantages": deque(maxlen=args.generation_batch_size),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Equivalent but simplify by using generation_batch_size directly

@@ -819,7 +817,7 @@ def _get_train_sampler(self, dataset: Optional[Dataset] = None) -> Sampler:
data_source=dataset,
mini_repeat_count=self.num_generations,
batch_size=self.args.generation_batch_size // self.num_generations,
repeat_count=self.num_iterations * self.args.steps_per_generation,
repeat_count=self.num_iterations,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a big one. No idea why we did it in the first place. There is no need to repeat the sampling self.args.steps_per_generation times

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the original logic might be correct, because in _prepare_inputs, data is only sampled once every generate_every steps. For the remaining generate_every - 1 steps, we still need to load data. Therefore, multiplying by self.args.steps_per_generation here actually helps to avoid skipping any data.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was what I was thinking when I made the change. Although perhaps there were some details of the sampling that I have misunderstood, which introduced the bug.

Comment on lines +1275 to +1281
# If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of
# a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the
# samples may come from an earlier version of the model. In that case, we need to track old_per_token_logps
# for importance sampling. If the steps are aligned, importance sampling isn't necessary and we set
# old_per_token_logps to None.
generate_every = self.args.steps_per_generation * self.num_iterations # generation frequency
if self.args.gradient_accumulation_steps % generate_every != 0:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hoping the comment is enough for review

Copy link

@tangyd tangyd Jul 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still confused. The generate_every is for fixing distribution shift and I totally understand why it times steps_per_generation.
While in generating new samples, we have batch_size=self.args.generation_batch_size // self.num_generations and repeat_count=self.num_iterations * self.args.steps_per_generation. According to the document, generation_batch_size = per_device_train_batch_size * num_processes * steps_per_generation. It seems both parameters in RepeatSampler, batch_size and repeat_count, include steps_per_generation.
Is this what we really want? If yes, could you give more detailed explanation in comment? @qgallouedec

@edbeeching
Copy link
Collaborator

Would it be possible to write an integration test with dummy data to ensure this is now working as expected in a few different configurations?

kashif added a commit to CompN3rd/trl that referenced this pull request Jul 14, 2025
@kashif
Copy link
Collaborator

kashif commented Jul 14, 2025

i have been using:

    def test_repeat_sampler_length_calculation(self):
        dataset = Dataset.from_list([{"text": f"sample_{i}"} for i in range(10)])

        sampler = RepeatSampler(
            data_source=dataset,
            mini_repeat_count=2,
            batch_size=3,  # 10 samples / 3 batch_size = 3 complete batches, 1 sample dropped
            repeat_count=1,
            shuffle=False,
        )

        actual_samples = list(sampler)

        # With batch_size=3 and 10 samples, only 9 samples form complete batches (1 dropped)
        # So we expect 9 * 2 (mini_repeat_count) * 1 (repeat_count) = 18 samples
        expected_actual_samples = 18

        self.assertEqual(len(actual_samples), expected_actual_samples)
        self.assertEqual(
            len(sampler),
            len(actual_samples),
            f"RepeatSampler.__len__() should match actual iterations. "
            f"Expected {len(actual_samples)}, but got {len(sampler)}",
        )

    def test_grpo_sampling_no_steps_per_generation_multiplier(self):
        dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

        with tempfile.TemporaryDirectory() as tmp_dir:
            config = GRPOConfig(
                output_dir=tmp_dir,
                per_device_train_batch_size=4,
                num_generations=8,
                num_iterations=2,
                steps_per_generation=4,  # Use steps_per_generation without generation_batch_size
                max_steps=1,  # Just test sampler creation
                report_to="none",
            )

            trainer = GRPOTrainer(
                model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
                reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
                args=config,
                train_dataset=dataset,
            )

            # Get the train sampler
            sampler = trainer._get_train_sampler()

            # repeat_count = num_iterations, NOT num_iterations * steps_per_generation
            self.assertEqual(
                sampler.repeat_count,
                config.num_iterations,
                f"RepeatSampler.repeat_count should not include steps_per_generation multiplier. "
                f"Expected {config.num_iterations}, but got {sampler.repeat_count}",
            )

@qgallouedec qgallouedec merged commit 508d551 into main Jul 15, 2025
10 of 11 checks passed
@qgallouedec qgallouedec deleted the fix-grpo-sampling branch July 15, 2025 20:39
@konstantinjdobler
Copy link

Hey, after reading through this, I am not entirely clear on which cases were bugged before this PR. Could you give a quick overview which settings caused wrong sampling logic before this PR landed?

marcandrelarochelle pushed a commit to marcandrelarochelle/trl that referenced this pull request Jul 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants