-
Notifications
You must be signed in to change notification settings - Fork 2.8k
fix per token cuda kernel hidden dim cannot divide by 16 #8543
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @hebiao064, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request aims to enhance the per_token_quant_fp8
CUDA kernel by extending its compatibility to hidden dimensions divisible by 8, rather than being strictly limited to those divisible by 16. It introduces a new 8-element vectorized kernel path and dynamically selects the appropriate kernel based on the input tensor's dimensions. To ensure correctness and facilitate development, a pure PyTorch reference implementation has been added to the benchmarking suite, allowing for direct verification against the optimized CUDA kernels. The build system has also been temporarily streamlined to focus exclusively on this kernel, indicating an ongoing development effort.
Highlights
- Expanded
per_token_quant_fp8
Support: Theper_token_quant_fp8
CUDA kernel now supports input tensors where the hidden dimension is divisible by 8, not just 16. This is achieved by introducing a new 8-element vectorized kernel path and dynamically selecting the appropriate kernel based on thehidden_dim
. - Enhanced Benchmarking and Verification: A pure PyTorch reference implementation for per-token FP8 quantization has been added to the benchmark script. This allows for direct correctness comparison and more comprehensive performance analysis against the VLLM and SGLang kernel implementations across various
hidden_dim
values. - Streamlined Development Environment: The build configuration (
CMakeLists.txt
) and Python bindings (common_extension.cc
,__init__.py
) have been temporarily stripped down to focus solely on theper_token_quant_fp8
kernel. This facilitates isolated development, debugging, and benchmarking of this specific feature.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This PR adds support for hidden_dim
values that are multiples of 8 in the per_token_quant_fp8
kernel. The changes include updating the CUDA kernel, modifying build files, and extending the benchmark. Critical issues related to commented-out code in build files and Python packages need to be resolved. A potential correctness issue in the benchmark reference implementation and opportunities to improve the CUDA code's portability and maintainability were also identified.
Clean code and add modify sgl-kernel |
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
|
|
Fixed by #9093 |
File Changes
All other compilation file changes will be reverted, I did it to accelerate my development process
Motivation
First of all, I updated the benchmark to include torch's quant implementation, which surfaced that both vllm and sglang quant kernel is not very accurate as torch, but I do feel it its acceptable, since vllm and sglang both delivered similar quantization result.
Secondly, I modified the kernel to allow hidden dim like 1368 which will fail due to #8460, not it's solved.
E2E Test
TP 4:
TP 8: (before this pr, TP8 will fail)
Need to use USE_VLLM_CUTLASS_W8A8_FP8_KERNEL since our cutlass fp8 doesn't support
RuntimeError: mat_a must be multiple of 16 bytes for memory alignment
Before this PR:
Hidden Dim 1368:
RuntimeError: Hidden dimension must be divisible by 16, but got 1368
After the PR
To be added
Modifications
Accuracy Test
Benchmark & Profiling
Checklist