-
Notifications
You must be signed in to change notification settings - Fork 2.1k
🔧 Fix GRPO sampling logic #3725
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Number of optimization steps per generation. If `None`, it defaults to gradient_accumulation_steps. | ||
`per_device_train_batch_size * num_processes * steps_per_generation`. In other words, there is one | ||
generation batch processed per optimization step. Mutually exclusive with `steps_per_generation`. | ||
steps_per_generation: (`int` or `None`, *optional*, defaults to `None`): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- steps_per_generations: (`int` or `None`, *optional*, defaults to `None`):
+ steps_per_generation: (`int` or `None`, *optional*, defaults to `None`):
`per_device_train_batch_size * num_processes * steps_per_generation`. In other words, there is one | ||
generation batch processed per optimization step. Mutually exclusive with `steps_per_generation`. | ||
steps_per_generation: (`int` or `None`, *optional*, defaults to `None`): | ||
Number of steps per generation. If `None`, it defaults to `gradient_accumulation_steps`. Mutually exclusive |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not the number of opt step, but the number of steps.
# Check if the effective batch size can be divided by the number of generations | ||
if self.num_generations < 2: | ||
raise ValueError( | ||
"GRPO requires at least 2 generations per prompt to calculate the advantages. You provided " | ||
f"{self.num_generations}, which is less than the minimum required." | ||
) | ||
possible_values = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The checks below aren't relevant anymore. The only test we need is
if self.generation_batch_size % (self.per_device_train_batch_size * num_processes) != 0:
raise ValueError(
f"generation_batch_size ({self.generation_batch_size}) must be divisible by the global batch size "
f"({self.per_device_train_batch_size * num_processes})."
)
@@ -179,7 +179,7 @@ def __iter__(self): | |||
yield index | |||
|
|||
def __len__(self) -> int: | |||
return self.num_samples * self.mini_repeat_count * self.repeat_count | |||
return (self.num_samples // self.batch_size) * self.batch_size * self.mini_repeat_count * self.repeat_count |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because of this (just above):
# [[2, 4, 3], [1, 0, 6], [5]]
# -> [[2, 4, 3], [1, 0, 6]]
indexes = [chunk for chunk in indexes if len(chunk) == self.batch_size]
# Keep logs sized to the generation batch to record only outputs from the latest model update. | ||
self._textual_logs = { | ||
"prompt": deque(maxlen=maxlen), | ||
"completion": deque(maxlen=maxlen), | ||
"rewards": defaultdict(lambda: deque(maxlen=maxlen)), | ||
"advantages": deque(maxlen=maxlen), | ||
"prompt": deque(maxlen=args.generation_batch_size), | ||
"completion": deque(maxlen=args.generation_batch_size), | ||
"rewards": defaultdict(lambda: deque(maxlen=args.generation_batch_size)), | ||
"advantages": deque(maxlen=args.generation_batch_size), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Equivalent but simplify by using generation_batch_size
directly
trl/trainer/grpo_trainer.py
Outdated
@@ -819,7 +817,7 @@ def _get_train_sampler(self, dataset: Optional[Dataset] = None) -> Sampler: | |||
data_source=dataset, | |||
mini_repeat_count=self.num_generations, | |||
batch_size=self.args.generation_batch_size // self.num_generations, | |||
repeat_count=self.num_iterations * self.args.steps_per_generation, | |||
repeat_count=self.num_iterations, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is a big one. No idea why we did it in the first place. There is no need to repeat the sampling self.args.steps_per_generation
times
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the original logic might be correct, because in _prepare_inputs
, data is only sampled once every generate_every
steps. For the remaining generate_every - 1
steps, we still need to load data. Therefore, multiplying by self.args.steps_per_generation
here actually helps to avoid skipping any data.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was what I was thinking when I made the change. Although perhaps there were some details of the sampling that I have misunderstood, which introduced the bug.
# If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of | ||
# a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the | ||
# samples may come from an earlier version of the model. In that case, we need to track old_per_token_logps | ||
# for importance sampling. If the steps are aligned, importance sampling isn't necessary and we set | ||
# old_per_token_logps to None. | ||
generate_every = self.args.steps_per_generation * self.num_iterations # generation frequency | ||
if self.args.gradient_accumulation_steps % generate_every != 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hoping the comment is enough for review
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm still confused. The generate_every is for fixing distribution shift and I totally understand why it times steps_per_generation.
While in generating new samples, we have batch_size=self.args.generation_batch_size // self.num_generations
and repeat_count=self.num_iterations * self.args.steps_per_generation
. According to the document, generation_batch_size = per_device_train_batch_size * num_processes * steps_per_generation
. It seems both parameters in RepeatSampler, batch_size and repeat_count, include steps_per_generation
.
Is this what we really want? If yes, could you give more detailed explanation in comment? @qgallouedec
Would it be possible to write an integration test with dummy data to ensure this is now working as expected in a few different configurations? |
i have been using: def test_repeat_sampler_length_calculation(self):
dataset = Dataset.from_list([{"text": f"sample_{i}"} for i in range(10)])
sampler = RepeatSampler(
data_source=dataset,
mini_repeat_count=2,
batch_size=3, # 10 samples / 3 batch_size = 3 complete batches, 1 sample dropped
repeat_count=1,
shuffle=False,
)
actual_samples = list(sampler)
# With batch_size=3 and 10 samples, only 9 samples form complete batches (1 dropped)
# So we expect 9 * 2 (mini_repeat_count) * 1 (repeat_count) = 18 samples
expected_actual_samples = 18
self.assertEqual(len(actual_samples), expected_actual_samples)
self.assertEqual(
len(sampler),
len(actual_samples),
f"RepeatSampler.__len__() should match actual iterations. "
f"Expected {len(actual_samples)}, but got {len(sampler)}",
)
def test_grpo_sampling_no_steps_per_generation_multiplier(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
with tempfile.TemporaryDirectory() as tmp_dir:
config = GRPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=4,
num_generations=8,
num_iterations=2,
steps_per_generation=4, # Use steps_per_generation without generation_batch_size
max_steps=1, # Just test sampler creation
report_to="none",
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=config,
train_dataset=dataset,
)
# Get the train sampler
sampler = trainer._get_train_sampler()
# repeat_count = num_iterations, NOT num_iterations * steps_per_generation
self.assertEqual(
sampler.repeat_count,
config.num_iterations,
f"RepeatSampler.repeat_count should not include steps_per_generation multiplier. "
f"Expected {config.num_iterations}, but got {sampler.repeat_count}",
) |
Hey, after reading through this, I am not entirely clear on which cases were bugged before this PR. Could you give a quick overview which settings caused wrong sampling logic before this PR landed? |
No idea how we ended up with this erroneous sampling logic, but after debugging in depth, I realize that GRPO is incorrect in several places. This PR fixes those problems.EDIT: I was initially wrong, but there are still a couple of errors to fix
If it can help the review:
