Skip to content

Conversation

yuan-luo
Copy link
Contributor

@yuan-luo yuan-luo commented Jun 6, 2025

Motivation

Rewrite moe_ep_silu_and_mul Triton kernel in CUDA. Gains 10% performance improvement.

Modifications

Checklist

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @yuan-luo, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

Summary of Changes

Hello! gemini-code-assist here, providing a summary of this pull request to help everyone get up to speed. This PR, marked as Work In Progress (WIP), introduces a new CUDA kernel specifically designed for the SiLU activation followed by multiplication, intended for use within the Mixture-of-Experts (MoE) layers. The goal is likely to provide a highly optimized implementation of this common operation on the GPU.

The changes involve adding the core CUDA kernel implementation, registering it with the PyTorch extension mechanism, creating C++ and Python wrappers to make it callable from Python, and adding a unit test to verify its correctness by comparing its output against an existing Triton kernel implementation.

Highlights

  • New CUDA Kernel: Adds a new CUDA kernel (ep_moe_act_and_mul_cuda_kernel) to perform the SiLU activation and multiplication operation efficiently on the GPU for MoE layers.
  • PyTorch Integration: Registers the new CUDA kernel with the PyTorch extension (sgl_kernel) and provides C++ and Python wrappers (ep_moe_silu_and_mul) to expose it to the Python API.
  • Unit Testing: Includes a new unit test (test_ep_moe_silu_and_mul_vs_triton) that compares the output of the new CUDA kernel against an existing Triton implementation for various configurations (token counts, hidden sizes, data types) to ensure correctness.

Changelog

Click here to see the changelog
  • sgl-kernel/csrc/common_extension.cc
    • Registers the new ep_moe_silu_and_mul kernel with the PyTorch sgl_kernel library fragment.
  • sgl-kernel/csrc/moe/ep_moe_silu_and_mul.cu
    • Adds a new CUDA source file containing the ep_moe_act_and_mul_cuda_kernel which implements the SiLU and multiplication logic for MoE.
    • Includes a C++ wrapper function ep_moe_silu_and_mul to launch the CUDA kernel.
  • sgl-kernel/include/sgl_kernel_ops.h
    • Adds the C++ function declaration for ep_moe_silu_and_mul.
  • sgl-kernel/include/utils.h
    • Adds a device function silu for the SiLU activation.
  • sgl-kernel/python/sgl_kernel/init.py
    • Imports the new ep_moe_silu_and_mul function into the Python package's public API.
  • sgl-kernel/python/sgl_kernel/moe.py
    • Adds a Python wrapper function ep_moe_silu_and_mul that calls the registered C++ kernel via torch.ops.sgl_kernel.
  • sgl-kernel/tests/test_ep_moe_silu_and_mul_kernel.py
    • Adds a new test file with a Pytest test test_ep_moe_silu_and_mul_vs_triton.
    • This test compares the output of the new CUDA kernel against a Triton kernel implementation for correctness across different parameters and data types.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new CUDA kernel ep_moe_silu_and_mul for Mixture-of-Experts, along with its PyTorch bindings and Python tests. The initiative to add a custom CUDA kernel for this operation is good for performance optimization.

However, the CUDA kernel implementation (ep_moe_silu_and_mul.cu) has several critical issues that will prevent compilation or lead to incorrect behavior. These include incorrect variable declarations, undefined variables, missing kernel arguments, and incorrect logic for loop bounds. The host wrapper function also has issues with kernel launch configuration and argument passing.

The Python tests comparing against a Triton kernel are a good step towards ensuring correctness once the CUDA kernel issues are resolved.

Given the WIP (Work In Progress) status in the PR title and the nature of the identified issues, significant revisions are needed before this can be merged.

Summary of Findings

  • CUDA Kernel Compilation Errors: The CUDA kernel ep_moe_act_and_mul_cuda_kernel and its host wrapper ep_moe_silu_and_mul have several issues that will prevent compilation:
  1. const float scale modified (critical).
  2. Undefined total_blocks in grid configuration (critical).
  3. Missing hidden_size argument in kernel launch (critical).
  4. Incorrect type cast for scales.data_ptr() in kernel launch (critical).
  5. Undefined tid in kernel loop (critical).
  6. Undefined vec_t type in kernel (critical).
  • CUDA Kernel Logic Error: The vec_elements calculation for loop bounds in the CUDA kernel is based on hidden_size but should be based on half_hidden_size (high).
  • Unused Variable in CUDA Kernel: An unused variable vec is declared in the CUDA kernel (medium).
  • Non-English Comment in Test Code: A comment in sgl-kernel/tests/test_ep_moe_silu_and_mul_kernel.py (line 68) is in Chinese. While not a functional issue, English comments are generally preferred for broader accessibility in open-source projects. (Severity: low, not commented directly due to settings).

Merge Readiness

This pull request is currently a Work In Progress (WIP) and has several critical issues in the CUDA kernel implementation that need to be addressed. These issues will prevent compilation and/or lead to incorrect functionality. Therefore, the PR is not ready for merging.

I recommend addressing all the critical and high severity issues identified in the review comments. Once these are resolved, further review, especially focusing on performance and edge cases, would be beneficial. As a reviewer, I am not authorized to approve pull requests; please ensure other maintainers review and approve the changes before merging.

@yuan-luo yuan-luo force-pushed the moe_silu_and_mul_cuda branch 2 times, most recently from 5694052 to 084e41f Compare June 6, 2025 09:52
@yuan-luo yuan-luo force-pushed the moe_silu_and_mul_cuda branch 4 times, most recently from d071ba5 to 32b09cc Compare June 8, 2025 08:46
@yuan-luo
Copy link
Contributor Author

yuan-luo commented Jun 8, 2025

Benchmark result, CUDA gains 10% improvement over Triton.

$python sgl-kernel/benchmark/bench_moe_silu_and_mul.py
Failed to import deepgemm, disable _ENABLE_JIT_DEEPGEMM.
Failed to import ceil_div from deep_gemm.
ep-moe-silu-and-mul-performance:
   batch_size  CUDA Kernel  Triton Kernel
0        64.0     6.656000          8.704
1       128.0     6.880000          8.896
2       256.0     7.488000          9.568
3       512.0     9.280000         10.656
4       640.0     9.952000         10.752
5       768.0    10.880000         11.168
6      1024.0    12.192000         12.288
7      2048.0    17.503999         18.336
8      4096.0    28.896000         30.112

@yuan-luo yuan-luo force-pushed the moe_silu_and_mul_cuda branch from 32b09cc to ce154f7 Compare June 8, 2025 12:12
@yuan-luo
Copy link
Contributor Author

yuan-luo commented Jun 8, 2025

[root  /home/root/luoyuan.luo/sglang] 日 6月 08 20:09:15 
$python ./sgl-kernel/tests/test_ep_moe_silu_and_mul_kernel.py
Failed to import deepgemm, disable _ENABLE_JIT_DEEPGEMM.
Failed to import ceil_div from deep_gemm.
============================================================================================= test session starts =============================================================================================
platform linux -- Python 3.10.13, pytest-8.3.5, pluggy-1.5.0
rootdir: /home/root/luoyuan.luo/sglang/sgl-kernel
configfile: pyproject.toml
plugins: anyio-4.8.0, typeguard-4.3.0
collected 27 items                                                                                                                                                                                            

sgl-kernel/tests/test_ep_moe_silu_and_mul_kernel.py ...........................                                                                                                                         [100%]

============================================================================================== warnings summary ===============================================================================================
../../../../opt/conda/lib/python3.10/site-packages/_pytest/config/__init__.py:1277
  /opt/conda/lib/python3.10/site-packages/_pytest/config/__init__.py:1277: PytestAssertRewriteWarning: Module already imported so cannot be rewritten: anyio
    self._mark_plugins_for_rewrite(hook)

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================================================================================== 27 passed, 1 warning in 0.64s ========================================================================================

@yuan-luo yuan-luo changed the title WIP: Add cuda kernel for moe_ep_silu_and_mul Add cuda kernel for moe_ep_silu_and_mul Jun 8, 2025
@yuan-luo yuan-luo force-pushed the moe_silu_and_mul_cuda branch from 0048d0f to 6b88282 Compare June 9, 2025 05:10
def benchmark(batch_size, provider):
dtype = torch.bfloat16
device = torch.device("cuda")
hidden_size, topk, start_expert_id, end_expert_id, block_size = 4096, 8, 0, 255, 512
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can these variables also take different value combinations in the configs above?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

$python ./sgl-kernel/benchmark/bench_moe_silu_and_mul.py
Failed to import deepgemm, disable _ENABLE_JIT_DEEPGEMM.
Failed to import ceil_div from deep_gemm.
ep-moe-silu-and-mul-performance:
     batch_size  hidden_size  block_size  CUDA Kernel  Triton Kernel
0          64.0       1024.0       128.0     6.656000       6.528000
1          64.0       1024.0       256.0     6.752000       6.784000
2          64.0       1024.0       512.0     6.784000       6.528000
3          64.0       2048.0       128.0     6.912000       7.552000
4          64.0       2048.0       256.0     6.848000       7.296000
5          64.0       2048.0       512.0     6.592000       7.296000
6          64.0       4096.0       128.0     7.040000       8.960000
7          64.0       4096.0       256.0     7.008000       8.672000
8          64.0       4096.0       512.0     6.752000       8.672000
9          64.0       8192.0       128.0     7.232000      11.360000
10         64.0       8192.0       256.0     6.976000      11.616000
11         64.0       8192.0       512.0     7.232000      11.328000
12        128.0       1024.0       128.0     6.592000       6.624000
13        128.0       1024.0       256.0     6.592000       6.656000
14        128.0       1024.0       512.0     6.848000       6.880000
15        128.0       2048.0       128.0     7.008000       7.744000
16        128.0       2048.0       256.0     7.008000       7.776000
17        128.0       2048.0       512.0     7.008000       7.488000
18        128.0       4096.0       128.0     7.232000       9.248000
19        128.0       4096.0       256.0     6.976000       9.248000
20        128.0       4096.0       512.0     7.200000       9.024000
21        128.0       8192.0       128.0     7.888000      12.064000
22        128.0       8192.0       256.0     7.616000      12.288000
23        128.0       8192.0       512.0     7.904000      12.064000
24        256.0       1024.0       128.0     6.816000       6.848000
25        256.0       1024.0       256.0     7.040000       6.848000
26        256.0       1024.0       512.0     7.072000       7.104000
27        256.0       2048.0       128.0     7.008000       7.872000
28        256.0       2048.0       256.0     6.976000       7.872000
29        256.0       2048.0       512.0     7.232000       8.128000
30        256.0       4096.0       128.0     7.680000       9.664000
31        256.0       4096.0       256.0     7.648000       9.696000
32        256.0       4096.0       512.0     7.904000       9.888000
33        256.0       8192.0       128.0     8.736000      13.408000
34        256.0       8192.0       256.0     8.960000      13.408000
35        256.0       8192.0       512.0     8.928000      13.376000
36        512.0       1024.0       128.0     8.512000       7.136000
37        512.0       1024.0       256.0     8.544000       7.424000
38        512.0       1024.0       512.0     8.544000       7.168000
39        512.0       2048.0       128.0     8.896000       8.704000
40        512.0       2048.0       256.0     8.896000       8.704000
41        512.0       2048.0       512.0     8.928000       8.448000
42        512.0       4096.0       128.0     9.696000      10.560000
43        512.0       4096.0       256.0     9.728000      10.592000
44        512.0       4096.0       512.0     9.472000      10.592000
45        512.0       8192.0       128.0    11.360000      14.592000
46        512.0       8192.0       256.0    11.360000      14.784000
47        512.0       8192.0       512.0    11.360000      14.560000
48        640.0       1024.0       128.0     9.024000       7.424000
49        640.0       1024.0       256.0     8.800000       7.648000
50        640.0       1024.0       512.0     9.024000       7.392000
51        640.0       2048.0       128.0     9.408000       8.736000
52        640.0       2048.0       256.0     9.472000       8.736000
53        640.0       2048.0       512.0     9.184000       8.960000
54        640.0       4096.0       128.0    10.432000      10.880000
55        640.0       4096.0       256.0    10.432000      10.912000
56        640.0       4096.0       512.0    10.144000      11.104000
57        640.0       8192.0       128.0    12.224000      15.424000
58        640.0       8192.0       256.0    12.480000      15.392000
59        640.0       8192.0       512.0    12.448000      15.616000
60        768.0       1024.0       128.0     9.472000       7.872000
61        768.0       1024.0       256.0     9.472000       7.872000
62        768.0       1024.0       512.0     9.728000       7.616000
63        768.0       2048.0       128.0    10.784000       9.408000
64        768.0       2048.0       256.0    10.528000       9.184000
65        768.0       2048.0       512.0    10.496000       9.184000
66        768.0       4096.0       128.0    11.040000      11.584000
67        768.0       4096.0       256.0    11.072000      11.584000
68        768.0       4096.0       512.0    11.072000      11.584000
69        768.0       8192.0       128.0    13.760000      16.192000
70        768.0       8192.0       256.0    13.760000      16.128000
71        768.0       8192.0       512.0    13.760000      16.416000
72       1024.0       1024.0       128.0    10.688000       8.000000
73       1024.0       1024.0       256.0    10.688000       8.000000
74       1024.0       1024.0       512.0    10.944000       8.000000
75       1024.0       2048.0       128.0    11.360000       9.632000
76       1024.0       2048.0       256.0    11.632000       9.824000
77       1024.0       2048.0       512.0    11.392000       9.632000
78       1024.0       4096.0       128.0    12.864000      12.544000
79       1024.0       4096.0       256.0    12.896000      12.768000
80       1024.0       4096.0       512.0    12.896000      12.512000
81       1024.0       8192.0       128.0    15.904000      18.912001
82       1024.0       8192.0       256.0    15.840000      18.912001
83       1024.0       8192.0       512.0    15.807999      18.912001
84       2048.0       1024.0       128.0    14.944000      10.208000
85       2048.0       1024.0       256.0    15.232000       9.920000
86       2048.0       1024.0       512.0    14.976000      10.240000
87       2048.0       2048.0       128.0    16.160000      13.088000
88       2048.0       2048.0       256.0    15.936000      13.120000
89       2048.0       2048.0       512.0    16.192000      13.312000
90       2048.0       4096.0       128.0    18.495999      18.688001
91       2048.0       4096.0       256.0    18.495999      18.464001
92       2048.0       4096.0       512.0    18.495999      18.464001
93       2048.0       8192.0       128.0    25.792001      30.592000
94       2048.0       8192.0       256.0    25.823999      30.880000
95       2048.0       8192.0       512.0    25.536001      30.880000
96       4096.0       1024.0       128.0    22.879999      13.632000
97       4096.0       1024.0       256.0    22.848001      13.648000
98       4096.0       1024.0       512.0    22.848001      13.440000
99       4096.0       2048.0       128.0    24.992000      18.816000
100      4096.0       2048.0       256.0    25.216000      18.592000
101      4096.0       2048.0       512.0    24.992000      18.816000
102      4096.0       4096.0       128.0    30.751999      30.015999
103      4096.0       4096.0       256.0    30.432001      30.080000
104      4096.0       4096.0       512.0    30.751999      30.272000
105      4096.0       8192.0       128.0    43.040000      50.783999
106      4096.0       8192.0       256.0    43.008000      50.816000
107      4096.0       8192.0       512.0    43.040000      50.816000

out_vec.store(dst_ptr + idx * vec_size);
}

#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's no necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

@yuan-luo yuan-luo force-pushed the moe_silu_and_mul_cuda branch from 551289c to 7682138 Compare June 9, 2025 08:10
@yuan-luo yuan-luo force-pushed the moe_silu_and_mul_cuda branch from 7682138 to 88c834e Compare June 9, 2025 08:19
@yuan-luo yuan-luo changed the title Add cuda kernel for moe_ep_silu_and_mul [sgl-kernel] Add cuda kernel for moe_ep_silu_and_mul Jun 9, 2025
Copy link
Collaborator

@BBuf BBuf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@yuan-luo yuan-luo force-pushed the moe_silu_and_mul_cuda branch from 88c834e to de3a17e Compare June 9, 2025 15:38
@yuan-luo
Copy link
Contributor Author

yuan-luo commented Jun 9, 2025

[root  /home/root/luoyuan.luo/sglang] 一 6月 09 23:35:21 
$python ./sgl-kernel/benchmark/bench_moe_silu_and_mul.py
Failed to import deepgemm, disable _ENABLE_JIT_DEEPGEMM.
Failed to import ceil_div from deep_gemm.
ep-moe-silu-and-mul-performance:
     batch_size  hidden_size  block_size  CUDA Kernel  Triton Kernel
0          64.0       1024.0       128.0     6.448000       6.592000
1          64.0       1024.0       256.0     6.560000       6.656000
2          64.0       1024.0       512.0     6.400000       6.432000
3          64.0       2048.0       128.0     6.720000       7.392000
4          64.0       2048.0       256.0     6.496000       7.360000
5          64.0       2048.0       512.0     6.496000       7.200000
6          64.0       4096.0       128.0     6.624000       8.768000
7          64.0       4096.0       256.0     6.624000       8.736000
8          64.0       4096.0       512.0     6.848000       8.512000
9          64.0       8192.0       128.0     6.848000      11.424000
10         64.0       8192.0       256.0     6.848000      11.232000
11         64.0       8192.0       512.0     6.848000      11.424000
12        128.0       1024.0       128.0     6.688000       6.720000
13        128.0       1024.0       256.0     6.464000       6.528000
14        128.0       1024.0       512.0     6.688000       6.528000
15        128.0       2048.0       128.0     6.592000       7.584000
16        128.0       2048.0       256.0     6.624000       7.584000
17        128.0       2048.0       512.0     6.816000       7.584000
18        128.0       4096.0       128.0     7.040000       9.088000
19        128.0       4096.0       256.0     6.848000       8.896000
20        128.0       4096.0       512.0     6.816000       8.864000
21        128.0       8192.0       128.0     7.680000      11.936000
22        128.0       8192.0       256.0     7.520000      12.128000
23        128.0       8192.0       512.0     7.488000      12.128000
24        256.0       1024.0       128.0     6.656000       6.944000
25        256.0       1024.0       256.0     6.880000       6.688000
26        256.0       1024.0       512.0     6.656000       6.944000
27        256.0       2048.0       128.0     6.848000       7.744000
28        256.0       2048.0       256.0     6.848000       7.936000
29        256.0       2048.0       512.0     7.072000       7.712000
30        256.0       4096.0       128.0     7.552000       9.728000
31        256.0       4096.0       256.0     7.552000       9.728000
32        256.0       4096.0       512.0     7.712000       9.728000
33        256.0       8192.0       128.0     8.576000      13.024000
34        256.0       8192.0       256.0     8.608000      13.216000
35        256.0       8192.0       512.0     8.576000      13.216000
36        512.0       1024.0       128.0     8.352000       7.232000
37        512.0       1024.0       256.0     8.352000       7.232000
38        512.0       1024.0       512.0     8.352000       7.008000
39        512.0       2048.0       128.0     8.736000       8.320000
40        512.0       2048.0       256.0     8.704000       8.512000
41        512.0       2048.0       512.0     8.704000       8.288000
42        512.0       4096.0       128.0     9.344000      10.464000
43        512.0       4096.0       256.0     9.536000      10.400000
44        512.0       4096.0       512.0     9.504000      10.624000
45        512.0       8192.0       128.0    10.976000      14.624000
46        512.0       8192.0       256.0    10.976000      14.624000
47        512.0       8192.0       512.0    10.976000      14.464000
48        640.0       1024.0       128.0     8.672000       7.296000
49        640.0       1024.0       256.0     8.640000       7.456000
50        640.0       1024.0       512.0     8.640000       7.264000
51        640.0       2048.0       128.0     9.248000       8.800000
52        640.0       2048.0       256.0     9.248000       8.576000
53        640.0       2048.0       512.0     9.024000       8.608000
54        640.0       4096.0       128.0    10.240000      10.784000
55        640.0       4096.0       256.0    10.272000      10.752000
56        640.0       4096.0       512.0    10.240000      10.944000
57        640.0       8192.0       128.0    12.064000      15.168000
58        640.0       8192.0       256.0    12.288000      15.200000
59        640.0       8192.0       512.0    12.064000      15.200000
60        768.0       1024.0       128.0     9.568000       7.680000
61        768.0       1024.0       256.0     9.568000       7.712000
62        768.0       1024.0       512.0     9.376000       7.456000
63        768.0       2048.0       128.0     9.984000       9.056000
64        768.0       2048.0       256.0    10.176000       8.896000
65        768.0       2048.0       512.0    10.176000       9.056000
66        768.0       4096.0       128.0    10.912000      11.408000
67        768.0       4096.0       256.0    10.928000      11.424000
68        768.0       4096.0       512.0    11.136000      11.392000
69        768.0       8192.0       128.0    13.376000      16.256001
70        768.0       8192.0       256.0    13.152000      16.031999
71        768.0       8192.0       512.0    13.376000      16.256001
72       1024.0       1024.0       128.0    10.752000       8.064000
73       1024.0       1024.0       256.0    10.752000       7.872000
74       1024.0       1024.0       512.0    10.752000       8.064000
75       1024.0       2048.0       128.0    11.456000       9.696000
76       1024.0       2048.0       256.0    11.456000       9.728000
77       1024.0       2048.0       512.0    11.264000       9.711999
78       1024.0       4096.0       128.0    12.704000      12.384000
79       1024.0       4096.0       256.0    12.704000      12.384000
80       1024.0       4096.0       512.0    12.512000      12.384000
81       1024.0       8192.0       128.0    16.128000      18.528000
82       1024.0       8192.0       256.0    15.840000      18.528000
83       1024.0       8192.0       512.0    16.192000      18.560000
84       2048.0       1024.0       128.0    14.816000      10.048000
85       2048.0       1024.0       256.0    14.848000      10.048000
86       2048.0       1024.0       512.0    14.816000       9.792000
87       2048.0       2048.0       128.0    16.192000      12.928000
88       2048.0       2048.0       256.0    15.904000      12.960000
89       2048.0       2048.0       512.0    15.936000      12.960000
90       2048.0       4096.0       128.0    18.368000      18.528000
91       2048.0       4096.0       256.0    18.592000      18.304000
92       2048.0       4096.0       512.0    18.592000      18.528000
93       2048.0       8192.0       128.0    25.408000      30.719999
94       2048.0       8192.0       256.0    25.664000      30.432001
95       2048.0       8192.0       512.0    25.376000      30.719999
96       4096.0       1024.0       128.0    22.944000      13.280000
97       4096.0       1024.0       256.0    22.944000      13.280000
98       4096.0       1024.0       512.0    22.720000      13.472000
99       4096.0       2048.0       128.0    25.056001      18.464001
100      4096.0       2048.0       256.0    25.056001      18.464001
101      4096.0       2048.0       512.0    25.087999      18.495999
102      4096.0       4096.0       128.0    30.560000      29.888000
103      4096.0       4096.0       256.0    30.336000      30.112000
104      4096.0       4096.0       512.0    30.304000      30.112000
105      4096.0       8192.0       128.0    43.168001      50.976001
106      4096.0       8192.0       256.0    43.184001      50.912000
107      4096.0       8192.0       512.0    43.136001      50.719999

@yuan-luo yuan-luo force-pushed the moe_silu_and_mul_cuda branch from de3a17e to a244149 Compare June 9, 2025 16:08
@Alcanderian
Copy link
Collaborator

ep-moe-silu-and-mul-performance:
     batch_size  hidden_size  block_size  CUDA Kernel  Triton Kernel
0          64.0       1024.0       128.0     6.496000       6.848000
1          64.0       1024.0       256.0     6.592000       6.624000
2          64.0       1024.0       512.0     6.592000       6.912000
3          64.0       2048.0       128.0     6.656000       7.424000
4          64.0       2048.0       256.0     6.912000       7.680000
5          64.0       2048.0       512.0     6.688000       7.648000
6          64.0       4096.0       128.0     6.784000       8.800000
7          64.0       4096.0       256.0     7.040000       9.024000
8          64.0       4096.0       512.0     7.072000       9.024000
9          64.0       8192.0       128.0     7.296000      11.616000
10         64.0       8192.0       256.0     7.296000      11.648000
11         64.0       8192.0       512.0     7.296000      11.648000
12        128.0       1024.0       128.0     6.656000       6.752000
13        128.0       1024.0       256.0     6.944000       6.688000
14        128.0       1024.0       512.0     6.656000       6.720000
15        128.0       2048.0       128.0     7.040000       7.584000
16        128.0       2048.0       256.0     6.816000       7.616000
17        128.0       2048.0       512.0     6.784000       7.840000
18        128.0       4096.0       128.0     7.296000       9.088000
19        128.0       4096.0       256.0     7.072000       9.088000
20        128.0       4096.0       512.0     7.296000       9.088000
21        128.0       8192.0       128.0     8.128000      11.968000
22        128.0       8192.0       256.0     7.872000      12.288000
23        128.0       8192.0       512.0     7.872000      12.288000
24        256.0       1024.0       128.0     6.784000       7.072000
25        256.0       1024.0       256.0     6.752000       6.848000
26        256.0       1024.0       512.0     7.008000       7.072000
27        256.0       2048.0       128.0     7.040000       8.032000
28        256.0       2048.0       256.0     7.040000       8.032000
29        256.0       2048.0       512.0     7.040000       7.808000
30        256.0       4096.0       128.0     7.840000       9.824000
31        256.0       4096.0       256.0     8.096000       9.856000
32        256.0       4096.0       512.0     7.840000       9.600000
33        256.0       8192.0       128.0     8.896000      12.864000
34        256.0       8192.0       256.0     8.896000      13.088000
35        256.0       8192.0       512.0     8.896000      13.088000
36        512.0       1024.0       128.0     7.040000       7.168000
37        512.0       1024.0       256.0     7.264000       7.392000
38        512.0       1024.0       512.0     7.040000       7.392000
39        512.0       2048.0       128.0     7.872000       8.640000
40        512.0       2048.0       256.0     8.096000       8.384000
41        512.0       2048.0       512.0     7.840000       8.640000
42        512.0       4096.0       128.0     8.928000      10.368000
43        512.0       4096.0       256.0     9.184000      10.336000
44        512.0       4096.0       512.0     8.928000      10.368000
45        512.0       8192.0       128.0    11.296000      14.528000
46        512.0       8192.0       256.0    11.296000      14.048000
47        512.0       8192.0       512.0    11.040000      14.528000
48        640.0       1024.0       128.0     7.264000       7.584000
49        640.0       1024.0       256.0     7.488000       7.328000
50        640.0       1024.0       512.0     7.232000       7.328000
51        640.0       2048.0       128.0     8.384000       8.992000
52        640.0       2048.0       256.0     8.384000       8.736000
53        640.0       2048.0       512.0     8.384000       8.704000
54        640.0       4096.0       128.0     9.728000      11.072000
55        640.0       4096.0       256.0     9.472000      11.072000
56        640.0       4096.0       512.0     9.472000      10.816000
57        640.0       8192.0       128.0    12.640000      14.944000
58        640.0       8192.0       256.0    12.640000      15.232000
59        640.0       8192.0       512.0    12.640000      14.944000
60        768.0       1024.0       128.0     7.488000       7.584000
61        768.0       1024.0       256.0     7.488000       7.872000
62        768.0       1024.0       512.0     7.712000       7.600000
63        768.0       2048.0       128.0     8.608000       9.280000
64        768.0       2048.0       256.0     8.352000       8.992000
65        768.0       2048.0       512.0     8.608000       8.992000
66        768.0       4096.0       128.0     9.920000      11.296000
67        768.0       4096.0       256.0    10.144000      11.264000
68        768.0       4096.0       512.0     9.920000      11.264000
69        768.0       8192.0       128.0    13.728000      16.063999
70        768.0       8192.0       256.0    13.504000      16.063999
71        768.0       8192.0       512.0    13.728000      16.319999
72       1024.0       1024.0       128.0     8.160000       8.288000
73       1024.0       1024.0       256.0     7.888000       8.032000
74       1024.0       1024.0       512.0     8.160000       8.000000
75       1024.0       2048.0       128.0     8.928000       9.536000
76       1024.0       2048.0       256.0     9.216000       9.792000
77       1024.0       2048.0       512.0     9.216000       9.760000
78       1024.0       4096.0       128.0    11.072000      12.320000
79       1024.0       4096.0       256.0    11.072000      12.576000
80       1024.0       4096.0       512.0    11.072000      12.320000
81       1024.0       8192.0       128.0    15.968001      18.495999
82       1024.0       8192.0       256.0    15.744001      18.495999
83       1024.0       8192.0       512.0    15.712000      18.495999
84       2048.0       1024.0       128.0     9.568000       9.280000
85       2048.0       1024.0       256.0     9.568000       9.536000
86       2048.0       1024.0       512.0     9.568000       9.280000
87       2048.0       2048.0       128.0    11.584000      11.776000
88       2048.0       2048.0       256.0    11.584000      11.744000
89       2048.0       2048.0       512.0    11.584000      11.488000
90       2048.0       4096.0       128.0    15.744001      16.928000
91       2048.0       4096.0       256.0    15.776001      16.928000
92       2048.0       4096.0       512.0    16.000001      17.152000
93       2048.0       8192.0       128.0    24.800001      27.360000
94       2048.0       8192.0       256.0    24.800001      27.616000
95       2048.0       8192.0       512.0    24.800001      27.616000
96       4096.0       1024.0       128.0    12.416000      12.032000
97       4096.0       1024.0       256.0    12.384000      12.032000
98       4096.0       1024.0       512.0    12.384000      12.032000
99       4096.0       2048.0       128.0    16.256001      16.736001
100      4096.0       2048.0       256.0    16.224001      16.736001
101      4096.0       2048.0       512.0    16.224001      16.736001
102      4096.0       4096.0       128.0    24.448000      27.136000
103      4096.0       4096.0       256.0    24.448000      27.168000
104      4096.0       4096.0       512.0    24.416000      27.168000
105      4096.0       8192.0       128.0    41.023999      45.984000
106      4096.0       8192.0       256.0    41.216001      45.728002
107      4096.0       8192.0       512.0    41.023999      45.984000

@Alcanderian Alcanderian added the ready-to-merge The PR is ready to merge after the CI is green. label Jun 9, 2025
@merrymercy merrymercy merged commit 84727a5 into sgl-project:main Jun 12, 2025
105 of 125 checks passed
jianan-gu pushed a commit to jianan-gu/sglang that referenced this pull request Jun 12, 2025
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority ready-to-merge The PR is ready to merge after the CI is green.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants