-
-
Notifications
You must be signed in to change notification settings - Fork 10.1k
[CUDA] Enable full cudagraph for FlashMLA #18581
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CUDA] Enable full cudagraph for FlashMLA #18581
Conversation
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
Could you give some details on speedup associated with this modification? |
This pull request has merge conflicts that must be resolved before it can be |
I haven't necessarily profiled this but it's meant to enable the double-batch-overlap optimization (prototype in #18415) |
d5c7a35
to
c794889
Compare
c794889
to
80f20ce
Compare
976e852
to
40e7248
Compare
This pull request has merge conflicts that must be resolved before it can be |
Hi, any further progress on this pr? |
Almost ready for review! |
5e3f7ab
to
30562a2
Compare
Signed-off-by: luka <luka@neuralmagic.com>
Signed-off-by: luka <luka@neuralmagic.com>
… hardcoded error Signed-off-by: luka <luka@neuralmagic.com>
f478ecd
to
ab519de
Compare
self._num_prefill_tokens = 0 | ||
return self.build(0, m) | ||
|
||
def build(self, common_prefix_len: int, | ||
common_attn_metadata: CommonAttentionMetadata) -> M: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can put common_prefix_len
in CommonAttentionMetadata
too and just default it to 0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's currently calculated per-backend though. I guess it should always be the same value?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm maybe not for this PR but I think common_prefix_len
should always be the same regardless of use_cascade_attention
and then the backend can just choose to ignore it if use_cascade_attention
is false; then it would belong in common_prefix_len
)) | ||
|
||
attn_metadata_i = self.attn_metadata_builders[ | ||
kv_cache_group_id].build_for_cudagraph_capture( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _dummy_run
is used for more then just cudagraph capture, what if the backend doesnt support build_for_cudagraph_capture
? we should still be able to run dummy_runs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh wait I see build_for_cudagraph_capture
is in the base class; I think this still a bit confusing for backends that done support full cuda-graphs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- If a backend doesn't support it this path is not triggered (shouldn't be running with full cuda graphs)
- This method just calls build by default anyway
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we add a cudagraph_capturing
flag to _dummy_run
maybe?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think skip_attn is enough, what would change if we added that flag?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think skip_attn is enough, what would change if we added that flag?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
basically do (attn_metadata_builder.build if not cudagraph_capturing else attn_metadata_builder.build_for_cudagraph_capture)(common_metadata)
just so we only use build_for_cudagraph_capture
for cudagraph capture
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this was we can do raise NotImplemented in build_for_cudagraph_capture
if the backend doesnt support it (so we dont accidentally give the impression a backend supports cuda-graphs when it doesnt actually)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Per offline discussion, agreed this interface is not ideal. But we only use _dummy_run
with attention when capturing cudagraph capture. So I'll rename the flag, if in the future regular attention in _dummy_run
is needed, a new flag can be added.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall this is looking much better! thanks for doing the refactor, left a couple comments
Signed-off-by: luka <luka@neuralmagic.com>
Signed-off-by: luka <luka@neuralmagic.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM thanks for the refactor!
Signed-off-by: luka <luka@neuralmagic.com> Signed-off-by: minpeter <kali2005611@gmail.com>
Signed-off-by: luka <luka@neuralmagic.com>
Signed-off-by: luka <luka@neuralmagic.com>
Signed-off-by: luka <luka@neuralmagic.com> Signed-off-by: avigny <47987522+avigny@users.noreply.github.com>
Signed-off-by: luka <luka@neuralmagic.com>
Enable fullgraph CUDAGraph capture for the FlashMLA decode case.
Hacks:
Tested with: