-
-
Notifications
You must be signed in to change notification settings - Fork 10k
[Perf] Cuda Kernel for Per Token Group Quant #21083
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Perf] Cuda Kernel for Per Token Group Quant #21083
Conversation
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new CUDA kernel for per-token-group FP8 quantization. The changes include the kernel implementation, its C++ and Python bindings, and integration into the quantization utilities. The new CUDA kernel is used when available, with a fallback to the Triton kernel.
There are a few high-severity issues that should be addressed:
- The C++ function signature for the new op should align with existing conventions in the codebase.
- The new CUDA kernel can be further optimized to reduce global memory bandwidth.
- A check in the C++ code is overly restrictive and could lead to runtime errors for valid input shapes.
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
CC @mgoin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need a kernel unit test now to compare the cuda kernel against the triton/torch impl for the 4 cases we have to consider now (row float, row e8m0, col float, col e8m0)
TORCH_CHECK(input.is_contiguous()); | ||
TORCH_CHECK(output_q.is_contiguous()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Contiguous might be a problem for MLA, so please test a couple DeepSeek evals/benchmarks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right, so I choose to fallback to triton when input is not contiguous.
Now it works:
VLLM_USE_DEEP_GEMM=1 lm_eval --model vllm --model_args "pretrained=deepseek-ai/DeepSeek-R1,data_parallel_size=8,gpu_memory_utilization=0.95,max_model_len=16384,enable_expert_parallel=True" --tasks gsm8k --batch_size auto --num_fewshot 5
|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.9560|± |0.0056|
| | |strict-match | 5|exact_match|↑ |0.9545|± |0.0057|
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
@@ -366,6 +366,7 @@ def per_token_group_quant_fp8( | |||
dtype: Optional[torch.dtype] = None, | |||
column_major_scales: bool = False, | |||
out_q: Optional[torch.Tensor] = None, | |||
use_ue8m0: bool = is_blackwell_deep_gemm_used(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I worry about setting this as a default variable since this function could be used on Blackwell, but for the CUTLASS or FlashInfer FP8 block kernels that are now on SM100
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is_blackwell_deep_gemm_used
will check the env VLLM_USE_DEEP_GEMM
as well, so it won't cause trouble now.
And this default is no so good as well, since DeepGemm now supports float32 scale on B200 now, I will have another pr for this, letting the user decide whether to use e8m0
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
I think this PR causes this build error on MI300:
Checking out to the previous commit 2c8db17 and cherry-picking 226b452 build goes fine. (or maybe something is wrong in my env cc @gshtras) |
From Float8_e4m3fn alone doesn't look like this file is supposed to be used on ROCm |
I was about to report the same thing, MI300A GPU, build worked before. |
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: qizixi <qizixi@meta.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: avigny <47987522+avigny@users.noreply.github.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: shuw <shuw@nvidia.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: x22x22 <wadeking@qq.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: Paul Pak <paulpak58@gmail.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: Diego-Castan <diego.castan@ibm.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Purpose
Add cuda kernel for per-token-group-quant instead of triton
Modified from SGL, https://github.com/sgl-project/sglang/blob/570d33437bf0b4ac42e00ad468ddc43f9e0b376f/sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu, thanks for the code!
Add support for row-major E8M0, float32 scale tensor, and optimized the cuda kernel using shared memory.
Test
B200
Accuracy
Performance
VLLM_USE_DEEP_GEMM=1 vllm bench throughput --model Qwen/Qwen3-30B-A3B-FP8 --load-format dummy --input-len 1000 --output-len 100 --trust_remote_code --enable-expert-parallel Throughput: 27.00 requests/s, 29643.76 total tokens/s, 2700.48 output tokens/s # main Throughput: 25.35 requests/s, 27828.75 total tokens/s, 2535.13 output tokens/s
R1(EP+DP):
H100