Fix, or rather "port", bug fix for sdpa #10466
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary:
Flash attention imlementation breaks q @ k matmul in chunks in both source
seqlen and target seqlen (k cache) dim. When using masks typically masks are of
shape [q seq len, k seq len], where k seq len == kv cache size.
Imagine you have k seq len = 700 and mask that is like
0 1 2 3 ...........515........575 576.....697 698 699 700
-inf -inf -inf -inf............0..........0. -inf.... -inf -inf -inf -inf
What this is doing really is telling you that you should attend only to the
middle portion. For example when you are decoding pos 575 you want to attend
to only previous 60 position but nothing before that. In that case position
515 to 575 in kv cache is what you care for.
This is how sliding window attention can be implemented.
Now comes the interesting part. Because flash attention implementation chunk
along k seq len dim, we have this chunk size set to 512. Thus in the first chunk
of q @ k you add attention mask of -inf. This makes your entire chunk -inf indicating
you dont want to attend to this chunk at all. Well you could have honestly avoided
this calculation entirely but maybe thats for another day. However, as a result of
calculating this q @ k and adding mask, you now have value containing all -infs.
This introduces numerics issue in flash attention if not carefully guarded.
All subsequent calculatings for softmax will now be nans. Why? Because how flash
attention progressively calculates attention and makes final adjustments in the last stage.
But because we have nans now, all subsequent calculatins also result in nans.
I found this the hard way and thought wait, why is this not the problem in core
from which much of this code is copied. Well indeed, it was and fixed after this code was copied.
It was fixed in this PR pytorch/pytorch#130014
If we had better code sharing, this probably could have been avoided but we have diverged
quite a bit now, plus the ugliness in both places are irreconcilable.
Differential Revision: D73640471