Skip to content

Conversation

ElizaWszola
Copy link
Contributor

@ElizaWszola ElizaWszola commented Jul 17, 2025

The way we obtain N from _moe_problem_size() caused the function to allocate up to twice as much memory in cutlass_moe's workspace_shapes() as needed. The fix brings down the size of workspaces to what actually reflects the size needed to store cutlass_moe's intermediate tensors.

Testing

Tested for correctness with offline inference for nm-testing/DeepSeek-Coder-V2-Lite-Instruct-FP8 and

pytest tests/kernels/moe/test_cutlass_moe.py -k test_cutlass_moe_8_bit_EP_large
pytest tests/kernels/moe/test_cutlass_moe.py -k test_cutlass_moe_8_bit_cuda_graph

Signed-off-by: ElizaWszola <ewszola@redhat.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request reduces memory allocation in non-batched CUTLASS MoE by optimizing the calculation of N in _moe_problem_size(), leading to smaller workspace sizes. No specific issues were identified requiring immediate attention based on the provided changes and review criteria.

Comment on lines -288 to +289
workspace1 = (M * topk, max(2 * N, K))
workspace2 = (M * topk, N)
workspace1 = (M * topk, max(N, K))
workspace2 = (M * topk, N // 2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason why this is the inverse of TritonExperts?

workspace1 = (M, topk, max(N // 2, K))
workspace2 = (M, topk, max(N, K))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like the variables are just flipped.

In Triton (workspace13 is the same tensor as workspace1):
https://github.com/vllm-project/vllm/blob/9fb2d22032cee577a189f8c4cddd88a3c190cb0c/vllm/model_executor/layers/fused_moe/fused_moe.py#L1702C1-L1707C72

In CUTLASS:
https://github.com/vllm-project/vllm/blob/9fb2d22032cee577a189f8c4cddd88a3c190cb0c/vllm/model_executor/layers/fused_moe/cutlass_moe.py#L168C1-L170C55

The CUTLASS names seem to make more sense given what we assign the resized workspace tensors to

@mgoin mgoin added the bug Something isn't working label Jul 17, 2025
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks Eliza!

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 17, 2025
Signed-off-by: ElizaWszola <ewszola@redhat.com>
@DarkLight1337 DarkLight1337 merged commit 4adc66f into vllm-project:main Jul 18, 2025
67 checks passed
WorldExplored pushed a commit to nadathurv/vllm that referenced this pull request Jul 19, 2025
WorldExplored pushed a commit to nadathurv/vllm that referenced this pull request Jul 19, 2025
…t#21121)

Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: WorldExplored <srreyansh.sethi@gmail.com>
WorldExplored pushed a commit to nadathurv/vllm that referenced this pull request Jul 19, 2025
…t#21121)

Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: WorldExplored <srreyansh.sethi@gmail.com>
hj-mistral pushed a commit to hj-mistral/vllm that referenced this pull request Jul 19, 2025
…t#21121)

Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: Himanshu Jaju <hj@mistral.ai>
LyrisZhong pushed a commit to LyrisZhong/vllm that referenced this pull request Jul 23, 2025
avigny pushed a commit to avigny/vllm that referenced this pull request Jul 31, 2025
…t#21121)

Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: avigny <47987522+avigny@users.noreply.github.com>
x22x22 pushed a commit to x22x22/vllm that referenced this pull request Aug 5, 2025
…t#21121)

Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: x22x22 <wadeking@qq.com>
Pradyun92 pushed a commit to Pradyun92/vllm that referenced this pull request Aug 6, 2025
npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
…t#21121)

Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
…t#21121)

Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: Paul Pak <paulpak58@gmail.com>
taneem-ibrahim pushed a commit to taneem-ibrahim/vllm that referenced this pull request Aug 14, 2025
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
…t#21121)

Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: Diego-Castan <diego.castan@ibm.com>
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 27, 2025
googlercolin pushed a commit to googlercolin/vllm that referenced this pull request Aug 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants