Skip to content

Conversation

kimishpatel
Copy link
Contributor

Summary:
Flash attention imlementation breaks q @ k matmul in chunks in both source
seqlen and target seqlen (k cache) dim. When using masks typically masks are of
shape [q seq len, k seq len], where k seq len == kv cache size.
Imagine you have k seq len = 700 and mask that is like
0 1 2 3 ...........515........575 576.....697 698 699 700
-inf -inf -inf -inf............0..........0. -inf.... -inf -inf -inf -inf

What this is doing really is telling you that you should attend only to the
middle portion. For example when you are decoding pos 575 you want to attend
to only previous 60 position but nothing before that. In that case position
515 to 575 in kv cache is what you care for.

This is how sliding window attention can be implemented.

Now comes the interesting part. Because flash attention implementation chunk
along k seq len dim, we have this chunk size set to 512. Thus in the first chunk
of q @ k you add attention mask of -inf. This makes your entire chunk -inf indicating
you dont want to attend to this chunk at all. Well you could have honestly avoided
this calculation entirely but maybe thats for another day. However, as a result of
calculating this q @ k and adding mask, you now have value containing all -infs.
This introduces numerics issue in flash attention if not carefully guarded.
All subsequent calculatings for softmax will now be nans. Why? Because how flash
attention progressively calculates attention and makes final adjustments in the last stage.
But because we have nans now, all subsequent calculatins also result in nans.

I found this the hard way and thought wait, why is this not the problem in core
from which much of this code is copied. Well indeed, it was and fixed after this code was copied.
It was fixed in this PR pytorch/pytorch#130014

If we had better code sharing, this probably could have been avoided but we have diverged
quite a bit now, plus the ugliness in both places are irreconcilable.

Differential Revision: D73640471

Copy link

pytorch-bot bot commented Apr 25, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/10466

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 376d293 with merge base 8321a4a (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 25, 2025
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D73640471

@kimishpatel kimishpatel added the release notes: examples Changes to any of our example LLMs integrations, such as Llama3 and Llava label Apr 25, 2025
Summary:

Flash attention imlementation breaks q @ k matmul in chunks in both source
seqlen and target seqlen (k cache) dim. When using masks typically masks are of
shape [q seq len, k seq len], where k seq len == kv cache size.
Imagine you have k seq len = 700 and mask that is like
0      1    2    3 ...........515........575  576.....697  698  699  700
-inf -inf -inf -inf............0..........0. -inf.... -inf -inf -inf -inf

What this is doing really is telling you that you should attend only to the
middle portion. For example when you are decoding pos 575 you want to attend
to only previous 60 position but nothing before that. In that case position
515 to 575 in kv cache is what you care for.

This is how sliding window attention can be implemented.

Now comes the interesting part. Because flash attention implementation chunk
along k seq len dim, we have this chunk size set to 512. Thus in the first chunk
of q @ k you add attention mask of -inf. This makes your entire chunk -inf indicating
you dont want to attend to this chunk at all. Well you could have honestly avoided
this calculation entirely but maybe thats for another day. However, as a result of
calculating this q @ k _and_ adding mask, you now have value containing all -infs.
This introduces numerics issue in flash attention if not carefully guarded.
All subsequent calculatings for softmax will now be nans. Why? Because how flash
attention progressively calculates attention and makes final adjustments in the last stage.
But because we have nans now, all subsequent calculatins also result in nans.

I found this the hard way and thought wait, why is this not the problem in core
from which much of this code is copied. Well indeed, it was and fixed after this code was copied.
It was fixed in this PR pytorch/pytorch#130014

If we had better code sharing, this probably could have been avoided but we have diverged
quite a bit now, plus the ugliness in both places are irreconcilable.

Reviewed By: larryliu0820

Differential Revision: D73640471
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D73640471

@facebook-github-bot facebook-github-bot merged commit f3e8972 into pytorch:main Apr 25, 2025
85 of 86 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported release notes: examples Changes to any of our example LLMs integrations, such as Llama3 and Llava
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants