Skip to content

Conversation

Valentine233
Copy link
Collaborator

@Valentine233 Valentine233 commented Jul 3, 2024

Fixes #127055.

NaNs are generated in flash attention because the computation of std::exp((-inf) - (-inf)) and +/-inf * 0 in lazy softmax. We fix the issue by avoiding the related calculation.

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10

Copy link

pytorch-bot bot commented Jul 3, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/130014

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 37f6864 with merge base 1e27af3 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the module: cpu CPU specific problem (e.g., perf, algorithm) label Jul 3, 2024
@Valentine233 Valentine233 requested review from jgong5 and removed request for jgong5 July 3, 2024 06:44
@Valentine233 Valentine233 marked this pull request as draft July 3, 2024 07:07
@Valentine233 Valentine233 requested a review from jgong5 July 3, 2024 13:27
@Valentine233 Valentine233 marked this pull request as ready for review July 3, 2024 13:27
@zou3519 zou3519 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jul 3, 2024
@Valentine233 Valentine233 requested a review from drisspg July 4, 2024 02:31
@Valentine233 Valentine233 added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 4, 2024
@Valentine233 Valentine233 requested a review from eellison July 10, 2024 01:31
Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@drisspg drisspg added the topic: bug fixes topic category label Jul 10, 2024
@Valentine233
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

@Valentine233
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Jul 25, 2024
Fixes pytorch#127055.

NaNs are generated in flash attention because the computation of `std::exp((-inf) - (-inf))` and `+/-inf * 0` in lazy softmax. We fix the issue by avoiding the related calculation.

Pull Request resolved: pytorch#130014
Approved by: https://github.com/jgong5, https://github.com/drisspg
@drisspg drisspg added this to the 2.4.1 milestone Jul 28, 2024
@atalman
Copy link
Contributor

atalman commented Aug 15, 2024

@pytorchbot cherry-pick --onto release/2.4 -c critical --fixes #127055

pytorchbot pushed a commit that referenced this pull request Aug 15, 2024
Fixes #127055.

NaNs are generated in flash attention because the computation of `std::exp((-inf) - (-inf))` and `+/-inf * 0` in lazy softmax. We fix the issue by avoiding the related calculation.

Pull Request resolved: #130014
Approved by: https://github.com/jgong5, https://github.com/drisspg

(cherry picked from commit 868d9a4)
@pytorchbot
Copy link
Collaborator

Cherry picking #130014

The cherry pick PR is at #133598 and it is linked with issue #127055. The following tracker issues are updated:

Details for Dev Infra team Raised by workflow job

atalman pushed a commit that referenced this pull request Aug 21, 2024
[cpu][flash attention] fix nan issue (#130014)

Fixes #127055.

NaNs are generated in flash attention because the computation of `std::exp((-inf) - (-inf))` and `+/-inf * 0` in lazy softmax. We fix the issue by avoiding the related calculation.

Pull Request resolved: #130014
Approved by: https://github.com/jgong5, https://github.com/drisspg

(cherry picked from commit 868d9a4)

Co-authored-by: Valentine233 <xuan.liao@intel.com>
@github-actions github-actions bot deleted the fa_nan branch September 17, 2024 01:53
kimishpatel added a commit to kimishpatel/executorch-1 that referenced this pull request Apr 25, 2025
Summary:
Flash attention imlementation breaks q @ k matmul in chunks in both source
seqlen and target seqlen (k cache) dim. When using masks typically masks are of
shape [q seq len, k seq len], where k seq len == kv cache size.
Imagine you have k seq len = 700 and mask that is like
0      1    2    3 ...........515........575  576.....697  698  699  700
-inf -inf -inf -inf............0..........0. -inf.... -inf -inf -inf -inf

What this is doing really is telling you that you should attend only to the
middle portion. For example when you are decoding pos 575 you want to attend
to only previous 60 position but nothing before that. In that case position
515 to 575 in kv cache is what you care for.

This is how sliding window attention can be implemented.

Now comes the interesting part. Because flash attention implementation chunk
along k seq len dim, we have this chunk size set to 512. Thus in the first chunk
of q @ k you add attention mask of -inf. This makes your entire chunk -inf indicating
you dont want to attend to this chunk at all. Well you could have honestly avoided
this calculation entirely but maybe thats for another day. However, as a result of
calculating this q @ k _and_ adding mask, you now have value containing all -infs.
This introduces numerics issue in flash attention if not carefully guarded.
All subsequent calculatings for softmax will now be nans. Why? Because how flash
attention progressively calculates attention and makes final adjustments in the last stage.
But because we have nans now, all subsequent calculatins also result in nans.

I found this the hard way and thought wait, why is this not the problem in core
from which much of this code is copied. Well indeed, it was and fixed after this code was copied.
It was fixed in this PR pytorch/pytorch#130014

If we had better code sharing, this probably could have been avoided but we have diverged
quite a bit now, plus the ugliness in both places are irreconcilable.

Differential Revision: D73640471
kimishpatel added a commit to kimishpatel/executorch-1 that referenced this pull request Apr 25, 2025
Summary:

Flash attention imlementation breaks q @ k matmul in chunks in both source
seqlen and target seqlen (k cache) dim. When using masks typically masks are of
shape [q seq len, k seq len], where k seq len == kv cache size.
Imagine you have k seq len = 700 and mask that is like
0      1    2    3 ...........515........575  576.....697  698  699  700
-inf -inf -inf -inf............0..........0. -inf.... -inf -inf -inf -inf

What this is doing really is telling you that you should attend only to the
middle portion. For example when you are decoding pos 575 you want to attend
to only previous 60 position but nothing before that. In that case position
515 to 575 in kv cache is what you care for.

This is how sliding window attention can be implemented.

Now comes the interesting part. Because flash attention implementation chunk
along k seq len dim, we have this chunk size set to 512. Thus in the first chunk
of q @ k you add attention mask of -inf. This makes your entire chunk -inf indicating
you dont want to attend to this chunk at all. Well you could have honestly avoided
this calculation entirely but maybe thats for another day. However, as a result of
calculating this q @ k _and_ adding mask, you now have value containing all -infs.
This introduces numerics issue in flash attention if not carefully guarded.
All subsequent calculatings for softmax will now be nans. Why? Because how flash
attention progressively calculates attention and makes final adjustments in the last stage.
But because we have nans now, all subsequent calculatins also result in nans.

I found this the hard way and thought wait, why is this not the problem in core
from which much of this code is copied. Well indeed, it was and fixed after this code was copied.
It was fixed in this PR pytorch/pytorch#130014

If we had better code sharing, this probably could have been avoided but we have diverged
quite a bit now, plus the ugliness in both places are irreconcilable.

Reviewed By: larryliu0820

Differential Revision: D73640471
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged module: cpu CPU specific problem (e.g., perf, algorithm) open source topic: bug fixes topic category topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

MultiheadAttention returns NaNs when need_weights=False for long sequences with a mask that ignores old tokens
7 participants