Skip to content

Conversation

Fridge003
Copy link
Collaborator

@Fridge003 Fridge003 commented Feb 17, 2025

Motivation

#3323

Grouped Gemm kernel added in Cublas 12.5 is useful. It can be applied to MoE EP layer/Lora layer for acceleration.

Modifications

  • Add cublas_grouped_gemm in sgl-kernel library, and provides accuracy test/benchmark script.
  • Update document for this feature.

Environment:

Torch 2.5.1, Cuda 12.5, Cublas 12.5.3.2, sglang 0.4.3
Since sglang doesn't support torch 2.6 yet, to build the environment:

  1. First make sure the Cuda version is >= 12.5 with nvcc -V
  2. Then install sglang as the official document does
  3. Reinstall cublas 12.5 through pip install nvidia-cublas-cu12==12.5.3.2 so that the cublas is upgraded
  4. Compile the new sgl-kernel library.

Accuracy Test

python3 sgl-kernel/tests/test_cublas_grouped_gemm.py 

Kernel Benchmark

Deepseek V2 setting

On Deepseek V2 setting with TP Size = 8 (Group Size=20), N = 3072, K = 5120:

!python3 sgl-kernel/benchmark/bench_cublas_grouped_gemm.py --models DeepSeek-V2 --tp-size 8

Result in GB per second:
截屏2025-02-17 00 56 17

Deepseek V2-Lite setting

On Deepseek V2 setting with TP Size = 2 (Group Size=32), N = 2816, K = 2048:

!python3 sgl-kernel/benchmark/bench_cublas_grouped_gemm.py --models DeepSeek-V2-Lite --tp-size 2

Result in GB per second:
截屏2025-02-17 01 00 33

Checklist

@yizhang2077
Copy link
Collaborator

Since pytorch 2.5.1 only supports cuda12.4 in official docs, and we can not change pytorch version easily, we need to update doc to guide user to reinstall pytorch if they want to use group gemm to accelerate their models.

@Fridge003 Fridge003 changed the title [Feature] Implement Cublas Grouped Gemm kernel and apply it to MoE EP [Feature] Implement Cublas Grouped Gemm kernel Feb 17, 2025
@Fridge003 Fridge003 changed the title [Feature] Implement Cublas Grouped Gemm kernel [Feature] Apply Cublas Grouped Gemm kernel Feb 17, 2025
@yizhang2077
Copy link
Collaborator

LGTM cc @zhyncs

@zhyncs zhyncs merged commit 67fc595 into sgl-project:main Feb 18, 2025
12 of 16 checks passed
@zhyncs
Copy link
Member

zhyncs commented Feb 18, 2025

amazing work!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants