-
-
Notifications
You must be signed in to change notification settings - Fork 9.9k
[Feature][Kernel] Blocked FP8 CUTLASS MoE for Hopper #19983
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Feature][Kernel] Blocked FP8 CUTLASS MoE for Hopper #19983
Conversation
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
…ate it yet Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @ElizaWszola, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request significantly enhances vLLM's quantization capabilities by integrating blocked FP8 CUTLASS MoE support for SM90 GPUs. This addition provides a more optimized path for running Mixture of Experts models with fine-grained FP8 quantization, improving efficiency and potentially performance on compatible hardware.
Highlights
- New Feature: Blocked FP8 CUTLASS MoE: This PR introduces support for blocked FP8 quantization within the CUTLASS Mixture of Experts (MoE) implementation, specifically targeting SM90 (NVIDIA Hopper) architectures. This enables more efficient execution of MoE layers with block-wise FP8 weights and activations.
- Core Kernel Implementation: A new CUDA kernel,
grouped_mm_c3x_with_blockwise_scaling.cu
, has been added. This kernel leverages CUTLASS 3.x to perform grouped matrix multiplications with block-wise FP8 scaling, which is crucial for the new quantization scheme. - API and Integration: New C++ functions (
cutlass_moe_blockwise_mm
,transpose_cutlass_moe_a_scales
) and their Python bindings have been introduced to expose the blocked MoE functionality. Thevllm
FP8 quantization handler is updated to utilize this new CUTLASS implementation when supported by the hardware. - Comprehensive Testing: Extensive unit tests have been added to validate the new blocked FP8 MoE implementation, including tests for single grouped multiply operations and fused experts operations, ensuring correctness and performance.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This PR adds support for blocked FP8 CUTLASS MoE for SM90 architectures. The changes are comprehensive. My review focuses on correctness and maintainability. I've found a critical compilation issue due to a typo in a macro definition. Additionally, there are several areas with commented-out code, debug statements, and redundant checks that should be cleaned up to improve code quality before merging. The new tests are comprehensive and cover the new functionality well.
csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_with_blockwise_scaling.cu
Outdated
Show resolved
Hide resolved
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn); | ||
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// if (n >= 8192) { | ||
// cutlass_blockwise_group_gemm_caller<Cutlass3xGemmN8192>( | ||
// out_tensors, a_tensors, b_tensors, a_scales, b_scales, | ||
// expert_offsets, problem_sizes, a_strides, b_strides, c_strides, | ||
// per_act_block); | ||
// } else if (k >= 8192) { | ||
// cutlass_blockwise_group_gemm_caller<Cutlass3xGemmK8192>( | ||
// out_tensors, a_tensors, b_tensors, a_scales, b_scales, | ||
// expert_offsets, problem_sizes, a_strides, b_strides, c_strides, | ||
// per_act_block); | ||
// } else if (m <= 16) { | ||
// cutlass_blockwise_group_gemm_caller<Cutlass3xGemmM16>( | ||
// out_tensors, a_tensors, b_tensors, a_scales, b_scales, | ||
// expert_offsets, problem_sizes, a_strides, b_strides, c_strides, | ||
// per_act_block,); | ||
// } else { | ||
cutlass_blockwise_group_gemm_caller<Cutlass3xGemmDefault>( | ||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, | ||
problem_sizes, a_strides, b_strides, c_strides, per_act_block); | ||
// } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
@ElizaWszola If you want to get your PR in first, I don't mind following up with the SM100 PR #19757 but modifying |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's tune these kernels and see how they measure up to DeepGEMM!
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, | ||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, | ||
torch::Tensor const& b_strides, torch::Tensor const& c_strides, | ||
bool per_act_block); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does per_act_block
mean? What does it mean if it's true and what does it mean if it's false? Please add some documentation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When it's true, this means [1x128]-block input scales. When it's false, we use per tensor scales - not sure if this is or will be needed, I can delete if it's too much extra code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's good to support both cases. But please add some comments documenting what the variable means.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added relevant comments in torch_bindings.cpp
and cutlass_moe.py
.
if per_act_block: | ||
a1q_scale_t = torch.empty((a1q_scale.shape[0] * a1q_scale.shape[1]), | ||
device=device, | ||
dtype=a1q_scale.dtype) | ||
ops.transpose_cutlass_moe_a_scales(a1q_scale_t, a1q_scale, | ||
expert_offsets, problem_sizes1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you run the pytorch profiler to see how much time this takes?
Would the code be cleaner if ops.transpose_cutlass_moe_a_scales
returned a1q_scale_t
rather than requiring the caller to allocate it with torch.empty
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've profiled the function with a few different sizes and the both transpose function together take only slightly more time than the quantization of intermediate results (a2q, a2q_scale = ops.scaled_fp8_quant(...)
). So for large enough inputs, this is an order of magnitude less than the kernel runtimes.
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: ElizaWszola <ewszola@redhat.com>
@djmmoss Yes, this should address them. This PR might land a bit late though - I still have to do a bit of benchmarking and possibly add a bunch of kernel configs for performance |
Signed-off-by: ElizaWszola <ewszola@redhat.com>
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
@@ -637,13 +654,14 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") | |||
endif() | |||
|
|||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a" "${CUDA_ARCHS}") | |||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) | |||
if(VLLM_COMPILE_FP8_BLOCKWISE_CUTLASS_MOE AND ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we want to guard against the existing sm100 kernels here too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is where I'm waiting for the input from the SM100 kernel's author. I e2e-benchmarked CUTLASS vs. Triton on a SM100 machine and CUTLASS was slower, but I would like them to confirm that CUTLASS is slower than Triton also with their setup.
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the work!
@@ -598,6 +598,55 @@ def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int, | |||
return None | |||
|
|||
|
|||
# Copied and adapted from | |||
# https://github.com/deepseek-ai/DeepGEMM/blob/78cacf70d41d15d688bd493ebc85845f7f2a3d5d/tests/test_core.py#L17 | |||
def per_block_cast_to_fp8( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can reuse other per_block_cast_to_fp8
, similar to #21787
assert x.dim() == 2 | ||
m, n = x.shape | ||
|
||
def ceil_div(x: int, y: int) -> int: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use vllm utils cdiv please
def native_per_token_group_quant_fp8(x, | ||
group_size, | ||
eps=1e-10, | ||
dtype=torch.float8_e4m3fn): | ||
"""Function to perform per-token-group quantization on an input tensor | ||
`x` using native torch.""" | ||
assert x.shape[-1] % group_size == 0, ("the last dimension of `x` cannot " | ||
"be divisible by `group_size`") | ||
assert x.is_contiguous(), "`x` is not contiguous" | ||
|
||
finfo = torch.finfo(dtype) | ||
fp8_min = finfo.min | ||
fp8_max = finfo.max | ||
|
||
x_ = x.reshape(x.numel() // group_size, group_size) | ||
amax = x_.abs().max(dim=-1, | ||
keepdim=True)[0].clamp(min=eps).to(torch.float32) | ||
x_s = amax / fp8_max | ||
x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype) | ||
x_q = x_q.reshape(x.shape) | ||
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, )) | ||
|
||
return x_q, x_s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need this instead of cuda kernel per token group quant?
const float* a_scales, | ||
const int32_t* expert_offsets, | ||
const int32_t* problem_sizes, | ||
int64_t k_scaled) { | ||
int64_t expert_idx = blockIdx.x; | ||
int64_t start_k_scaled = threadIdx.x; | ||
int64_t step_k_scaled = blockDim.x; | ||
int64_t expert_offset = expert_offsets[expert_idx]; | ||
int64_t num_tokens = problem_sizes[expert_idx * 3]; | ||
int64_t expert_offset_scaled = expert_offset * k_scaled; | ||
|
||
for (int64_t t = 0; t < num_tokens; ++t) { | ||
for (int64_t k = start_k_scaled; k < k_scaled; k += step_k_scaled) { | ||
a_scales_t[expert_offset_scaled + k * num_tokens + t] = | ||
a_scales[expert_offset_scaled + t * k_scaled + k]; | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pure for scalar would be slow.
Some thoughts that we could optimize this:
- vectorization using vectorize_with_alignment
- shared memory tile
|
||
// Swap-AB should be disabled for FP4 path | ||
bool may_swap_ab = (!blockscale_offsets.has_value()) && | ||
(topk_ids.numel() <= SWAP_AB_THRESHOLD); | ||
|
||
if (may_swap_ab) { | ||
bool swap_ab = !force_no_swap && topk_ids.numel() <= SWAP_AB_THRESHOLD; | ||
if (swap_ab) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we are not using blockscale_offsets.has_value()
, we should make sure that nvfp4 path will pass in the force_no_swap
using ElementScale = typename Gemm::ElementScale; | ||
using ScaleConfig = typename Gemm::ScaleConfig; | ||
using LayoutSFA = typename Gemm::LayoutSFA; | ||
using LayoutSFB = typename Gemm::LayoutSFB; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe setting the correct LayoutSFA
could avoid the transpose of a_scales in cutlass_moe.py ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Many thanks! Will look into this tomorrow.
# Get the right scale for tests. | ||
if per_act_block: | ||
a_q, a_scale = per_token_group_quant_fp8(moe_tensors_fp16.a, | ||
block_size[1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
transpose here ?
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: elvircrn <elvircrn@gmail.com>
Signed-off-by: elvircrn <elvircrn@gmail.com>
Signed-off-by: elvircrn <elvircrn@gmail.com>
These are the latest benchmark numbers of H100:
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Elvir Crnčević <elvircrn@gmail.com>
This pull request has merge conflicts that must be resolved before it can be |
Add support for blocked fp8 CUTLASS MoE for SM90.
Testing:
Single grouped multiply unit tests:
Fused experts op unit tests:
LMM for offline inference test:
Performance:
Currenlty, blocked CUTLASS is slightly worse than Triton on average, but beneficial for some shapes:
Future PRs should involve improvements in speed for functions that preprocess data before running the MoE kernels.