-
Notifications
You must be signed in to change notification settings - Fork 2.1k
🪂 Don't gather logits in SFT to avoid hanging #2890
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
* Don't gather logits * Remove unused function and test
I guess the reason is the sizes of gathered logits(seq length) do not match. So gathering only the number of correct tokens and the total number of tokens will be a good fix. |
with trl==0.16.0, and SFTTrainer for "Qwen/Qwen2.5-1.5B-Instruct" and base_job_name = "mt5-large-full-peft", I still encounter similar issue when my training set is larger (~300M samples) on P4d.24xlarge instance (8 A100 GPU), while the error now suggesting mismatching shape:
Is it possible that, GPU might get different batches (for example, 10 batches for 8 GPUs), so the loss/mean_token_accuracy will run into issue? |
do you still have this issue @dszhengyu with |
Yes, it is 0.16.0, sorry I made a typo of 0.15.0, fixed. |
ok so yes its an edge case where there are things are all masked and thus the sum is none and should be zero... |
@dszhengyu l'll fix it and make a patch release |
* Don't gather logits * Remove unused function and test
What does this PR do?
Fixes #2879
I'm not sure why it does always hang, but it seems that at some point in the training, always the same, it can hang while trying to gather the logits. The fix consist in gathering only the number of correct tokens and the total number of tokens.
The only way was to copy the content of
compute_token_accuracy
in theSFTTrainer
, which I think should be ok as it is only used once in the codebase.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.