Skip to content

Conversation

shixianc
Copy link
Contributor

@shixianc shixianc commented Jul 9, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

We integrated the new moe_align_block_size cuda kernel but realized the sorted_ids from new CUDA kernel does not preserve original input token order as Triton kernel does. This does not impact quality but the matrix accumulation order of final output has changed, and we caught this on our internal quality tests.

Existing test_moe_align_block_size.py only compares the CUDA kernel against previous Triton kernel which is going to be deprecated based on the comment.
Therefore I added a torch golden implementation for moe_align_block_size kernel to ensure:

  1. final sorted tokens assigned to the same expert are the same between cuda vs. torch golden
  2. determinism across multiple CUDA kernel invocations for reproducibility

Test Plan

pytest tests/kernels/moe/test_moe_align_block_size.py

Test Result

tests/kernels/moe/test_moe_align_block_size.py .............................................................................................................................. [ 11%]
............................................................................................................................................................................. [ 26%]
............................................................................................................................................................................. [ 41%]
............................................................................................................................................................................. [ 57%]
............................................................................................................................................................................. [ 72%]
............................................................................................................................................................................. [ 88%]
...................................................................................................................................... [100%]

1125 passed in 289.91s (0:04:49)

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @shixianc, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the testing framework for the moe_align_block_size CUDA kernel. My primary goal was to provide a robust PyTorch-based 'golden' reference implementation and comprehensive validation tests to ensure the CUDA kernel's correctness, determinism, and proper handling of expert mapping, addressing observed discrepancies in token ordering compared to the deprecated Triton kernel.

Highlights

  • New Golden Reference Implementation: I've introduced a pure PyTorch implementation, torch_moe_align_block_size, to serve as a 'golden' reference for validating the moe_align_block_size CUDA kernel. This addresses the need for a stable comparison baseline after the new CUDA kernel's introduction.
  • Enhanced Kernel Validation Logic: I've added new helper functions, _group_tokens_by_expert and _verify_expert_level_sorting, to enable robust expert-level validation. This ensures that while the exact token order within an expert's block might differ from the golden reference, the set of tokens assigned to each expert remains consistent and correct.
  • Comprehensive Test Coverage: The existing comparison test (CUDA vs Triton) has been replaced with a more extensive test suite. New parameterized tests cover various configurations, including scenarios with expert mapping and explicit checks for the CUDA kernel's determinism across multiple invocations.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request adds a golden PyTorch implementation for the moe_align_block_size kernel, enhancing testing. Consider refactoring the golden implementation to improve efficiency by vectorizing operations.

@shixianc
Copy link
Contributor Author

shixianc commented Jul 9, 2025

@yewentao256 related to your previous change, could you take a look? thanks.

Copy link

github-actions bot commented Jul 9, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@shixianc shixianc force-pushed the moe-unittest branch 2 times, most recently from 6a7e7ef to 408565b Compare July 9, 2025 20:25
Copy link
Collaborator

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work!

We integrated the new moe_align_block_size cuda kernel but realized the sorted_ids from new CUDA kernel does not preserve original input token order as Triton kernel does. This does not impact quality but the matrix accumulation order of final output has changed, and we caught this on our internal quality tests.

Could you give more details about this? So will this test fails in current development?

@shixianc
Copy link
Contributor Author

shixianc commented Jul 10, 2025

Thanks for the work!

We integrated the new moe_align_block_size cuda kernel but realized the sorted_ids from new CUDA kernel does not preserve original input token order as Triton kernel does. This does not impact quality but the matrix accumulation order of final output has changed, and we caught this on our internal quality tests.

Could you give more details about this? So will this test fails in current development?

This updated test will pass because I'm not checking exact match on sorted_ids instead using _verify_expert_level_sorting function to check token ids are matching for each expert.

Take an example for explaining what happens:

  • topk_ids = [[0, 1], [0, 0]], assume 2 experts, topk=2, m=2. block_size = 4
  • sorted_ids from Triton would look like:
  • [0, 1, 4, 4, 0, 4, 4, 4] (padding token == 4).
  • however CUDA may return [1, 0, 4, 4, 0, 4, 4, 4]

This does not impact model output quality, it's just new kernel changes the output matrix accumulation order, so we saw there's a chance that output token changed within its synonyms.

Maybe this is also the reason we didn't do torch.allclose on sorted_ids previously?

@shixianc
Copy link
Contributor Author

@yewentao256 could you see the comment above, any more concerns?

Copy link
Collaborator

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, thank you for the contribution!

@yewentao256
Copy link
Collaborator

I don't have the permission to merge, so @houseroad could you take a look if possible?

Copy link
Collaborator

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@houseroad houseroad added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 18, 2025
@houseroad
Copy link
Collaborator

Could you rebase and re-trigger the CI? @shixianc

@houseroad houseroad enabled auto-merge (squash) July 18, 2025 19:57
Signed-off-by: Shixian Cui <shixian@amazon.com>
auto-merge was automatically disabled July 18, 2025 20:27

Head branch was pushed to by a user without write access

@vllm-bot vllm-bot merged commit 7d94577 into vllm-project:main Jul 19, 2025
41 of 43 checks passed
LyrisZhong pushed a commit to LyrisZhong/vllm that referenced this pull request Jul 23, 2025
…ect#20653)

Signed-off-by: Shixian Cui <shixian@amazon.com>
Co-authored-by: Shixian Cui <shixian@amazon.com>
avigny pushed a commit to avigny/vllm that referenced this pull request Jul 31, 2025
…ect#20653)

Signed-off-by: Shixian Cui <shixian@amazon.com>
Co-authored-by: Shixian Cui <shixian@amazon.com>
Signed-off-by: avigny <47987522+avigny@users.noreply.github.com>
x22x22 pushed a commit to x22x22/vllm that referenced this pull request Aug 5, 2025
…ect#20653)

Signed-off-by: Shixian Cui <shixian@amazon.com>
Co-authored-by: Shixian Cui <shixian@amazon.com>
Signed-off-by: x22x22 <wadeking@qq.com>
Pradyun92 pushed a commit to Pradyun92/vllm that referenced this pull request Aug 6, 2025
…ect#20653)

Signed-off-by: Shixian Cui <shixian@amazon.com>
Co-authored-by: Shixian Cui <shixian@amazon.com>
npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
…ect#20653)

Signed-off-by: Shixian Cui <shixian@amazon.com>
Co-authored-by: Shixian Cui <shixian@amazon.com>
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
…ect#20653)

Signed-off-by: Shixian Cui <shixian@amazon.com>
Co-authored-by: Shixian Cui <shixian@amazon.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
…ect#20653)

Signed-off-by: Shixian Cui <shixian@amazon.com>
Co-authored-by: Shixian Cui <shixian@amazon.com>
Signed-off-by: Paul Pak <paulpak58@gmail.com>
taneem-ibrahim pushed a commit to taneem-ibrahim/vllm that referenced this pull request Aug 14, 2025
…ect#20653)

Signed-off-by: Shixian Cui <shixian@amazon.com>
Co-authored-by: Shixian Cui <shixian@amazon.com>
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
…ect#20653)

Signed-off-by: Shixian Cui <shixian@amazon.com>
Co-authored-by: Shixian Cui <shixian@amazon.com>
Signed-off-by: Diego-Castan <diego.castan@ibm.com>
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 27, 2025
…ect#20653)

Signed-off-by: Shixian Cui <shixian@amazon.com>
Co-authored-by: Shixian Cui <shixian@amazon.com>
googlercolin pushed a commit to googlercolin/vllm that referenced this pull request Aug 29, 2025
…ect#20653)

Signed-off-by: Shixian Cui <shixian@amazon.com>
Co-authored-by: Shixian Cui <shixian@amazon.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants