Skip to content

Conversation

trevor-m
Copy link
Collaborator

@trevor-m trevor-m commented Jul 25, 2025

Motivation

Follow up to #8333
Fuse the multiply by routed_scaling_factor into select_experts, following example of TRT-LLM:
https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/models/modeling_deepseekv3.py#L323
https://github.com/NVIDIA/TensorRT-LLM/blob/738ab615930fd08dccb94fa388bd74dc91c5f235/cpp/tensorrt_llm/kernels/noAuxTcKernels.cu#L651

For the non-FP4 paths, the routed_scaling_factor is fused into moe_sum_reduce. However, we could move it into select_experts for those paths too if we wanted to simplify the code.

Modifications

Add boolean argument to fused_moe_gate()

Results

Prefill:
10.46% speedup at BS 1
1.86% speedup at BS 128
Decode:
1.22% speedup at BS1
0.26% speedup at BS128

Server command

python3 -m sglang.launch_server --model-path nvidia/DeepSeek-R1-0528-FP4 --trust-remote-code --quantization modelopt_fp4 --tp 8 --enable-flashinfer-cutlass-moe --enable-ep-moe --ep-size 8

BEFORE results

python3 -m sglang.bench_one_batch_server --model-path nvidia/DeepSeek-R1-0528-FP4 --base-url http://127.0.0.1:30000/ --batch-size 1 --input-len 1024 --output-len 1024

#Input tokens: 1024
#Output tokens: 1024
batch size: 1
input_len: 1024
output_len: 1024
latency: 10.60 s
ttft: 0.11 s
last generation throughput: 96.61 tok/s
input throughput: 9392.46 tok/s
output throughput: 97.62 tok/s

python3 -m sglang.bench_one_batch_server --model-path nvidia/DeepSeek-R1-0528-FP4 --base-url http://127.0.0.1:30000/ --batch-size 128 --input-len 1024 --output-len 1024

#Input tokens: 131072
#Output tokens: 131072
batch size: 128
input_len: 1024
output_len: 1024
latency: 39.79 s
ttft: 7.56 s
last generation throughput: 3693.98 tok/s
input throughput: 17344.15 tok/s
output throughput: 4066.24 tok/s

AFTER results

python3 -m sglang.bench_one_batch_server --model-path nvidia/DeepSeek-R1-0528-FP4 --base-url http://127.0.0.1:30000/ --batch-size 1 --input-len 1024 --output-len 1024

#Input tokens: 1024
#Output tokens: 1024
batch size: 1
input_len: 1024
output_len: 1024
latency: 10.46 s
ttft: 0.10 s
last generation throughput: 97.88 tok/s
input throughput: 10375.70 tok/s
output throughput: 98.82 tok/s

python3 -m sglang.bench_one_batch_server --model-path nvidia/DeepSeek-R1-0528-FP4 --base-url http://127.0.0.1:30000/ --batch-size 128 --input-len 1024 --output-len 1024

#Input tokens: 131072
#Output tokens: 131072
batch size: 128
input_len: 1024
output_len: 1024
latency: 39.57 s
ttft: 7.42 s
last generation throughput: 3698.44 tok/s
input throughput: 17667.47 tok/s
output throughput: 4076.75 tok/s

Accuracy

python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1319 --parallel 1319 --port=30000

# Before
Accuracy: 0.960
Invalid: 0.000
Latency: 23.902 s
Output throughput: 6057.485 token/s

# After
Accuracy: 0.961
Invalid: 0.000
Latency: 23.413 s
Output throughput: 6128.773 token/s

Checklist

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @trevor-m, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a performance optimization for Mixture of Experts (MoE) layers, particularly for FP4 quantized models. The core change involves fusing the application of a routed_scaling_factor directly into the expert selection kernel, which should lead to more efficient computations. This work is a direct follow-up to a previous related pull request.

Highlights

  • MoE Optimization: I've implemented a fusion optimization for Mixture of Experts (MoE) layers, specifically targeting FP4 quantized models. This change moves the application of the routed_scaling_factor directly into the expert selection process.
  • Scaling Factor Fusion Logic: The routed_scaling_factor is now conditionally applied within the CUDA kernels responsible for moe_fused_gate (expert selection) during the expert weight normalization. This aims to improve computational efficiency by reducing separate operations.
  • Conditional Activation: A new boolean parameter, apply_routed_scaling_factor_on_output, has been introduced to control this fusion. It is specifically enabled for FusedMoE experts using ModelOptNvFp4FusedMoEMethod in models like DeepSeekV2.
  • Shared Expert Fusion Compatibility: The previous restriction that disabled shared expert fusion when Flashinfer MoE was enabled has been removed, potentially allowing for more flexible and optimized configurations.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to fuse the routed_scaling_factor multiplication into the select_experts function, particularly for FP4 MoE, to improve performance. The changes involve adding a new boolean parameter, apply_routed_scaling_factor_on_output, which is propagated from the Python layers down to the CUDA kernels to control this fused operation.

Overall, the changes are logical and consistently implemented across the Python and C++/CUDA codebases. My review focuses on improving code robustness, maintainability, and readability. Key suggestions include:

  • Replacing string-based type checking with isinstance for better maintainability.
  • Using NotImplementedError instead of assert for unimplemented paths to prevent silent failures in production.
  • Minor refactoring in both Python and CUDA code to improve clarity and reduce redundancy.

These improvements will make the code more robust and easier to maintain in the future.

@trevor-m trevor-m force-pushed the fuse-routed-scaling branch from 0b7967e to 7626d88 Compare July 28, 2025 19:40
@trevor-m trevor-m changed the title Draft: Fuse routed scaling factor into select_experts for FP4 MoE Fuse routed scaling factor into select_experts for FP4 MoE Jul 28, 2025
@trevor-m trevor-m force-pushed the fuse-routed-scaling branch from 7626d88 to acede7d Compare July 28, 2025 19:48
@trevor-m
Copy link
Collaborator Author

@merrymercy Could you please take a look? Thanks
cc @BBuf

@pavanimajety
Copy link
Collaborator

pavanimajety commented Jul 30, 2025

Thanks for the PR! Can we also do sanity checks for TP cases(i.e., --enable-ep-moe not passed in) and with FP8 model, to verify that there is no accuracy regression?

@trevor-m trevor-m force-pushed the fuse-routed-scaling branch from fb5135a to 44a2e5b Compare August 1, 2025 00:16
Copy link
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

Copy link
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@trevor-m
Copy link
Collaborator Author

trevor-m commented Aug 1, 2025

SGL_CUTLASS_MOE=1 FP8 path was giving 0.832 accuracy because we weren't multiplying by routed_scaling_Factor. I enabled the fusion to fix it and now it gives 0.965.

@trevor-m trevor-m force-pushed the fuse-routed-scaling branch from d4325ea to 7b1b7d4 Compare August 1, 2025 04:07
@trevor-m trevor-m force-pushed the fuse-routed-scaling branch from c57b2d7 to 14123dd Compare August 2, 2025 00:10
@trevor-m trevor-m changed the title Fuse routed scaling factor into select_experts for FP4 MoE [1/2] sgl-kernel: Fuse routed scaling factor into select_experts Aug 2, 2025
@merrymercy merrymercy merged commit f642524 into sgl-project:main Aug 2, 2025
50 of 53 checks passed
@hnyls2002
Copy link
Collaborator

hnyls2002 commented Aug 2, 2025

@trevor-m Please fix the broken unit test for the fused gate.

TypeError: biased_grouped_topk_gpu() got an unexpected keyword argument 'apply_routed_scaling_factor_on_output'

hnyls2002 added a commit that referenced this pull request Aug 2, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants