-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Add entropy based filtering inside the GRPOTrainer. #3563
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add entropy based filtering inside the GRPOTrainer. #3563
Conversation
You should be able to calculate the entropy directly from log probs as
which means we don't have to modify |
|
Also just realized that the temperature parameter won't affect the ranking order of entropy since all positions are affected by the same temperature, so the temp param shouldn't matter here. |
Actually this isn't true the |
yes of course, my bad. not really a fan of the proposed refactor of |
Nice! Thanks! Another recommendation: 1 argument is probably enough. - if self.filter_on_entropy:
+ if self.token_entropy_percentile_threshold < 1.0: and add that the recommended value is 0.2 in the documentation. |
Cool, let me look into making the entropy calculation less memory intensive. |
Updated the code to make sure that only a mini-batch of logits are materialized at any given point of time and entropies for those mini-batches of logits are optionally calculated. The |
nice work. left a few minor comments |
Thanks for the review Leon! Made the suggested changes. |
@qgallouedec please take another look at the PR when you have the time. |
I spent some time benchmarking different entropy calculations, scripts here. Long story short I recommend:
this is basically what you already had, but importantly we sum inside the loop which avoids materializing the [S, V] tensor. I observed no latency improvements when increasing |
@qgallouedec @LeonEricsson this PR has a bit of duplicate with my implementation of entropy regularization loss in #3628. We need to sync on both given the 2 address entropy in two different directions. The entropy regularization loss is proposed in issue and also officially implemented in verl. |
@1485840691 I think that despite the overlap both of the purposes are complementary and once this PR is pushed including the entropy loss in the final loss should be fairly simple. |
@LeonEricsson thanks for those cool benchmarks! I've updated the code to compute the sum inside the for loop as you've suggested. |
I have a comment on entropy from logits. Given verl has already provided an implementation https://github.com/volcengine/verl/blob/9b7bb69ea3165b691cc908d7f3f2f14c4a65a59e/verl/utils/torch_functional.py#L150, why do we not re-use that? Sorry, I tried benchmark it. Indeed verl's implementation does not have better running performance But given the entropy from logits function is embed inside the chunk loop of get_per_token_logps, why do we need to support chunking here? |
And there is another question regarding the interaction between entropy loss and entropy mask: Do we need to consider the entropy mask in computing entropy loss? Now entropy loss is computed using completion mask @qgallouedec @LeonEricsson |
I'm assuming you mean why we don't simply import the function from |
I don't think the entropy loss should take the entropy mask into consideration. In the entropy masking paper linked above, their intention was to just affect the policy loss. So similar to how KL-div doesn't consider the entropy mask, I believe that the entropy loss shouldn't either, especially if we want to reproduce the behavior of veRL. |
I think we might follow verl to support entropy_from_logits(https://github.com/volcengine/verl/blob/9b7bb69ea3165b691cc908d7f3f2f14c4a65a59e/verl/utils/torch_functional.py#L143) and entropy_from_logits_chunking(https://github.com/volcengine/verl/blob/9b7bb69ea3165b691cc908d7f3f2f14c4a65a59e/verl/utils/torch_functional.py#L150). And I do not mean importing verl lib to support such small util function ,just copy the code and add a comment. But given I have benchmarked it. verl's implementation does not have better performance. |
Sorry, I missed this comment. I think the objective here is to reduce peak memory usage rather than optimize on run-time, based on the comments from Leon and Quentin
|
I agree that chunking isn't necessary here, but since it's defined as a general utility function and could be useful in future scenarios, its nice to have; provides a convenient way to trade memory for latency.
Yeah,
|
some final comments then I'm happy to merge. |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Co-authored-by: LeonEricsson <70749762+LeonEricsson@users.noreply.github.com> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
What does this PR do?
This PR is in relation to #3555 which proposes to mask out the policy loss coming from tokens in the completions corresponding to positions with an entropy scores below the bottom-k percentile.
This idea is proposed by the Qwen team in their accompanying paper Beyond the 80/20 Rule
Key Proposals of the paper that guided the implementation
The key difference is the term
From the paper
Entropy is calculated as normal via the formula
The paper applies the entropy mask to the DAPO loss function in their experiments, but I think we can leave it to the user to decide which loss i.e. GRPO, Dr. GRPO or DAPO to apply it to.
The paper finds that the best threshold is to keep the top-20% of tokens based on their entropy.
I didn't run the vllm tests inside
test_grpo_trainer.py
since my machine/vm didn't have access to a gpu.Fixes #3555
Before submitting
Pull Request section?
to it if that's the case.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.