Skip to content

Conversation

varun-sundar-rabindranath
Copy link
Contributor

@varun-sundar-rabindranath varun-sundar-rabindranath commented Jul 2, 2025

Purpose

Bug 1:

  • TritonOrDeepGemmExperts chooses DeepGemmExperts for num_tokens > 128, it chooses TritonExperts otherwise. Imagine a case, where num_tokens is 130 and the chunk size is 128. In this case, we'd need DeepGemmExperts::workspace_shapes for the first chunk and TritonExperts::workspace_shapes for the second chunk.
  • On main, we would compute the workspace shapes once and use it for all the chunks. This results in a crash as the workspace shapes of the two implementations are different.

Fix: Compute the workspace shapes individually for each chunk.

Bug 2 (soft) :

  • Some all2all implementations return a expert_num_tokens. This tensor contains the number of tokens assigned to each expert. This isn't accounted for correctly.

Fix : The PR introduces a count_expert_num_tokens kernel and computes the expert_num_tokens for each chunk individually. This is termed a "soft bug" as, for experts that support chunking, expert_num_tokens is None / unused.

As part of the fixes, this PR moves the chunking logic out of the main FusedMoEModularKernel::forward pass for clarity.

Test Plan

Machine : H100

VLLM_ALL2ALL_BACKEND="deepep_high_throughput" VLLM_USE_DEEP_GEMM=1  canhazgpu run -g 2  vllm serve Qwen/Qwen3-30B-A3B-FP8  --trust-remote-code --enable-expert-parallel --data-parallel-size 2 --port 9010
lm_eval --model local-completions --tasks gsm8k --model_args model=Qwen/Qwen3-30B-A3B-FP8,base_url=http://127.0.0.1:9010/v1/completions,num_concurrent=30,max_retries=3 --limit 100 

pytest : pytest -s tests/kernels/moe/test_modular_kernel_combinations.py from #20449

Test Result

without chunking:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  | 0.86|±  |0.0349|
|     |       |strict-match    |     5|exact_match|↑  | 0.93|±  |0.0256|

with VLLM_FUSED_MOE_CHUNK_SIZE=64

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  | 0.82|±  |0.0386|
|     |       |strict-match    |     5|exact_match|↑  | 0.91|±  |0.0288|

pytests: Pass

(Optional) Documentation Update

Copy link

github-actions bot commented Jul 2, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @varun-sundar-rabindranath, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly refactors the chunking mechanism within the Fused Mixture of Experts (MoE) modular kernel. The changes aim to enhance the organization and maintainability of the code by extracting chunking responsibilities into dedicated functions. Additionally, it introduces a new, optimized kernel for counting tokens per expert, which is integrated into the chunking process to support more granular control and potential performance improvements.

Highlights

  • Refactoring MoE Chunking Logic: The core chunking logic for the Fused MoE Modular Kernel has been refactored into dedicated private methods (_do_fused_experts and _maybe_chunk_fused_experts) within modular_kernel.py. This improves modularity and readability by separating the chunking orchestration from the direct application of fused experts.
  • New Expert Token Counting Kernel: A new Triton-based kernel, _count_expert_num_tokens, and its Python wrapper count_expert_num_tokens have been introduced in utils.py. This kernel efficiently counts the number of tokens assigned to each expert, which is now utilized during the chunked processing in the MoE kernel.
  • Output Shape Adjustment in DeepGEMM MoE: The deep_gemm_moe.py module has been updated to change the expected output shape from (M * topk, K) to (M, topk, K) in its workspace_shapes method (line 86). Correspondingly, the apply method now uses output.view((-1, K)) (line 159) to adapt to this new shape during torch.index_select operations.
  • Comprehensive Testing for New Kernel: A new test file, test_count_expert_num_tokens.py, has been added to thoroughly validate the count_expert_num_tokens kernel across various configurations of tokens, top-k values, experts, and expert parallelism sizes. It includes a reference CPU implementation for correctness verification.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

mergify bot commented Jul 2, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @varun-sundar-rabindranath.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jul 2, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the chunking logic in the Fused MoE modular kernel and introduces a new utility function count_expert_num_tokens with a corresponding Triton kernel and tests. The refactoring improves code structure, but I've identified a potential performance regression related to workspace memory allocation within the new chunking loop. Additionally, I've raised concerns about the output shape in deep_gemm_moe.py and the reshaping of the output tensor before torch.index_select.

c_expert_num_tokens = None
if expert_num_tokens is not None:
c_expert_num_tokens = slice_expert_num_tokens(
c_topk_ids, local_num_experts, expert_map)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

expert_num_tokens has the number tokens assigned to each expert. This needs to be updated / re-calculated during chunking for correctness.

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a summary of why you're making these changes to the PR description?

tl.store(expert_num_tokens_ptr + curr_expert, tl.sum(acc))


def count_expert_num_tokens(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a docstring?

Copy link

mergify bot commented Jul 9, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @varun-sundar-rabindranath.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jul 9, 2025
@varun-sundar-rabindranath varun-sundar-rabindranath changed the title [Misc] Fused MoE Modular Kernel : Refactor Chunking loop [Bugfix] Fused MoE Modular Kernel chunking loop Jul 9, 2025
@varun-sundar-rabindranath varun-sundar-rabindranath force-pushed the varun/refactor-chunking-loop branch 2 times, most recently from d485f3c to a339feb Compare July 9, 2025 21:18
@mergify mergify bot removed the needs-rebase label Jul 9, 2025
Copy link

mergify bot commented Jul 10, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @varun-sundar-rabindranath.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jul 10, 2025
Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm once merge conflicts are resolved

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
@varun-sundar-rabindranath varun-sundar-rabindranath force-pushed the varun/refactor-chunking-loop branch from a339feb to 2658f56 Compare July 10, 2025 14:56
@mergify mergify bot removed the needs-rebase label Jul 10, 2025
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
@tlrmchlsmth tlrmchlsmth enabled auto-merge (squash) July 10, 2025 16:46
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 10, 2025
@tlrmchlsmth tlrmchlsmth merged commit fdadb6f into vllm-project:main Jul 10, 2025
76 checks passed
Chen-zexi pushed a commit to Chen-zexi/vllm that referenced this pull request Jul 13, 2025
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
patrickvonplaten pushed a commit to patrickvonplaten/vllm that referenced this pull request Jul 15, 2025
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
LyrisZhong pushed a commit to LyrisZhong/vllm that referenced this pull request Jul 23, 2025
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
avigny pushed a commit to avigny/vllm that referenced this pull request Jul 31, 2025
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Signed-off-by: avigny <47987522+avigny@users.noreply.github.com>
Pradyun92 pushed a commit to Pradyun92/vllm that referenced this pull request Aug 6, 2025
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Signed-off-by: Paul Pak <paulpak58@gmail.com>
taneem-ibrahim pushed a commit to taneem-ibrahim/vllm that referenced this pull request Aug 14, 2025
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Signed-off-by: Diego-Castan <diego.castan@ibm.com>
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 27, 2025
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
googlercolin pushed a commit to googlercolin/vllm that referenced this pull request Aug 29, 2025
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants