-
Notifications
You must be signed in to change notification settings - Fork 25.2k
[cpu][flash attention] fix nan issue #130014
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/130014
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 37f6864 with merge base 1e27af3 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
@pytorchbot merge |
Merge failedReason: This PR needs a If not, please add the To add a label, you can comment to pytorchbot, for example For more information, see Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Fixes pytorch#127055. NaNs are generated in flash attention because the computation of `std::exp((-inf) - (-inf))` and `+/-inf * 0` in lazy softmax. We fix the issue by avoiding the related calculation. Pull Request resolved: pytorch#130014 Approved by: https://github.com/jgong5, https://github.com/drisspg
@pytorchbot cherry-pick --onto release/2.4 -c critical --fixes #127055 |
Fixes #127055. NaNs are generated in flash attention because the computation of `std::exp((-inf) - (-inf))` and `+/-inf * 0` in lazy softmax. We fix the issue by avoiding the related calculation. Pull Request resolved: #130014 Approved by: https://github.com/jgong5, https://github.com/drisspg (cherry picked from commit 868d9a4)
Cherry picking #130014The cherry pick PR is at #133598 and it is linked with issue #127055. The following tracker issues are updated: Details for Dev Infra teamRaised by workflow job |
[cpu][flash attention] fix nan issue (#130014) Fixes #127055. NaNs are generated in flash attention because the computation of `std::exp((-inf) - (-inf))` and `+/-inf * 0` in lazy softmax. We fix the issue by avoiding the related calculation. Pull Request resolved: #130014 Approved by: https://github.com/jgong5, https://github.com/drisspg (cherry picked from commit 868d9a4) Co-authored-by: Valentine233 <xuan.liao@intel.com>
Summary: Flash attention imlementation breaks q @ k matmul in chunks in both source seqlen and target seqlen (k cache) dim. When using masks typically masks are of shape [q seq len, k seq len], where k seq len == kv cache size. Imagine you have k seq len = 700 and mask that is like 0 1 2 3 ...........515........575 576.....697 698 699 700 -inf -inf -inf -inf............0..........0. -inf.... -inf -inf -inf -inf What this is doing really is telling you that you should attend only to the middle portion. For example when you are decoding pos 575 you want to attend to only previous 60 position but nothing before that. In that case position 515 to 575 in kv cache is what you care for. This is how sliding window attention can be implemented. Now comes the interesting part. Because flash attention implementation chunk along k seq len dim, we have this chunk size set to 512. Thus in the first chunk of q @ k you add attention mask of -inf. This makes your entire chunk -inf indicating you dont want to attend to this chunk at all. Well you could have honestly avoided this calculation entirely but maybe thats for another day. However, as a result of calculating this q @ k _and_ adding mask, you now have value containing all -infs. This introduces numerics issue in flash attention if not carefully guarded. All subsequent calculatings for softmax will now be nans. Why? Because how flash attention progressively calculates attention and makes final adjustments in the last stage. But because we have nans now, all subsequent calculatins also result in nans. I found this the hard way and thought wait, why is this not the problem in core from which much of this code is copied. Well indeed, it was and fixed after this code was copied. It was fixed in this PR pytorch/pytorch#130014 If we had better code sharing, this probably could have been avoided but we have diverged quite a bit now, plus the ugliness in both places are irreconcilable. Differential Revision: D73640471
Summary: Flash attention imlementation breaks q @ k matmul in chunks in both source seqlen and target seqlen (k cache) dim. When using masks typically masks are of shape [q seq len, k seq len], where k seq len == kv cache size. Imagine you have k seq len = 700 and mask that is like 0 1 2 3 ...........515........575 576.....697 698 699 700 -inf -inf -inf -inf............0..........0. -inf.... -inf -inf -inf -inf What this is doing really is telling you that you should attend only to the middle portion. For example when you are decoding pos 575 you want to attend to only previous 60 position but nothing before that. In that case position 515 to 575 in kv cache is what you care for. This is how sliding window attention can be implemented. Now comes the interesting part. Because flash attention implementation chunk along k seq len dim, we have this chunk size set to 512. Thus in the first chunk of q @ k you add attention mask of -inf. This makes your entire chunk -inf indicating you dont want to attend to this chunk at all. Well you could have honestly avoided this calculation entirely but maybe thats for another day. However, as a result of calculating this q @ k _and_ adding mask, you now have value containing all -infs. This introduces numerics issue in flash attention if not carefully guarded. All subsequent calculatings for softmax will now be nans. Why? Because how flash attention progressively calculates attention and makes final adjustments in the last stage. But because we have nans now, all subsequent calculatins also result in nans. I found this the hard way and thought wait, why is this not the problem in core from which much of this code is copied. Well indeed, it was and fixed after this code was copied. It was fixed in this PR pytorch/pytorch#130014 If we had better code sharing, this probably could have been avoided but we have diverged quite a bit now, plus the ugliness in both places are irreconcilable. Reviewed By: larryliu0820 Differential Revision: D73640471
Fixes #127055.
NaNs are generated in flash attention because the computation of
std::exp((-inf) - (-inf))
and+/-inf * 0
in lazy softmax. We fix the issue by avoiding the related calculation.cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10