Skip to content

🔭 Add support for better KL estimator (k3) in PPOTrainer #3240

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Apr 6, 2025

Conversation

AMindToThink
Copy link
Contributor

@AMindToThink AMindToThink commented Apr 5, 2025

In this blog post by John Schulman, he describes three estimators for KL divergence: k1, k2, and k3. He finds that k3 appears to be strictly better than k1 and k2 for both theoretical and empirical reasons.

Previously the PPOTrainer was using k1 to estimate KL divergence. This pull request adds support, through an argument in PPOConfig, kl_estimator, to decide between "k1" and "k3".

k3 is unbiased and has less variance than k1.
Additionally, k1 is sometimes negative, which is an undesirable property to have when estimating KL divergence, a quantity which is always positive. k3 is always positive.

I did not add support for k2 because the PPOTrainer was already using that for logging purposes.

Notes:

  • I wrote descriptions of the new variable kl_estimator in the scripts, and if I'm understanding correctly, [[autodoc]] will add those descriptions to the documentation.
  • I tested by running test_ppo_trainer.py with both default and kl_estimator='k3', but did not feel the need to clutter the tests by writing another one for this one small change.
  • For historical reasons, I set "k1" to be the default, so everyone's scripts will run the same way as before after the update. We should consider setting "k3" to be the default because it is lower variance, unbiased, and trivial to compute (it just takes an exp and two subtractions).

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines.
  • Did you write any new necessary tests?

AMindToThink and others added 2 commits April 5, 2025 22:04
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Copy link
Member

@qgallouedec qgallouedec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks @AMindToThink!
A few format comments, otherwise we're good to merge

AMindToThink and others added 3 commits April 6, 2025 00:04
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
@qgallouedec
Copy link
Member

Waiting for the CI to pass

@AMindToThink
Copy link
Contributor Author

Pleasure working with you, @qgallouedec . Expect more PRs from me in the future!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qgallouedec
Copy link
Member

For historical reasons, I set "k1" to be the default, so everyone's scripts will run the same way as before after the update. We should consider setting "k3" to be the default because it is lower variance, unbiased, and trivial to compute (it just takes an exp and two subtractions).

If you can provide results showing that in the general case, k3 is a better choice, we can change the default value (in a follow up PR).

@AMindToThink
Copy link
Contributor Author

How would you go about showing that for the general case?
For one, what experiments would I do? I could find three projects which use PPOTrainer, and rerun the code with both estimators, but that is hardly showing that kx is better in the general case.
And how would I evaluate results? KL is the measure of coherence, and I wouldn't just be able to use one of the KL estimators themselves because that would be unfair to the other (unless both estimators agreed that the kx estimator results in lower KL divergence)?

@qgallouedec
Copy link
Member

What's important here is to know whether using k3 makes it possible to converge faster and/or towards better models.
My advice would be to take 2 or 3 different configurations (datasets/models) and run the training, then evaluate the checkpoints (the type of evaluation depends on the dataset).

@qgallouedec qgallouedec changed the title Add support for better KL estimator (k3) in PPOTrainer 🔭 Add support for better KL estimator (k3) in PPOTrainer Apr 6, 2025
@qgallouedec qgallouedec merged commit 4bfb8eb into huggingface:main Apr 6, 2025
8 of 9 checks passed
@AMindToThink AMindToThink deleted the kl-estimator-ppo branch April 6, 2025 05:37
@AMindToThink
Copy link
Contributor Author

@qgallouedec
I notice that other scripts, such as DPOTrainer use cap_exp instead of exp as a protection against exponentiating extreme values. Should we use that here, too?

I would think it should be unnecessary, because PPOTrainer is on-policy, so logprobs cannot be huge (since that would imply sampling a really unlikely generation/token), but I thought I'd check in.

@qgallouedec
Copy link
Member

I'd say, unless we encounter this issue, we can keep the code as is

yxliu-TAMU pushed a commit to mincheolseong/ECEN743-GRPO-Project-Proposal that referenced this pull request Apr 20, 2025
…#3240)

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants