-
-
Notifications
You must be signed in to change notification settings - Fork 9.8k
[torch.compile][ROCm] Fuse quantization onto attention using a torch.compile pass #16756
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[torch.compile][ROCm] Fuse quantization onto attention using a torch.compile pass #16756
Conversation
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
d6b46c4
to
d9d415d
Compare
Triton compile issue resolvedThe code is currently failing with a Triton compilation error (weird):
The offending line:
Repro steps (even without this PR, no torch.compile, nothing):
|
Memory issue resolvedTriton memory issueRepro steps:
Works without attention fusion:
|
This pull request has merge conflicts that must be resolved before it can be |
Hi @zou3519 , could you also help review on the torch.compile pass part? Thanks. |
ca19be3
to
fc60dcc
Compare
130f5a8
to
030f5ce
Compare
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
- cleanup backends to release llms - increase gpu_model_utilization Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
66152d1
to
98de2f9
Compare
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's awesome, thanks!
…compile pass (vllm-project#16756) Signed-off-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: Sage Moore <sage@neuralmagic.com> Signed-off-by: minpeter <kali2005611@gmail.com>
…compile pass (vllm-project#16756) Signed-off-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: Sage Moore <sage@neuralmagic.com>
…compile pass (vllm-project#16756) Signed-off-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: Sage Moore <sage@neuralmagic.com>
…compile pass (vllm-project#16756) Signed-off-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: Sage Moore <sage@neuralmagic.com> Signed-off-by: avigny <47987522+avigny@users.noreply.github.com>
…compile pass (vllm-project#16756) Signed-off-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: Sage Moore <sage@neuralmagic.com>
This PR implements the fusion of fp8 quantization onto attention, described in #16220. It performs this fusion using a new
AttnFusionPass
, which uses the pattern matcher and only performs the fusion if the backend supports it. It is currently off by default, pending more robust V1 support and performance measurement.This PR also makes the following changes:
output_scale
added as a parameter tounified_attention_with_output
. During thetorch.compile
fusion pass, we do not have access to the scale, just the graph node corresponding to it. Hence, we cannot just set the scale on the layer object.fused_output_quant_supported
on theAttentionImpl
. This method tells the fusion pass that it is safe to fuse the output quantization onto attention. It is opt-in, so fusion will only be performed if the backend impl supports it.ROCmFlashAttentionImpl
attention backend. This is the motivating case for this pass, as AMD is adding support for fused attention quantization to the Triton kernel in [Kernel][Triton][FP8] Adding fp8 and variable length sequence support to Triton FAv2 kernel #12591.(since [Feature] support sequence parallelism using compilation pass #16155 we pass vllm_config to passes anyway)PostGradPassManager
now accepts the forward context as a parameter, so it can be passed to passes that need it (likeAttnFusionPass
). We cannot currently pass the whole compilation config as that would create a cycle when the manager adds itself toCompilationConfig.inductor_config
.NoOpEliminationPass
. We now also replace a chain of reshapes with the last one:t.view(*args1).view(*args2).view(*args3)
->t.view(*args3)
. This is needed for correct pattern matching. Either way it's always good to simplify the graph.lazy_format_graph_pass
insideVllmInductorPass
makes sure the graph gets printed when debugging withdepyf
.AttnFusionPass
. Using LLM instances instead of silly models to avoid redoing metadata setup in test code.While this PR only adds fusion support on ROCm, it makes it easy to add support for other backends once their attention kernels add support for fused quantization of output. This includes V1, although we'll either need to use full cudagraphs or address the piecewise problem as described in #16220. Additionally, support for other quantization schemes can be added as well with minor additions to the pass (matching appropriate quant ops and saving quant metadata into the attention layer).
This PR depends on #12591, #15734, #16431 and #17139. All of them have been merged to main.
Pewrf numbers below, improvement on decode (ITL) and reduction on prefill (TTFT), but not going to invest as it's using a deprecated V0 prefill triton kernel. Will revisit performance with support for other backends.
VLLM_USE_V1=0 vllm serve amd/Llama-3.1-8B-Instruct-FP8-KV -O '{"pass_config":{"enable_attn_fusion": false}}'
:VLLM_USE_V1=0 vllm serve serve amd/Llama-3.1-8B-Instruct-FP8-KV -O '{"pass_config":{"enable_attn_fusion": true}}'
: