-
Notifications
You must be signed in to change notification settings - Fork 2.1k
🔭 Add support for better KL estimator (k3) in PPOTrainer #3240
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, thanks @AMindToThink!
A few format comments, otherwise we're good to merge
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Waiting for the CI to pass |
Pleasure working with you, @qgallouedec . Expect more PRs from me in the future! |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
If you can provide results showing that in the general case, k3 is a better choice, we can change the default value (in a follow up PR). |
How would you go about showing that for the general case? |
What's important here is to know whether using k3 makes it possible to converge faster and/or towards better models. |
@qgallouedec I would think it should be unnecessary, because PPOTrainer is on-policy, so logprobs cannot be huge (since that would imply sampling a really unlikely generation/token), but I thought I'd check in. |
I'd say, unless we encounter this issue, we can keep the code as is |
…#3240) Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
In this blog post by John Schulman, he describes three estimators for KL divergence: k1, k2, and k3. He finds that k3 appears to be strictly better than k1 and k2 for both theoretical and empirical reasons.
Previously the PPOTrainer was using k1 to estimate KL divergence. This pull request adds support, through an argument in PPOConfig,
kl_estimator
, to decide between "k1" and "k3".k3 is unbiased and has less variance than k1.
Additionally, k1 is sometimes negative, which is an undesirable property to have when estimating KL divergence, a quantity which is always positive. k3 is always positive.
I did not add support for k2 because the PPOTrainer was already using that for logging purposes.
Notes:
kl_estimator
in the scripts, and if I'm understanding correctly, [[autodoc]] will add those descriptions to the documentation.kl_estimator='k3'
, but did not feel the need to clutter the tests by writing another one for this one small change.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines.