-
-
Notifications
You must be signed in to change notification settings - Fork 9.9k
[perf] Add fused MLA QKV + strided layernorm #21116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[perf] Add fused MLA QKV + strided layernorm #21116
Conversation
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces two significant optimizations: fusing the QKV projection for MLA models and implementing a strided LayerNorm kernel. The changes are well-implemented and should provide the performance benefits described.
The fusion of Q-LoRA and KV-LoRA projections into a single matrix operation for DeepSeek-V2 models is a smart optimization that reduces kernel launch overhead and memory traffic. The introduction of MergedReplicatedLinear
to handle this fusion is a clean way to extend the existing linear layer infrastructure.
The addition of a strided layernorm implementation is crucial for the fusion to be effective, as it avoids expensive .contiguous()
calls on tensor slices. The CUDA kernels have been updated correctly to handle the input_stride
, and the PyTorch bindings are adjusted accordingly.
The test suite has been properly extended to cover the new strided input case for the layernorm kernels, ensuring the correctness of the new implementation.
Overall, this is a high-quality contribution that improves performance while maintaining code clarity and correctness. I have no major concerns.
Signed-off-by: Mickael Seznec <mickael@mistral.ai>
75b3d50
to
e3962ab
Compare
Signed-off-by: Mickael Seznec <mickael@mistral.ai>
Signed-off-by: Mickael Seznec <mickael@mistral.ai>
Signed-off-by: Mickael Seznec <mickael@mistral.ai>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the work!
from vllm.model_executor.layers.quantization.fp8 import ( | ||
Fp8LinearMethod, Fp8MoEMethod) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we refactor the code, so that we can put import on top of the file without worrying about the circular import instead here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well it's tricky, because FP8Linear already depends on Linear (which makes sense). I don't know how you'd like to proceed.
I lazily copy/pasted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/linear.py#L787-L791
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I am thinking, if A imports B, B imports A. We can have a base file C, move base things into C, so A imports C, B imports C as well.
We don't need to do it right now in this pr if you don't wish, could be done by refactor in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure! Here, the best way would probably be to rely on inheritance by defining (and overriding) methods like: QuantizeMethodBase.supports_block_quantization()
However, I don't have a complete overview on all the supported cases and potential edge-cases and it might make this PR heavier than needed now.
Happy to help with a following PR though :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds great, certainly you can do that in another pr
Signed-off-by: Mickael Seznec <mickael@mistral.ai>
Signed-off-by: Mickael Seznec <mickael@mistral.ai>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work!
Signed-off-by: Mickael Seznec <mickael@mistral.ai> Co-authored-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Mickael Seznec <mickael@mistral.ai> Co-authored-by: mgoin <mgoin64@gmail.com> Signed-off-by: qizixi <qizixi@meta.com>
Signed-off-by: Mickael Seznec <mickael@mistral.ai> Co-authored-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Mickael Seznec <mickael@mistral.ai> Co-authored-by: mgoin <mgoin64@gmail.com> Signed-off-by: avigny <47987522+avigny@users.noreply.github.com>
Signed-off-by: Mickael Seznec <mickael@mistral.ai> Co-authored-by: mgoin <mgoin64@gmail.com> Signed-off-by: shuw <shuw@nvidia.com>
Signed-off-by: Mickael Seznec <mickael@mistral.ai> Co-authored-by: mgoin <mgoin64@gmail.com> Signed-off-by: x22x22 <wadeking@qq.com>
Signed-off-by: Mickael Seznec <mickael@mistral.ai> Co-authored-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Mickael Seznec <mickael@mistral.ai> Co-authored-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Mickael Seznec <mickael@mistral.ai> Co-authored-by: mgoin <mgoin64@gmail.com> Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
Signed-off-by: Mickael Seznec <mickael@mistral.ai> Co-authored-by: mgoin <mgoin64@gmail.com> Signed-off-by: Paul Pak <paulpak58@gmail.com>
Signed-off-by: Mickael Seznec <mickael@mistral.ai> Co-authored-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Mickael Seznec <mickael@mistral.ai> Co-authored-by: mgoin <mgoin64@gmail.com> Signed-off-by: Diego-Castan <diego.castan@ibm.com>
Signed-off-by: Mickael Seznec <mickael@mistral.ai> Co-authored-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Mickael Seznec <mickael@mistral.ai> Co-authored-by: mgoin <mgoin64@gmail.com>
Essential Elements of an Effective PR Description Checklist
supported_models.md
andexamples
for a new model.Purpose
For MLA models that have a q_lora_rank: fuse q_lora and kv_lora into the same matrix (avoids some traffic + one less kernel call).
Also adds a implementation for layernorm to operate on strided input, this avoids memory copy.
Test Plan
Units tests added for strided layernorm. E2E testing & benchamrks results in this PR
Test Result
Accuracy
main (20149d8)
This PR:
Performance
main (20149d8)
This PR:
(Optional) Documentation Update