Skip to content

Conversation

elfiegg
Copy link
Collaborator

@elfiegg elfiegg commented Apr 24, 2025

NOTE

The current CUTLASS 3.9 in SGLang will experience: 1. Kernel hang 2. Perf slowdown for the this MoE kernel.
I'll update our CUTLASS dependency in another PR, as it breaks some of the existing sm90 templates.

Motivation

Using the benchmark we provided in the PR, we have found our fused_expert layer with CUTLASS 4.0 in CUDA graph mode has ~30%-40% speedup over Triton in CUDA graph mode on small batch sizes.

For Deepseek V3/R1 models, where
{'num_experts': 256, 'topk': 8, 'hidden_size': 7168, 'shard_intermediate_size': 512, 'dtype': torch.bfloat16, 'block_shape': [128, 128]}

The result of python3 python/sglang/test/test_cutlass_moe.py:

--- Batch Size: 1 ---
Config: E=256, topk=8, H=7168, I_shard=512, dtype=torch.bfloat16, block_shape=[128, 128]
Warming up...
Benchmarking Cutlass fused_experts...
Benchmarking Triton fused_experts...
Cutlass fused_experts time: 0.106 ms (median) [0.101 - 0.107]
Triton  fused_experts time: 0.148 ms (median) [0.146 - 0.149]

--- Batch Size: 4 ---
Config: E=256, topk=8, H=7168, I_shard=512, dtype=torch.bfloat16, block_shape=[128, 128]
Warming up...
Benchmarking Cutlass fused_experts...
Benchmarking Triton fused_experts...
Cutlass fused_experts time: 0.155 ms (median) [0.150 - 0.157]
Triton  fused_experts time: 0.207 ms (median) [0.205 - 0.207]

--- Batch Size: 8 ---
Config: E=256, topk=8, H=7168, I_shard=512, dtype=torch.bfloat16, block_shape=[128, 128]
Warming up...
Benchmarking Cutlass fused_experts...
Benchmarking Triton fused_experts...
Cutlass fused_experts time: 0.234 ms (median) [0.229 - 0.235]
Triton  fused_experts time: 0.287 ms (median) [0.285 - 0.288]

--- Batch Size: 16 ---
Config: E=256, topk=8, H=7168, I_shard=512, dtype=torch.bfloat16, block_shape=[128, 128]
Warming up...
Benchmarking Cutlass fused_experts...
Benchmarking Triton fused_experts...
Cutlass fused_experts time: 0.350 ms (median) [0.345 - 0.351]
Triton  fused_experts time: 0.372 ms (median) [0.369 - 0.372]

--- Batch Size: 32 ---
Config: E=256, topk=8, H=7168, I_shard=512, dtype=torch.bfloat16, block_shape=[128, 128]
Warming up...
Benchmarking Cutlass fused_experts...
Benchmarking Triton fused_experts...
Cutlass fused_experts time: 0.534 ms (median) [0.520 - 0.535]
Triton  fused_experts time: 0.532 ms (median) [0.527 - 0.534]

--- Batch Size: 64 ---
Config: E=256, topk=8, H=7168, I_shard=512, dtype=torch.bfloat16, block_shape=[128, 128]
Warming up...
Benchmarking Cutlass fused_experts...
Benchmarking Triton fused_experts...
Cutlass fused_experts time: 0.677 ms (median) [0.674 - 0.678]
Triton  fused_experts time: 0.658 ms (median) [0.656 - 0.659]

--- Batch Size: 128 ---
Config: E=256, topk=8, H=7168, I_shard=512, dtype=torch.bfloat16, block_shape=[128, 128]
Warming up...
Benchmarking Cutlass fused_experts...
Benchmarking Triton fused_experts...
Cutlass fused_experts time: 0.959 ms (median) [0.957 - 0.960]
Triton  fused_experts time: 0.760 ms (median) [0.758 - 0.761]

--- Batch Size: 256 ---
Config: E=256, topk=8, H=7168, I_shard=512, dtype=torch.bfloat16, block_shape=[128, 128]
Warming up...
Benchmarking Cutlass fused_experts...
Benchmarking Triton fused_experts...
Cutlass fused_experts time: 1.078 ms (median) [1.071 - 1.085]
Triton  fused_experts time: 0.801 ms (median) [0.797 - 0.801]

--- Batch Size: 512 ---
Config: E=256, topk=8, H=7168, I_shard=512, dtype=torch.bfloat16, block_shape=[128, 128]
Warming up...
Benchmarking Cutlass fused_experts...
Benchmarking Triton fused_experts...
Cutlass fused_experts time: 1.325 ms (median) [1.293 - 1.326]
Triton  fused_experts time: 0.850 ms (median) [0.848 - 0.852]

End-to-end model accuracy validation for deepseekR1:
server: CUTLASS_MOE=1 python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-R1 --trust-remote-code --enable-dp-attention --tp 8 --dp 8
client: python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1319 --parallel 1319

Accuracy: 0.955
Invalid: 0.000
Latency: 477.968 s

cc @depaulmillz @kushanam

Modifications

Add Python Class for CUTLASS MoE.
This PR also moves all the tensor allocations outside of the kernel implementation.

Checklist

@elfiegg
Copy link
Collaborator Author

elfiegg commented May 14, 2025

Hi all, can somebody take a look at the PR as we have tested both standalone layer and e2e, and it's now ready to integrate

@zhyncs
Copy link
Member

zhyncs commented May 15, 2025

Hi @elfiegg may you help upgrade CUTLASS version on top of #6272

@elfiegg
Copy link
Collaborator Author

elfiegg commented May 15, 2025

Done: #6336 @zhyncs

@Fridge003
Copy link
Collaborator

Fridge003 commented May 16, 2025

Hi @elfiegg , what's your command for server launching when testing?
Also you mentioned the cutlass kernel is slower than triton kernel when batch_size is large. Do you have any plan of implementing cutlass moe kernels for large batch sizes?

@zhyncs
Copy link
Member

zhyncs commented May 16, 2025

Done: #6336 @zhyncs

Great work!! It has been merged.

Copy link
Collaborator

@Fridge003 Fridge003 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wonderful Work!

@elfiegg
Copy link
Collaborator Author

elfiegg commented May 16, 2025

Hey @Fridge003, I used below command for testing, let me know if you run into any problems:

CUTLASS_MOE=1 python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-R1 --trust-remote-code --enable-dp-attention --tp 8 --dp 8 

For large batch sizes, I assume you're referring to potential optimizations - if so, yes, I believe they will be addressed in the long run. In the short term, we also have a plan to integrate TRT-LLM MoE kernels. Ideally, different MoE backends could each accelerate their own specialties. Hopefully this makes sense to you!

@Fridge003
Copy link
Collaborator

Fridge003 commented May 16, 2025

Hey @Fridge003, I used below command for testing, let me know if you run into any problems:

CUTLASS_MOE=1 python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-R1 --trust-remote-code --enable-dp-attention --tp 8 --dp 8 

For large batch sizes, I assume you're referring to potential optimizations - if so, yes, I believe they will be addressed in the long run. In the short term, we also have a plan to integrate TRT-LLM MoE kernels. Ideally, different MoE backends could each accelerate their own specialties. Hopefully this makes sense to you!

Thank you for your answer ! Maybe there could be some logic that chooses the better moe kernel as batch size changes.

@elfiegg
Copy link
Collaborator Author

elfiegg commented May 16, 2025

Agree @Fridge003, maybe let's roll out this CUTLASS logic with a server flag first? There might be other implementations like Hopper support and optimizations like epilogue fusions before we actually get to a point where we can confidently and comfortably route traffic based upon performance

@Fridge003
Copy link
Collaborator

Agree @Fridge003, maybe let's roll out this CUTLASS logic with a server flag first? There might be other implementations like Hopper support and optimizations like epilogue fusions before we actually get to a point where we can confidently and comfortably route traffic based upon performance

This PR can be merged first. Routing logic can be left for future PRs.

@zhyncs zhyncs merged commit 6fc9357 into sgl-project:main May 16, 2025
42 of 66 checks passed
Layssy pushed a commit to Layssy/sglang-iaas that referenced this pull request Jun 9, 2025
@ispobock
Copy link
Collaborator

@elfiegg What's the device did you run the benchmark? In B200, it seems slower than Triton implementation.

Running benchmarks with TP size: 8
Testing batch sizes: [1, 4, 8, 16, 32, 64, 128, 256, 512]
Model Config: {'num_experts': 256, 'topk': 8, 'hidden_size': 7168, 'shard_intermediate_size': 512, 'dtype': torch.bfloat16, 'block_shape': [128, 128]}

--- Batch Size: 1 ---
Config: E=256, topk=8, H=7168, I_shard=512, dtype=torch.bfloat16, block_shape=[128, 128]
Warming up...
Benchmarking Cutlass fused_experts...
Benchmarking Triton fused_experts...
Cutlass fused_experts time: 0.064 ms (median) [0.064 - 0.064]
Triton  fused_experts time: 0.059 ms (median) [0.059 - 0.059]

--- Batch Size: 4 ---
Config: E=256, topk=8, H=7168, I_shard=512, dtype=torch.bfloat16, block_shape=[128, 128]
Warming up...
Benchmarking Cutlass fused_experts...
Benchmarking Triton fused_experts...
Cutlass fused_experts time: 0.099 ms (median) [0.099 - 0.099]
Triton  fused_experts time: 0.092 ms (median) [0.092 - 0.092]

--- Batch Size: 8 ---
Config: E=256, topk=8, H=7168, I_shard=512, dtype=torch.bfloat16, block_shape=[128, 128]
Warming up...
Benchmarking Cutlass fused_experts...
Benchmarking Triton fused_experts...
Cutlass fused_experts time: 0.138 ms (median) [0.138 - 0.138]
Triton  fused_experts time: 0.127 ms (median) [0.127 - 0.127]

--- Batch Size: 16 ---
Config: E=256, topk=8, H=7168, I_shard=512, dtype=torch.bfloat16, block_shape=[128, 128]
Warming up...
Benchmarking Cutlass fused_experts...
Benchmarking Triton fused_experts...
Cutlass fused_experts time: 0.206 ms (median) [0.206 - 0.206]
Triton  fused_experts time: 0.149 ms (median) [0.149 - 0.149]

--- Batch Size: 32 ---
Config: E=256, topk=8, H=7168, I_shard=512, dtype=torch.bfloat16, block_shape=[128, 128]
Warming up...
Benchmarking Cutlass fused_experts...
Benchmarking Triton fused_experts...
Cutlass fused_experts time: 0.312 ms (median) [0.312 - 0.312]
Triton  fused_experts time: 0.246 ms (median) [0.246 - 0.246]

--- Batch Size: 64 ---
Config: E=256, topk=8, H=7168, I_shard=512, dtype=torch.bfloat16, block_shape=[128, 128]
Warming up...
Benchmarking Cutlass fused_experts...
Benchmarking Triton fused_experts...
Cutlass fused_experts time: 0.384 ms (median) [0.384 - 0.384]
Triton  fused_experts time: 0.280 ms (median) [0.280 - 0.280]

--- Batch Size: 128 ---
Config: E=256, topk=8, H=7168, I_shard=512, dtype=torch.bfloat16, block_shape=[128, 128]
Warming up...
Benchmarking Cutlass fused_experts...
Benchmarking Triton fused_experts...
Cutlass fused_experts time: 0.431 ms (median) [0.431 - 0.431]
Triton  fused_experts time: 0.312 ms (median) [0.312 - 0.312]

--- Batch Size: 256 ---
Config: E=256, topk=8, H=7168, I_shard=512, dtype=torch.bfloat16, block_shape=[128, 128]
Warming up...
Benchmarking Cutlass fused_experts...
Benchmarking Triton fused_experts...
Cutlass fused_experts time: 0.448 ms (median) [0.448 - 0.448]
Triton  fused_experts time: 0.335 ms (median) [0.335 - 0.335]

--- Batch Size: 512 ---
Config: E=256, topk=8, H=7168, I_shard=512, dtype=torch.bfloat16, block_shape=[128, 128]
Warming up...
Benchmarking Cutlass fused_experts...
Benchmarking Triton fused_experts...
Cutlass fused_experts time: 0.477 ms (median) [0.477 - 0.477]
Triton  fused_experts time: 0.416 ms (median) [0.414 - 0.418]

@ispobock
Copy link
Collaborator

And I am facing the error when I launch the server with SGLANG_CUTLASS_MOE=1:

[2025-06-20 13:21:12 TP5] Scheduler hit an exception: Traceback (most recent call last):
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 318, in __init__
    self.capture()
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 418, in capture
    ) = self.capture_one_batch_size(bs, forward)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 559, in capture_one_batch_size
    run_once()
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 547, in run_once
    logits_output_or_pp_proxy_tensors = forward(
                                        ^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 1760, in forward
    hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 1652, in forward
    hidden_states, residual = layer(
                              ^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 1499, in forward
    hidden_states = self.mlp(hidden_states, forward_batch)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 335, in forward
    return self.forward_normal(hidden_states)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 343, in forward_normal
    final_hidden_states = self.experts(
                          ^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/layers/moe/fused_moe_triton/layer.py", line 674, in forward
    final_hidden_states = self.quant_method.apply(
                          ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/layers/quantization/fp8.py", line 988, in apply
    return cutlass_fused_experts_fp8(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/layers/moe/cutlass_moe.py", line 212, in cutlass_fused_experts_fp8
    return apply_shuffle_mul_sum(c2, result, c_map, topk_weights)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/sgl_kernel/moe.py", line 205, in apply_shuffle_mul_sum
    torch.ops.sgl_kernel.apply_shuffle_mul_sum.default(
  File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 756, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Factors must match output dtype

@elfiegg
Copy link
Collaborator Author

elfiegg commented Jun 23, 2025

@ispobock thanks for reporting. It looks like the error is from kernel apply_shuffle_mul_sum() - I will improve the error handling there. The error is saying your c2 and topk_weights are not having the same dtype - see the link here To get unblocked, you can quickly check if that's the case and kindly cast topk_weights to c2 dtype as a workaround.

@elfiegg
Copy link
Collaborator Author

elfiegg commented Jun 23, 2025

@elfiegg What's the device did you run the benchmark? In B200, it seems slower than Triton implementation.

I was on a B200 600W machine. My container might have unoptimized triton settings. But I did notice for small batches, where it's mostly memory-bound, Triton performed pretty well. Can you help benchmarking larger batch sizes (2048, 4096, 8192 etc) on your machine?

@ispobock
Copy link
Collaborator

@elfiegg There are some PRs to fix the issue: #7444, #7442 Could you help check?

cc: @mickqian @zhyncs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants