Skip to content

Conversation

yuan-luo
Copy link
Contributor

@yuan-luo yuan-luo commented Jun 3, 2025

Motivation

moe_post_reorder is one of the important kernels in EP MoE.
Currently moe_post_reorder is using triton kernel. This PR is to introduce CUDA implementation for this kernel.
The new kernel is expected to gain performance improvement.

Modifications

Checklist

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @yuan-luo, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

Summary of Changes

Hello! Gemini here, providing a summary of this pull request. This PR, marked as Work In Progress by @yuan-luo, aims to replace the existing Triton kernel for moe_post_reorder with a new CUDA implementation. The primary motivation is to achieve performance improvements for this important kernel within the EP MoE (Expert Parallelism Mixture of Experts) framework. The changes involve adding the core CUDA kernel logic and the necessary C++ and Python wrappers to integrate it into the existing codebase.

Highlights

  • New CUDA Kernel: Introduces a new CUDA kernel, ep_post_reorder_cuda_kernel, to handle the post-reordering step in the EP MoE process. This kernel is designed to process the output from experts, applying weights and reordering the results per token.
  • C++ and Python Wrappers: Adds corresponding C++ (ep_moe_post_reorder) and Python (ep_moe_post_reorder) wrapper functions to expose the new CUDA kernel functionality to the Python frontend, allowing it to be called from the SGLang framework.
  • Performance Focus: The explicit goal stated in the PR description is to improve performance compared to the current Triton implementation.

Changelog

Click here to see the changelog
  • sgl-kernel/csrc/moe/ep_moe_reorder_kernel.cu
    • Added the ep_post_reorder_cuda_kernel implementing the post-reorder logic using CUDA.
    • Added the C++ wrapper function ep_moe_post_reorder to launch the new CUDA kernel.
    • Minor refactor in ep_pre_reorder_cuda_kernel to use a constant for loop iterations.
    • Minor refactor in ep_moe_pre_reorder to add const qualifiers to variable declarations.
  • sgl-kernel/include/sgl_kernel_ops.h
    • Added the declaration for the new ep_moe_post_reorder C++ function.
  • sgl-kernel/python/sgl_kernel/init.py
    • Imported the new ep_moe_post_reorder function.
  • sgl-kernel/python/sgl_kernel/moe.py
    • Added the Python wrapper function ep_moe_post_reorder that calls the C++ implementation.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.


CUDA code takes flight,
Replacing Triton's soft glow,
Speed up the reorder.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a CUDA kernel implementation for moe_ep_post_reorder, aiming to improve performance over the existing Triton kernel. This is a valuable addition for optimizing MoE operations.

The core logic is in the new ep_post_reorder_cuda_kernel and its corresponding host-side launch function. The changes to existing files are mostly to integrate this new functionality.

Overall, the CUDA kernel implementation looks reasonable, but there are a few areas for potential improvement regarding correctness, robustness, and best practices. Given this is a WIP PR, some of these points might already be on your radar. Please see the detailed comments below.

Summary of Findings

  • Kernel Correctness: Hidden Size Divisibility: The CUDA kernel ep_post_reorder_cuda_kernel assumes hidden_size is divisible by vec_size (4 for float). If not, trailing elements might be unprocessed. This needs clarification or handling for remainders. (Severity: Medium)
  • Kernel Clarity/Efficiency: computed Flag: The computed flag and an extra loop for out_vec in ep_post_reorder_cuda_kernel can be simplified, as acc is already initialized to zero. Removing them would improve clarity and potentially performance. (Severity: Medium)
  • Host Code Safety: int64_t to int Cast: Casting start_expert_id, end_expert_id, topk from int64_t to int in ep_moe_post_reorder (host) is potentially unsafe if values exceed INT_MAX. (Severity: Medium)
  • Host Code Robustness: Tensor Checks: The ep_moe_post_reorder host function could benefit from TORCH_CHECK assertions for input tensor properties (device, contiguity, dtype) for better robustness. (Severity: Medium)
  • API Design: Tensor Parameter Passing: In sgl_kernel_ops.h, the declaration for ep_moe_post_reorder should ideally use const torch::Tensor& for read-only input tensors and torch::Tensor& for the modified output tensor. (Severity: Medium)
  • Minor Optimization: Loop Invariant Calculation: In ep_pre_reorder_cuda_kernel, pre-calculating hidden_size / vec_size into vec_iters is a good minor optimization. (Severity: Low, not commented due to settings)
  • CUDA Kernel Style: Vector Initialization: In ep_post_reorder_cuda_kernel, flashinfer::vec_t might offer a more concise way to zero-initialize acc (e.g., acc.fill(0.f)), if available. (Severity: Low, not commented due to settings)
  • CUDA Kernel Style: Redundant Cast: The static_cast<float>(src_vec[i]) in ep_post_reorder_cuda_kernel might be redundant if src_vec[i] is already a float. (Severity: Low, not commented due to settings)
  • Host Code Style: Const Correctness: Adding const to local variables in ep_moe_pre_reorder (host) is good practice. (Severity: Low, not commented due to settings)
  • Python API: Documentation: The new Python wrapper ep_moe_post_reorder in sgl_kernel/python/sgl_kernel/moe.py is missing a docstring. (Severity: Low, not commented due to settings)
  • Python API: Type Hinting: The new Python wrapper ep_moe_post_reorder is missing type hints for parameters and return type. (Severity: Low, not commented due to settings)

Merge Readiness

This pull request is marked as WIP, and the introduction of a custom CUDA kernel for moe_ep_post_reorder is a significant step towards performance improvement.

Before this PR can be considered ready for merging, I recommend addressing the medium-severity issues identified in the review comments, particularly those related to potential correctness (hidden_size divisibility, int64_t to int casts) and robustness (tensor checks). The suggestions for code simplification and API consistency should also be considered.

Additionally, as per the PR checklist, completing unit tests, documentation, and benchmark results will be crucial for validating the changes and ensuring maintainability.

I am not authorized to approve pull requests. Please ensure further review and approval from other maintainers after addressing the feedback and completing the WIP items.

@yuan-luo yuan-luo force-pushed the moe_post_reorder_cuda branch from f932820 to 8419197 Compare June 3, 2025 09:58
@yuan-luo
Copy link
Contributor Author

yuan-luo commented Jun 3, 2025

[root  /home/root/luoyuan.luo/sglang] 二 6月 03 20:39:29 
$python ./sgl-kernel/benchmark/bench_moe_ep_post_reorder.py
Failed to import deepgemm, disable _ENABLE_JIT_DEEPGEMM.
Failed to import ceil_div from deep_gemm.
ep-moe-post-reorder-performance:
   batch_size  CUDA Kernel  Triton Kernel
0        64.0    33.856001      48.223998
1       128.0    43.839999      57.663999
2       256.0    64.768001      78.240000
3       512.0   112.992004     127.360001
4       640.0   137.055993     146.880001
5       768.0   161.024004     167.968005
6      1024.0   207.936004     207.791999
7      2048.0   393.216014     396.991998
8      4096.0   762.544036     773.952007

@yuan-luo yuan-luo changed the title WIP: [EP] Add cuda kernel for moe_ep_post_reorder [EP] Add cuda kernel for moe_ep_post_reorder Jun 3, 2025
for (uint32_t i = 0; i < vec_size; ++i)
acc[i] = 0.f;

bool computed = false;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As gemini said, we do not need computed flag.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed computed flag.

@Alcanderian
Copy link
Collaborator

Hello @yuan-luo , we have to support fp16/bf16/fp32 for this kernel. Ref: #6858

@yuan-luo
Copy link
Contributor Author

yuan-luo commented Jun 4, 2025

Hello @yuan-luo , we have to support fp16/bf16/fp32 for this kernel. Ref: #6858

Will follow up.

@yuan-luo yuan-luo force-pushed the moe_post_reorder_cuda branch 3 times, most recently from d6fca3f to 053caf1 Compare June 4, 2025 06:56
@yuan-luo yuan-luo force-pushed the moe_post_reorder_cuda branch from 053caf1 to c29b5e4 Compare June 4, 2025 13:39
@yuan-luo
Copy link
Contributor Author

yuan-luo commented Jun 4, 2025

Performance improved, but the precision test case not passed.

$python bench_moe_ep_post_reorder.py
Failed to import deepgemm, disable _ENABLE_JIT_DEEPGEMM.
Failed to import ceil_div from deep_gemm.
ep-moe-post-reorder-performance:
   batch_size  CUDA Kernel  Triton Kernel
0        64.0    32.736000      47.680002
1       128.0    42.431999      57.087999
2       256.0    60.927998      75.167999
3       512.0    99.391997     112.640001
4       640.0   119.263999     133.616000
5       768.0   141.343996     156.496003
6      1024.0   184.159994     196.319997
7      2048.0   352.064013     374.944001
8      4096.0   683.264017     726.624012

@yuan-luo yuan-luo force-pushed the moe_post_reorder_cuda branch 2 times, most recently from ee01689 to ad602b7 Compare June 4, 2025 14:20
@yuan-luo yuan-luo force-pushed the moe_post_reorder_cuda branch from ad602b7 to 1b7a4ab Compare June 4, 2025 14:48
@yuan-luo
Copy link
Contributor Author

yuan-luo commented Jun 4, 2025

Test passed.

$python test_ep_moe_post_reorder_kernel.py
Failed to import deepgemm, disable _ENABLE_JIT_DEEPGEMM.
Failed to import ceil_div from deep_gemm.
============================================================================================= test session starts =============================================================================================
platform linux -- Python 3.10.13, pytest-8.3.5, pluggy-1.5.0
rootdir: /home/root/luoyuan.luo/cuda-kernel-opt/pre_reorder
plugins: anyio-4.8.0, typeguard-4.3.0
collected 54 items                                                                                                                                                                                            

test_ep_moe_post_reorder_kernel.py ......................................................                                                                                                               [100%]
============================================================================================== warnings summary ===============================================================================================
../../../../../opt/conda/lib/python3.10/site-packages/_pytest/config/__init__.py:1277
  /opt/conda/lib/python3.10/site-packages/_pytest/config/__init__.py:1277: PytestAssertRewriteWarning: Module already imported so cannot be rewritten: anyio
    self._mark_plugins_for_rewrite(hook)

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================================================================================== 54 passed, 1 warning in 0.72s ========================================================================================

@yuan-luo
Copy link
Contributor Author

yuan-luo commented Jun 4, 2025

$python bench_moe_ep_post_reorder.py
Failed to import deepgemm, disable _ENABLE_JIT_DEEPGEMM.
Failed to import ceil_div from deep_gemm.
ep-moe-post-reorder-performance:
   batch_size  CUDA Kernel  Triton Kernel
0        64.0    33.728000      48.416000
1       128.0    42.431999      56.928001
2       256.0    62.944002      76.063998
3       512.0   101.968005     113.760002
4       640.0   121.023998     134.368002
5       768.0   143.040001     157.215998
6      1024.0   185.696006     197.039992
7      2048.0   354.016006     375.488013
8      4096.0   686.303973     728.511989

@Alcanderian Alcanderian added the ready-to-merge The PR is ready to merge after the CI is green. label Jun 4, 2025
@yuan-luo
Copy link
Contributor Author

yuan-luo commented Jun 5, 2025

Updated performance result after revising benchmark time consuming calculation.

$python ./bench_moe_ep_post_reorder.py 
Failed to import deepgemm, disable _ENABLE_JIT_DEEPGEMM.
Failed to import ceil_div from deep_gemm.
ep-moe-post-reorder-performance:
   batch_size  CUDA Kernel  Triton Kernel
0        64.0    10.976000      33.760000
1       128.0    11.936000      34.336001
2       256.0    13.344000      35.776000
3       512.0    20.479999      38.368002
4       640.0    24.256000      40.704001
5       768.0    27.807999      41.568000
6      1024.0    34.688000      43.744002
7      2048.0    58.752000      79.296000
8      4096.0   109.279998     149.504006

@zhyncs zhyncs merged commit 43baba6 into sgl-project:main Jun 5, 2025
74 of 80 checks passed
jianan-gu pushed a commit to jianan-gu/sglang that referenced this pull request Jun 12, 2025
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
xwu-intel pushed a commit to xwu-intel/sglang that referenced this pull request Jun 17, 2025
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
walker-ai pushed a commit to walker-ai/sglang that referenced this pull request Jul 8, 2025
Merge branch 'sgl_20250610_sync_tag047 of git@code.alipay.com:Theta/SGLang.git into main

https://code.alipay.com/Theta/SGLang/pull_requests/52


Reviewed-by: 剑川 <jianchuan.gys@antgroup.com>


* [Bugfix] Fix slice operation when chunk size mismatch (sgl-project#6697)
* [Bugfix] Fix ChatCompletion endpoint of mini_lb when stream is set (sgl-project#6703)
* [CI] Fix setup of disaggregation with different tp (sgl-project#6706)
* [PD] Remove Unnecessary Exception Handling for FastQueue.get() (sgl-project#6712)
* Fuse routed_scaling_factor in DeepSeek (sgl-project#6710)
* Overlap two kernels in DeepSeek with communication (sgl-project#6711)
* Minor refactor two-batch overlap (sgl-project#6682)
* Speed up when having padding tokens two-batch overlap (sgl-project#6668)
* [Feature] Support Flashinfer fp8 blockwise GEMM kernel on Blackwell (sgl-project#6479)
* Fix LoRA bench (sgl-project#6719)
* temp
* Fix PP for Qwen3 MoE (sgl-project#6709)
* [feat] triton kernel for get_last_loc (sgl-project#6676)
* [fix] more mem for draft_extend cuda_graph (sgl-project#6726)
* [PD] bug fix:  Update status if nixl receiver send a a dummy req. (sgl-project#6720)
* Tune memory arguments on B200 (sgl-project#6718)
* Add DeepSeek-R1-0528 function call chat template (sgl-project#6725)
* refactor(tool call): Fix BaseFormatDetector tool_index issue and refactor `parse_streaming_increment` (sgl-project#6715)
* Add draft extend CUDA graph for Triton backend (sgl-project#6705)
* refactor apply_w8a8_block_fp8_linear in fp (sgl-project#6545)
* [PD] Support completion endpoint (sgl-project#6729)
* PD Rust LB (PO2) (sgl-project#6437)
* Super tiny enable sole usage of expert distribution metrics and update doc (sgl-project#6680)
* Support picking variants of EPLB algorithms (sgl-project#6728)
* Support tuning DeepEP configs (sgl-project#6742)
* [test] add ut and bm for get_last_loc (sgl-project#6746)
* Fix mem_fraction_static for AMD CI (sgl-project#6748)
* [fix][RL] Fix DeepSeekV3ForCausalLM.post_load_weights for multiple update weight (sgl-project#6265)
* Improve EPLB logical to physical dispatch map (sgl-project#6727)
* Update DeepSeek-R1-0528 function call chat template (sgl-project#6765)
* [PD] Optimize time out logic and add env var doc for mooncake (sgl-project#6761)
* Fix aiohttp 'Chunk too big' in bench_serving (sgl-project#6737)
* Support sliding window in triton backend (sgl-project#6509)
* Fix shared experts fusion error (sgl-project#6289)
* Fix one bug in the grouped-gemm triton kernel (sgl-project#6772)
* update llama4 chat template and pythonic parser (sgl-project#6679)
* feat(tool call): Enhance Llama32Detector for improved JSON parsing in non-stream (sgl-project#6784)
* Support token-level quantization for EP MoE (sgl-project#6782)
* Temporarily lower mmlu threshold for triton sliding window backend (sgl-project#6785)
* ci: relax test_function_call_required (sgl-project#6786)
* Add intel_amx backend for Radix Attention for CPU (sgl-project#6408)
* Fix incorrect LoRA weight loading for fused gate_up_proj (sgl-project#6734)
* fix(PD-disaggregation): Can not get local ip (sgl-project#6792)
* [FIX] mmmu bench serving result display error (sgl-project#6525) (sgl-project#6791)
* Bump torch to 2.7.0 (sgl-project#6788)
* chore: bump sgl-kernel v0.1.5 (sgl-project#6794)
* Improve profiler and integrate profiler in bench_one_batch_server (sgl-project#6787)
* chore: upgrade sgl-kernel v0.1.5 (sgl-project#6795)
* [Minor] Always append newline after image token when parsing chat message (sgl-project#6797)
* Update CI tests for Llama4 models (sgl-project#6421)
* [Feat] Enable PDL automatically on Hopper architecture (sgl-project#5981)
* chore: update blackwell docker (sgl-project#6800)
* misc: cache is_hopper_arch (sgl-project#6799)
* Remove contiguous before Flashinfer groupwise fp8 gemm (sgl-project#6804)
* Correctly abort the failed grammar requests & Improve the handling of abort (sgl-project#6803)
* [EP] Add cuda kernel for moe_ep_pre_reorder (sgl-project#6699)
* Add draft extend CUDA graph for flashinfer backend  (sgl-project#6805)
* Refactor CustomOp to avoid confusing bugs (sgl-project#5382)
* Tiny log prefill time (sgl-project#6780)
* Tiny fix EPLB assertion about rebalancing period and recorder window size (sgl-project#6813)
* Add simple utility to dump tensors for debugging (sgl-project#6815)
* Fix profiles do not have consistent names (sgl-project#6811)
* Speed up rebalancing when using non-static dispatch algorithms (sgl-project#6812)
* [1/2] Add Kernel support for Cutlass based Fused FP4 MoE (sgl-project#6093)
* [Router] Fix k8s Service Discovery (sgl-project#6766)
* Add CPU optimized kernels for topk and rope fusions  (sgl-project#6456)
* fix new_page_count_next_decode (sgl-project#6671)
* Fix wrong weight reference in dynamic EPLB (sgl-project#6818)
* Minor add metrics to expert location updater (sgl-project#6816)
* [Refactor] Rename `n_share_experts_fusion` as `num_fused_shared_experts` (sgl-project#6735)
* [FEAT] Add transformers backend support  (sgl-project#5929)
* [fix] recover auto-dispatch for rmsnorm and rope (sgl-project#6745)
* fix ep_moe_reorder kernel bugs (sgl-project#6858)
* [Refactor] Multimodal data processing for VLM (sgl-project#6659)
* Decoder-only Scoring API (sgl-project#6460)
* feat: add dp-rank to KV events (sgl-project#6852)
* Set `num_fused_shared_experts` as `num_shared_experts` when shared_experts fusion is not disabled (sgl-project#6736)
* Fix one missing arg in DeepEP (sgl-project#6878)
* Support LoRA in TestOpenAIVisionServer and fix fused kv_proj loading bug. (sgl-project#6861)
* support 1 shot allreduce  in 1-node and 2-node using mscclpp (sgl-project#6277)
* Fix Qwen3MoE missing token padding optimization (sgl-project#6820)
* Tiny update error hints (sgl-project#6846)
* Support layerwise rebalancing experts (sgl-project#6851)
* Tiny allow profiler API to auto create directory (sgl-project#6865)
* Support Blackwell DeepEP docker images (sgl-project#6868)
* [EP] Add cuda kernel for moe_ep_post_reorder (sgl-project#6837)
* [theta]merge 0605
* oai: fix openAI client error with single request via batch api (sgl-project#6170)
* [PD] Fix potential perf spike caused by tracker gc and optimize doc (sgl-project#6764)
* Use deepgemm instead of triton for fused_qkv_a_proj_with_mqa (sgl-project#6890)
* [CUTLASS-FP4-MOE]  Introduce CutlassMoEParams class for easy initialization of Cutlass Grouped Gems Metadata (sgl-project#6887)
* bugfix(OAI): Fix image_data processing for jinja chat templates (sgl-project#6877)
* [CPU] enable CI for PRs, add Dockerfile and auto build task (sgl-project#6458)
* AITER backend extension and workload optimizations (sgl-project#6838)
* [theta]merge
* [theta]merge
* [Feature] Support Flashinfer fmha on Blackwell (sgl-project#6930)
* Fix a bug in abort & Improve docstrings for abort (sgl-project#6931)
* Tiny support customize DeepEP max dispatch tokens per rank (sgl-project#6934)
* Sync the changes on cuda graph runners (sgl-project#6932)
* [PD] Optimize transfer queue forward logic for dummy rank (sgl-project#6922)
* [Refactor] image data process in bench_serving (sgl-project#6879)
* [fix] logical_to_all_physical_map index 256 is out of bounds in EP parallel. (sgl-project#6767)
* Add triton fused moe kernel config for E=257 on B200 (sgl-project#6939)
* [sgl-kernel] update deepgemm (sgl-project#6942)
* chore: bump sgl-kernel v0.1.6 (sgl-project#6943)
* Minor compile fused topk (sgl-project#6944)
* [Bugfix] pipeline parallelism and Eagle Qwen2 (sgl-project#6910)
* Tiny re-introduce profile id logging (sgl-project#6912)
* Add triton version as a fused_moe_triton config search key to avoid performace decrease in different Triton version (sgl-project#5955)
* reduce torch.zeros overhead in moe align block size kernel (sgl-project#6369)
* chore: upgrade sgl-kernel v0.1.6 (sgl-project#6945)
* add fbgemm moe grouped gemm kernel benchmark (sgl-project#6924)
* [Docker] Add docker file for SGL Router (sgl-project#6915)
* Disabling mixed chunked prefill when eagle is enabled (sgl-project#6874)
* Add canary for EPLB rebalancing (sgl-project#6895)
* Refactor global_server_args_dict (sgl-project#6866)
* Fuse routed scaling factor in topk_reduce kernel (sgl-project#6220)
* Update server timeout time in AMD CI. (sgl-project#6953)
* [misc] add is_cpu() (sgl-project#6950)
* Add H20 fused MoE kernel tuning configs for DeepSeek-R1/V3 (sgl-project#6885)
* Add a CUDA kernel for fusing mapping and weighted sum for MoE. (sgl-project#6916)
* chore: bump sgl-kernel v0.1.6.post1 (sgl-project#6955)
* chore: upgrade sgl-kernel v0.1.6.post1 (sgl-project#6957)
* [DeepseekR1-FP4] Add Support for nvidia/DeepSeekR1-FP4 model (sgl-project#6853)
* Revert "Fuse routed scaling factor in topk_reduce kernel (sgl-project#6220)" (sgl-project#6968)
* [AMD] Add more tests to per-commit-amd (sgl-project#6926)
* chore: bump sgl-kernel v0.1.7 (sgl-project#6963)
* Slightly improve the sampler to skip unnecessary steps (sgl-project#6956)
* rebase h20 fused_moe config (sgl-project#6966)
* Fix CI and triton moe Configs (sgl-project#6974)
* Remove unnecessary kernels of num_token_non_padded (sgl-project#6965)
* Extend cuda graph capture bs for B200 (sgl-project#6937)
* Fuse routed scaling factor in deepseek (sgl-project#6970)
* Sync cuda graph runners (sgl-project#6976)
* Fix draft extend ut stability with flush cache (sgl-project#6979)
* Fix triton sliding window test case (sgl-project#6981)
* Fix expert distribution dumping causes OOM (sgl-project#6967)
* Minor remove one kernel for DeepSeek (sgl-project#6977)
* [perf][sgl-kernel] extend cutlass_mla_decode to support num_head < 128 (sgl-project#6929)
* Enable more unit tests for AMD CI. (sgl-project#6983)
* Use torch.compile to fuse flash attention decode metadata preparation (sgl-project#6973)
* Eliminate stream sync to speed up LoRA batch init  (sgl-project#6960)
* support qwen3 emebedding (sgl-project#6990)
* Fix torch profiler bugs for bench_offline_throughput.py (sgl-project#6557)
* chore: upgrade flashinfer v0.2.6.post1 jit (sgl-project#6958)
* cleanup tmp dir (sgl-project#7007)
* chore: update pr test xeon (sgl-project#7008)
* Fix cutlass MLA gets almost zero accuracy (sgl-project#6998)
* Update amd nightly models CI. (sgl-project#6992)
* feat: add direct routing strategy to DP worker (sgl-project#6884)
* Fallback to lower triton version for unfound fused moe configs (sgl-project#7013)
* Fix torchvision version for Blackwell (sgl-project#7015)
* Simplify prepare_extend_after_decode (sgl-project#6987)
* Migrate to assertEqual (sgl-project#6741)
* Fix torch version in blackwell dockerfile (sgl-project#7017)
* chore: update pr test xeon (sgl-project#7018)
* Update default settings for blackwell (sgl-project#7023)
* Support both approximate and exact expert distribution collection (sgl-project#6964)
* Add decode req pool (sgl-project#6980)
* [theta]merge 0610
* [theta]merge 0610
* [CI] Add CI workflow for sgl-router docker build (sgl-project#7027)
* Fix fused_moe triton configs (sgl-project#7029)
* CPU: map changes from developing branch in sgl-kernel (sgl-project#6833)
* chore: bump v0.4.7 (sgl-project#7038)
* Update README.md (sgl-project#7040)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready-to-merge The PR is ready to merge after the CI is green.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants