Skip to content

👎 [GRPO] Adds option to disable dropout #3234

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 9, 2025
Merged

Conversation

edbeeching
Copy link
Collaborator

@edbeeching edbeeching commented Apr 4, 2025

What does this PR do?

Adds an option to disable dropout.

The RLOOTrainer disables dropout in policy, ref_model and reward model. This PR adds the option to disable dropout to the GRPOTrainer, which may improve training stability.

The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization
Provides more insight about why this option may improve training stability:

We disable the dropout layers during training, similar to the settings in Ziegler et al. (2019); Huang
et al. (2024). This is important for PPO training, especially because with dropout activated, the log
probabilities of tokens will not be reproducible, making calculating the KL penalty unreliable while
also causing the ratios of the PPO to be not 1s during the first epoch, causing PPO optimization
problems. For consistency, we also disable dropout for SFT and RM training.

@edbeeching edbeeching requested review from qgallouedec and lewtun April 4, 2025 11:15
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice feature - LGTM.

@@ -359,6 +366,10 @@ def __init__(
reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
reward_func, num_labels=1, **model_init_kwargs
)
if args.disable_dropout:
if isinstance(reward_funcs[i], nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think AutoModelForSequenceClassification is loaded in eval model by default, so technically we don't need this here (happy to keep it though if we want to be safe)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed!

>>> from transformers import AutoModelForSequenceClassification
>>> model = AutoModelForSequenceClassification.from_pretrained("trl-lib/Qwen2-0.5B-Reward", num_labels=1)
>>> model.training
False

@@ -101,6 +101,9 @@ class GRPOConfig(TrainingArguments):
speed, but may be numerically unstable for long training runs.
num_iterations (`int`, *optional*, defaults to `1`):
Number of iterations per batch (denoted as μ in the algorithm).
disable_dropout (`bool`, *optional*, defaults to `False`):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the other trainers this is set to True, maybe we should do the same here?

Comment on lines 351 to 356
if args.disable_dropout:
if isinstance(model, nn.Module):
disable_dropout_in_model(model)
if self.ref_model is not None and isinstance(self.ref_model, nn.Module):
disable_dropout_in_model(self.ref_model)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we only support PreTrainedModel, which are nn.Module (what else could it be?)

Suggested change
if args.disable_dropout:
if isinstance(model, nn.Module):
disable_dropout_in_model(model)
if self.ref_model is not None and isinstance(self.ref_model, nn.Module):
disable_dropout_in_model(self.ref_model)
if args.disable_dropout:
disable_dropout_in_model(model)
if self.ref_model is not None:
disable_dropout_in_model(self.ref_model)

Copy link
Member

@qgallouedec qgallouedec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, in the future we could use a default to True (let's see if it improve stability)

@qgallouedec qgallouedec changed the title [GRPO] Adds option to disable dropout 👎 [GRPO] Adds option to disable dropout Apr 9, 2025
@qgallouedec qgallouedec merged commit 47b9515 into main Apr 9, 2025
8 of 10 checks passed
@qgallouedec qgallouedec deleted the grpo-disable-dropout branch April 9, 2025 16:59
yxliu-TAMU pushed a commit to mincheolseong/ECEN743-GRPO-Project-Proposal that referenced this pull request Apr 20, 2025
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants