Skip to content

Conversation

pavanimajety
Copy link
Collaborator

Motivation

Refactors Cutlass MoE to keep the interface cleaner.

Modifications

Introduces CutlassMoEParams Class that creates all the cutlass metadata based on the shape of intermediate size and hidden shape

Checklist

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @pavanimajety, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

Summary of Changes

Hello team, Gemini here with a summary of this pull request. The main goal of this PR, as indicated by the title and description, is to refactor the Cutlass MoE implementation to improve the interface cleanliness. This is achieved by introducing a dedicated CutlassMoEParams class to encapsulate the various metadata required for the Cutlass Grouped GEMM operations, rather than passing numerous individual parameters.

Highlights

  • New Parameter Class: Introduced a new dataclass CutlassMoEParams to hold all the necessary parameters (strides, problem sizes, offsets, dimensions, device) for Cutlass MoE operations, simplifying function signatures.
  • Refactored Cutlass MoE Functions: The core cutlass_moe_fp4 function in cutlass_moe.py and the underlying kernel call cutlass_fp4_group_mm in sgl_kernel/moe.py were updated to accept the new CutlassMoEParams object or a dictionary derived from it, reducing the number of direct arguments.
  • NVFP4 MoE Integration: Added a new ModelOptNvFp4FusedMoEMethod class within the modelopt_quant.py framework to specifically handle the creation, processing, and application of weights for NVFP4 quantized MoE layers using the refactored Cutlass functions.
  • Code Cleanups and Updates: Minor cleanups include renaming an FP8 function (cutlass_fused_experts to cutlass_fused_experts_fp8), updating tests to use the new interfaces, and enhancing the Deepseek model loading logic to support FP4 MoE weights.

Changelog

Click here to see the changelog
  • python/sglang/srt/layers/moe/cutlass_moe.py
    • Imported the new CutlassMoEParams class.
    • Renamed cutlass_fused_experts to cutlass_fused_experts_fp8 (Diff 2).
    • Refactored cutlass_moe_fp4 to accept a params: CutlassMoEParams object instead of individual stride, size, and offset tensors (Diff 3).
    • Updated the internal logic of cutlass_moe_fp4 to use parameters from the params object and call params.to_gemm1_args() and params.to_gemm2_args() for the kernel calls (Diff 4).
    • Added shuffle_rows import and usage in cutlass_moe_fp4 for the final output (Diff 1, Diff 4).
  • python/sglang/srt/layers/moe/cutlass_moe_params.py
    • Added a new file defining CutlassMoEType enum and CutlassMoEParams dataclass.
    • CutlassMoEParams encapsulates MoE parameters like strides, problem sizes, offsets, dimensions (m, n, k, e), and device.
    • Includes an __init__ method to initialize these parameters based on basic dimensions.
    • Provides to_gemm1_args() and to_gemm2_args() methods to return dictionaries suitable for the kernel calls.
  • python/sglang/srt/layers/moe/fused_moe_triton/layer.py
    • Modified weight loading logic to check for 'compressed' in the quant method name before raising an error (Diff 1).
    • Added a new block in weight_loader to handle 'ModelOpt' quantization specifically for weight and input scales (Diff 2).
  • python/sglang/srt/layers/quantization/fp8.py
    • Updated import and usage of the renamed FP8 MoE function from cutlass_fused_experts to cutlass_fused_experts_fp8.
  • python/sglang/srt/layers/quantization/modelopt_quant.py
    • Imported CutlassMoEParams and CutlassMoEType (Diff 1).
    • Imported is_sm100_supported and is_layer_skipped (Diff 2).
    • Modified ModelOptFp4Config.from_config to handle empty kv_cache_quant_algo and add a warning for missing config parameters (Diff 3).
    • Added is_layer_excluded method using regex for module exclusion (Diff 4).
    • Updated get_quant_method to use the new exclusion logic and include ModelOptNvFp4FusedMoEMethod for FusedMoE layers (Diff 4).
    • Added the ModelOptNvFp4FusedMoEMethod class to handle NVFP4 MoE quantization, including weight creation, processing (swizzling scales, calculating alphas), and calling the refactored cutlass_moe_fp4 with the new params object (Diff 5).
  • python/sglang/srt/models/deepseek_v2.py
    • Updated a log message to mention FP4 support for shared experts fusion (Diff 1).
    • Added a blank line in load_weights (Diff 2).
    • Added logic in load_weights to handle 'modelopt_fp4' quantization for shared expert fusion, specifying the relevant weight names (Diff 3).
    • Fixed a small issue in the logic for determining the parameter name for fused QKV/A proj weights (Diff 5).
  • python/sglang/test/test_cutlass_moe.py
    • Updated import and usage of the renamed FP8 MoE function from cutlass_fused_experts to cutlass_fused_experts_fp8.
  • python/sglang/test/test_fp4_moe.py
    • Imported CutlassMoEParams and CutlassMoEType (Diff 1).
    • Refactored test_cutlass_fp4_moe_no_graph to create and pass a CutlassMoEParams object to cutlass_moe_fp4, removing individual parameter arguments (Diff 2, Diff 3).
  • sgl-kernel/python/sgl_kernel/moe.py
    • Imported Dict and Any (Diff 1).
    • Refactored cutlass_fp4_group_mm to accept a single params: Dict[str, Any] argument instead of multiple individual parameters (Diff 2).
    • Updated the internal kernel call to access parameters from the params dictionary (Diff 3).
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.


Parameters galore,
A class to hold them all,
MoE code is clean.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively refactors the Cutlass MoE (Mixture of Experts) interface by introducing the CutlassMoEParams class, which significantly cleans up function signatures and improves code organization. A major addition is the ModelOptNvFp4FusedMoEMethod, enabling support for NVFP4 MoE quantization. The changes are generally well-structured and consistent. I've identified a few areas for potential improvement and clarification, detailed below.

Summary of Findings

  • Unused/Uninitialized m field in CutlassMoEParams: The m field (total tokens) in CutlassMoEParams is declared but not initialized via the constructor and appears unused, potentially causing confusion.
  • Removed Assertion for Input Shapes: An assertion checking the consistency between the input activation shape and topk_weights/topk_ids shapes was removed from cutlass_moe_fp4. Reinstating a similar check against m_a could improve robustness.
  • Clarity of Layer Exclusion Logic: The use of two separate functions (is_layer_skipped and is_layer_excluded) for determining if a layer should be skipped in ModelOptFp4Config could be clarified regarding their distinct purposes and the expected format of exclude_modules.
  • Commented-out code in CutlassMoEParams: The to_gemm1_args and to_gemm2_args methods in CutlassMoEParams contain commented-out pointer arguments (e.g., a_ptrs, b_ptrs). If these are not intended for future use, they should be removed. (Not commented due to review settings: severity low)
  • Clarity of swizzle_blockscale logic: The swizzle_blockscale method in ModelOptNvFp4FusedMoEMethod contains complex permutation logic. Additional comments explaining the swizzling pattern could improve maintainability. (Not commented due to review settings: severity low)

Merge Readiness

This pull request introduces valuable refactoring and new functionality for FP4 MoE. However, there are a few medium-severity issues identified regarding the CutlassMoEParams.m field, a removed assertion, and the clarity of layer exclusion logic. I recommend addressing these points to enhance code clarity and robustness before merging. As an AI, I am not authorized to approve pull requests; please ensure further review and approval by team members.

@pavanimajety
Copy link
Collaborator Author

pavanimajety commented Jun 5, 2025

@zhyncs The failed tests are unrelated to the PR. Please merge if you think it is good to go

@elfiegg
Copy link
Collaborator

elfiegg commented Jun 5, 2025

LGTM

@kushanam kushanam enabled auto-merge (squash) June 5, 2025 20:09
@zhyncs zhyncs disabled auto-merge June 5, 2025 20:13
@zhyncs zhyncs merged commit 0df6765 into sgl-project:main Jun 5, 2025
jianan-gu pushed a commit to jianan-gu/sglang that referenced this pull request Jun 12, 2025
…ation of Cutlass Grouped Gems Metadata (sgl-project#6887)

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
xwu-intel pushed a commit to xwu-intel/sglang that referenced this pull request Jun 17, 2025
…ation of Cutlass Grouped Gems Metadata (sgl-project#6887)

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
walker-ai pushed a commit to walker-ai/sglang that referenced this pull request Jul 8, 2025
Merge branch 'sgl_20250610_sync_tag047 of git@code.alipay.com:Theta/SGLang.git into main

https://code.alipay.com/Theta/SGLang/pull_requests/52


Reviewed-by: 剑川 <jianchuan.gys@antgroup.com>


* [Bugfix] Fix slice operation when chunk size mismatch (sgl-project#6697)
* [Bugfix] Fix ChatCompletion endpoint of mini_lb when stream is set (sgl-project#6703)
* [CI] Fix setup of disaggregation with different tp (sgl-project#6706)
* [PD] Remove Unnecessary Exception Handling for FastQueue.get() (sgl-project#6712)
* Fuse routed_scaling_factor in DeepSeek (sgl-project#6710)
* Overlap two kernels in DeepSeek with communication (sgl-project#6711)
* Minor refactor two-batch overlap (sgl-project#6682)
* Speed up when having padding tokens two-batch overlap (sgl-project#6668)
* [Feature] Support Flashinfer fp8 blockwise GEMM kernel on Blackwell (sgl-project#6479)
* Fix LoRA bench (sgl-project#6719)
* temp
* Fix PP for Qwen3 MoE (sgl-project#6709)
* [feat] triton kernel for get_last_loc (sgl-project#6676)
* [fix] more mem for draft_extend cuda_graph (sgl-project#6726)
* [PD] bug fix:  Update status if nixl receiver send a a dummy req. (sgl-project#6720)
* Tune memory arguments on B200 (sgl-project#6718)
* Add DeepSeek-R1-0528 function call chat template (sgl-project#6725)
* refactor(tool call): Fix BaseFormatDetector tool_index issue and refactor `parse_streaming_increment` (sgl-project#6715)
* Add draft extend CUDA graph for Triton backend (sgl-project#6705)
* refactor apply_w8a8_block_fp8_linear in fp (sgl-project#6545)
* [PD] Support completion endpoint (sgl-project#6729)
* PD Rust LB (PO2) (sgl-project#6437)
* Super tiny enable sole usage of expert distribution metrics and update doc (sgl-project#6680)
* Support picking variants of EPLB algorithms (sgl-project#6728)
* Support tuning DeepEP configs (sgl-project#6742)
* [test] add ut and bm for get_last_loc (sgl-project#6746)
* Fix mem_fraction_static for AMD CI (sgl-project#6748)
* [fix][RL] Fix DeepSeekV3ForCausalLM.post_load_weights for multiple update weight (sgl-project#6265)
* Improve EPLB logical to physical dispatch map (sgl-project#6727)
* Update DeepSeek-R1-0528 function call chat template (sgl-project#6765)
* [PD] Optimize time out logic and add env var doc for mooncake (sgl-project#6761)
* Fix aiohttp 'Chunk too big' in bench_serving (sgl-project#6737)
* Support sliding window in triton backend (sgl-project#6509)
* Fix shared experts fusion error (sgl-project#6289)
* Fix one bug in the grouped-gemm triton kernel (sgl-project#6772)
* update llama4 chat template and pythonic parser (sgl-project#6679)
* feat(tool call): Enhance Llama32Detector for improved JSON parsing in non-stream (sgl-project#6784)
* Support token-level quantization for EP MoE (sgl-project#6782)
* Temporarily lower mmlu threshold for triton sliding window backend (sgl-project#6785)
* ci: relax test_function_call_required (sgl-project#6786)
* Add intel_amx backend for Radix Attention for CPU (sgl-project#6408)
* Fix incorrect LoRA weight loading for fused gate_up_proj (sgl-project#6734)
* fix(PD-disaggregation): Can not get local ip (sgl-project#6792)
* [FIX] mmmu bench serving result display error (sgl-project#6525) (sgl-project#6791)
* Bump torch to 2.7.0 (sgl-project#6788)
* chore: bump sgl-kernel v0.1.5 (sgl-project#6794)
* Improve profiler and integrate profiler in bench_one_batch_server (sgl-project#6787)
* chore: upgrade sgl-kernel v0.1.5 (sgl-project#6795)
* [Minor] Always append newline after image token when parsing chat message (sgl-project#6797)
* Update CI tests for Llama4 models (sgl-project#6421)
* [Feat] Enable PDL automatically on Hopper architecture (sgl-project#5981)
* chore: update blackwell docker (sgl-project#6800)
* misc: cache is_hopper_arch (sgl-project#6799)
* Remove contiguous before Flashinfer groupwise fp8 gemm (sgl-project#6804)
* Correctly abort the failed grammar requests & Improve the handling of abort (sgl-project#6803)
* [EP] Add cuda kernel for moe_ep_pre_reorder (sgl-project#6699)
* Add draft extend CUDA graph for flashinfer backend  (sgl-project#6805)
* Refactor CustomOp to avoid confusing bugs (sgl-project#5382)
* Tiny log prefill time (sgl-project#6780)
* Tiny fix EPLB assertion about rebalancing period and recorder window size (sgl-project#6813)
* Add simple utility to dump tensors for debugging (sgl-project#6815)
* Fix profiles do not have consistent names (sgl-project#6811)
* Speed up rebalancing when using non-static dispatch algorithms (sgl-project#6812)
* [1/2] Add Kernel support for Cutlass based Fused FP4 MoE (sgl-project#6093)
* [Router] Fix k8s Service Discovery (sgl-project#6766)
* Add CPU optimized kernels for topk and rope fusions  (sgl-project#6456)
* fix new_page_count_next_decode (sgl-project#6671)
* Fix wrong weight reference in dynamic EPLB (sgl-project#6818)
* Minor add metrics to expert location updater (sgl-project#6816)
* [Refactor] Rename `n_share_experts_fusion` as `num_fused_shared_experts` (sgl-project#6735)
* [FEAT] Add transformers backend support  (sgl-project#5929)
* [fix] recover auto-dispatch for rmsnorm and rope (sgl-project#6745)
* fix ep_moe_reorder kernel bugs (sgl-project#6858)
* [Refactor] Multimodal data processing for VLM (sgl-project#6659)
* Decoder-only Scoring API (sgl-project#6460)
* feat: add dp-rank to KV events (sgl-project#6852)
* Set `num_fused_shared_experts` as `num_shared_experts` when shared_experts fusion is not disabled (sgl-project#6736)
* Fix one missing arg in DeepEP (sgl-project#6878)
* Support LoRA in TestOpenAIVisionServer and fix fused kv_proj loading bug. (sgl-project#6861)
* support 1 shot allreduce  in 1-node and 2-node using mscclpp (sgl-project#6277)
* Fix Qwen3MoE missing token padding optimization (sgl-project#6820)
* Tiny update error hints (sgl-project#6846)
* Support layerwise rebalancing experts (sgl-project#6851)
* Tiny allow profiler API to auto create directory (sgl-project#6865)
* Support Blackwell DeepEP docker images (sgl-project#6868)
* [EP] Add cuda kernel for moe_ep_post_reorder (sgl-project#6837)
* [theta]merge 0605
* oai: fix openAI client error with single request via batch api (sgl-project#6170)
* [PD] Fix potential perf spike caused by tracker gc and optimize doc (sgl-project#6764)
* Use deepgemm instead of triton for fused_qkv_a_proj_with_mqa (sgl-project#6890)
* [CUTLASS-FP4-MOE]  Introduce CutlassMoEParams class for easy initialization of Cutlass Grouped Gems Metadata (sgl-project#6887)
* bugfix(OAI): Fix image_data processing for jinja chat templates (sgl-project#6877)
* [CPU] enable CI for PRs, add Dockerfile and auto build task (sgl-project#6458)
* AITER backend extension and workload optimizations (sgl-project#6838)
* [theta]merge
* [theta]merge
* [Feature] Support Flashinfer fmha on Blackwell (sgl-project#6930)
* Fix a bug in abort & Improve docstrings for abort (sgl-project#6931)
* Tiny support customize DeepEP max dispatch tokens per rank (sgl-project#6934)
* Sync the changes on cuda graph runners (sgl-project#6932)
* [PD] Optimize transfer queue forward logic for dummy rank (sgl-project#6922)
* [Refactor] image data process in bench_serving (sgl-project#6879)
* [fix] logical_to_all_physical_map index 256 is out of bounds in EP parallel. (sgl-project#6767)
* Add triton fused moe kernel config for E=257 on B200 (sgl-project#6939)
* [sgl-kernel] update deepgemm (sgl-project#6942)
* chore: bump sgl-kernel v0.1.6 (sgl-project#6943)
* Minor compile fused topk (sgl-project#6944)
* [Bugfix] pipeline parallelism and Eagle Qwen2 (sgl-project#6910)
* Tiny re-introduce profile id logging (sgl-project#6912)
* Add triton version as a fused_moe_triton config search key to avoid performace decrease in different Triton version (sgl-project#5955)
* reduce torch.zeros overhead in moe align block size kernel (sgl-project#6369)
* chore: upgrade sgl-kernel v0.1.6 (sgl-project#6945)
* add fbgemm moe grouped gemm kernel benchmark (sgl-project#6924)
* [Docker] Add docker file for SGL Router (sgl-project#6915)
* Disabling mixed chunked prefill when eagle is enabled (sgl-project#6874)
* Add canary for EPLB rebalancing (sgl-project#6895)
* Refactor global_server_args_dict (sgl-project#6866)
* Fuse routed scaling factor in topk_reduce kernel (sgl-project#6220)
* Update server timeout time in AMD CI. (sgl-project#6953)
* [misc] add is_cpu() (sgl-project#6950)
* Add H20 fused MoE kernel tuning configs for DeepSeek-R1/V3 (sgl-project#6885)
* Add a CUDA kernel for fusing mapping and weighted sum for MoE. (sgl-project#6916)
* chore: bump sgl-kernel v0.1.6.post1 (sgl-project#6955)
* chore: upgrade sgl-kernel v0.1.6.post1 (sgl-project#6957)
* [DeepseekR1-FP4] Add Support for nvidia/DeepSeekR1-FP4 model (sgl-project#6853)
* Revert "Fuse routed scaling factor in topk_reduce kernel (sgl-project#6220)" (sgl-project#6968)
* [AMD] Add more tests to per-commit-amd (sgl-project#6926)
* chore: bump sgl-kernel v0.1.7 (sgl-project#6963)
* Slightly improve the sampler to skip unnecessary steps (sgl-project#6956)
* rebase h20 fused_moe config (sgl-project#6966)
* Fix CI and triton moe Configs (sgl-project#6974)
* Remove unnecessary kernels of num_token_non_padded (sgl-project#6965)
* Extend cuda graph capture bs for B200 (sgl-project#6937)
* Fuse routed scaling factor in deepseek (sgl-project#6970)
* Sync cuda graph runners (sgl-project#6976)
* Fix draft extend ut stability with flush cache (sgl-project#6979)
* Fix triton sliding window test case (sgl-project#6981)
* Fix expert distribution dumping causes OOM (sgl-project#6967)
* Minor remove one kernel for DeepSeek (sgl-project#6977)
* [perf][sgl-kernel] extend cutlass_mla_decode to support num_head < 128 (sgl-project#6929)
* Enable more unit tests for AMD CI. (sgl-project#6983)
* Use torch.compile to fuse flash attention decode metadata preparation (sgl-project#6973)
* Eliminate stream sync to speed up LoRA batch init  (sgl-project#6960)
* support qwen3 emebedding (sgl-project#6990)
* Fix torch profiler bugs for bench_offline_throughput.py (sgl-project#6557)
* chore: upgrade flashinfer v0.2.6.post1 jit (sgl-project#6958)
* cleanup tmp dir (sgl-project#7007)
* chore: update pr test xeon (sgl-project#7008)
* Fix cutlass MLA gets almost zero accuracy (sgl-project#6998)
* Update amd nightly models CI. (sgl-project#6992)
* feat: add direct routing strategy to DP worker (sgl-project#6884)
* Fallback to lower triton version for unfound fused moe configs (sgl-project#7013)
* Fix torchvision version for Blackwell (sgl-project#7015)
* Simplify prepare_extend_after_decode (sgl-project#6987)
* Migrate to assertEqual (sgl-project#6741)
* Fix torch version in blackwell dockerfile (sgl-project#7017)
* chore: update pr test xeon (sgl-project#7018)
* Update default settings for blackwell (sgl-project#7023)
* Support both approximate and exact expert distribution collection (sgl-project#6964)
* Add decode req pool (sgl-project#6980)
* [theta]merge 0610
* [theta]merge 0610
* [CI] Add CI workflow for sgl-router docker build (sgl-project#7027)
* Fix fused_moe triton configs (sgl-project#7029)
* CPU: map changes from developing branch in sgl-kernel (sgl-project#6833)
* chore: bump v0.4.7 (sgl-project#7038)
* Update README.md (sgl-project#7040)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants