-
Notifications
You must be signed in to change notification settings - Fork 2.8k
[Feature] Apply Cublas Grouped Gemm kernel #3629
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
yizhang2077
reviewed
Feb 17, 2025
yizhang2077
reviewed
Feb 17, 2025
Since pytorch 2.5.1 only supports cuda12.4 in official docs, and we can not change pytorch version easily, we need to update doc to guide user to reinstall pytorch if they want to use group gemm to accelerate their models. |
LGTM cc @zhyncs |
yizhang2077
reviewed
Feb 18, 2025
yizhang2077
approved these changes
Feb 18, 2025
amazing work! |
Closed
5 tasks
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Motivation
#3323
Grouped Gemm kernel added in Cublas 12.5 is useful. It can be applied to MoE EP layer/Lora layer for acceleration.
Modifications
cublas_grouped_gemm
in sgl-kernel library, and provides accuracy test/benchmark script.Environment:
Torch 2.5.1, Cuda 12.5, Cublas 12.5.3.2, sglang 0.4.3
Since sglang doesn't support torch 2.6 yet, to build the environment:
nvcc -V
pip install nvidia-cublas-cu12==12.5.3.2
so that the cublas is upgradedAccuracy Test
Kernel Benchmark
Deepseek V2 setting
On Deepseek V2 setting with TP Size = 8 (Group Size=20), N = 3072, K = 5120:
!python3 sgl-kernel/benchmark/bench_cublas_grouped_gemm.py --models DeepSeek-V2 --tp-size 8
Result in GB per second:

Deepseek V2-Lite setting
On Deepseek V2 setting with TP Size = 2 (Group Size=32), N = 2816, K = 2048:
!python3 sgl-kernel/benchmark/bench_cublas_grouped_gemm.py --models DeepSeek-V2-Lite --tp-size 2
Result in GB per second:

Checklist