-
Notifications
You must be signed in to change notification settings - Fork 2.9k
[2/2] Add python wrapper for CUTLASS FP8 Blockscale MoE Kernel. #5694
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
9ab8b4c
to
2db133b
Compare
94c92f4
to
4d8357a
Compare
Hi all, can somebody take a look at the PR as we have tested both standalone layer and e2e, and it's now ready to integrate |
Hi @elfiegg , what's your command for server launching when testing? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wonderful Work!
Hey @Fridge003, I used below command for testing, let me know if you run into any problems:
For large batch sizes, I assume you're referring to potential optimizations - if so, yes, I believe they will be addressed in the long run. In the short term, we also have a plan to integrate TRT-LLM MoE kernels. Ideally, different MoE backends could each accelerate their own specialties. Hopefully this makes sense to you! |
Thank you for your answer ! Maybe there could be some logic that chooses the better moe kernel as batch size changes. |
Agree @Fridge003, maybe let's roll out this CUTLASS logic with a server flag first? There might be other implementations like Hopper support and optimizations like epilogue fusions before we actually get to a point where we can confidently and comfortably route traffic based upon performance |
This PR can be merged first. Routing logic can be left for future PRs. |
@elfiegg What's the device did you run the benchmark? In B200, it seems slower than Triton implementation.
|
And I am facing the error when I launch the server with
|
@ispobock thanks for reporting. It looks like the error is from kernel |
I was on a B200 600W machine. My container might have unoptimized triton settings. But I did notice for small batches, where it's mostly memory-bound, Triton performed pretty well. Can you help benchmarking larger batch sizes (2048, 4096, 8192 etc) on your machine? |
NOTE
The current CUTLASS 3.9 in SGLang will experience: 1. Kernel hang 2. Perf slowdown for the this MoE kernel.
I'll update our CUTLASS dependency in another PR, as it breaks some of the existing sm90 templates.
Motivation
Using the benchmark we provided in the PR, we have found our fused_expert layer with CUTLASS 4.0 in CUDA graph mode has ~30%-40% speedup over Triton in CUDA graph mode on small batch sizes.
For Deepseek V3/R1 models, where
{'num_experts': 256, 'topk': 8, 'hidden_size': 7168, 'shard_intermediate_size': 512, 'dtype': torch.bfloat16, 'block_shape': [128, 128]}
The result of
python3 python/sglang/test/test_cutlass_moe.py
:End-to-end model accuracy validation for deepseekR1:
server:
CUTLASS_MOE=1 python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-R1 --trust-remote-code --enable-dp-attention --tp 8 --dp 8
client:
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1319 --parallel 1319
Accuracy: 0.955
Invalid: 0.000
Latency: 477.968 s
cc @depaulmillz @kushanam
Modifications
Add Python Class for CUTLASS MoE.
This PR also moves all the tensor allocations outside of the kernel implementation.
Checklist