Skip to content

Conversation

bnellnm
Copy link
Contributor

@bnellnm bnellnm commented Apr 2, 2025

This PR defines a set of base classes used to make MoE kernels more modular. The goal is to be able to utilize different communication mechanisms with any fused MoE kernel without needing to have combinatoric implementations.

The fused moe kernels are broken down into the following components:

[Router] → [Quantize-Dispatch] → [Permute-Experts-Unpermute] → [Combine]                                                                                                       

Each component will be independent of the others except for [Quantize-Dispatch] and `[Combine] (see below). The components can then be mixed and matched with so that DP+EP can be supported easily for multiple MoE kernel implementations.

The following main classes are defined:

  • FusedMoEQuantizeDispatchCombine - an abstract base class for quantization, dispatching and combing. The dispatch method takes care of any needed quantization and the combine method applies weights and does the final reduction of the output.
  • FusedMoEPermuteExpertsUnpermute - an abstract base class for the main fused MoE operation. One important feature to note is that this class does not apply topk weights or reduce the final output.
  • FusedMoEModularKernel - an interface class that combines a FusedMoEQuantizeDispatchCombine and a
  • FusedMoEPermuteExpertsUnpermute to provide the standard fused MoE kernel interface.
  • StandardDispatchCombine - a concrete class that can be used for serial Triton, DeepGemm and CUTLASS moe implementations.

The implementations for the DeepGemm and CUTLASS moe functions have been replaced with the modularized versions. There's also a modularized version for the Triton kernels but it will not be enabled by default.

[Quantize-Dispatch] and [Combine] functionality are bundled into a single class FusedMoEQuantizeDispatchCombine since they could use collective communication mechanisms that need to be consistent.

cc @ElizaWszola , @varun-sundar-rabindranath

Copy link

github-actions bot commented Apr 2, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@bnellnm bnellnm force-pushed the modular-fused-experts branch from 96b7d5b to fc3243d Compare April 3, 2025 19:39
@bnellnm bnellnm marked this pull request as ready for review April 3, 2025 23:20
@tlrmchlsmth
Copy link
Collaborator

For readers: We're doing this to support the pplx-kernel integration. We can use this structure for DeepEP as well.

Right now our fused MoE is implemented as something very very roughly like:

[Router] → [Quantize] → [Experts + topk_weight scaling + reduction]

This is a problem as the topk_weight scaling and reduction now need to happen during combine. We need to fit dispatch in there as well.

This PR defines a set of base classes used to make MoE kernels more modular. The goal is to be able to utilize different communication mechanisms with any fused MoE kernel without needing to have combinatoric implementations.

The fused moe kernels are broken down into the following components:

[Router] → [Quantize-Dispatch] → [Permute-Experts-Unpermute] → [Combine]

☝️ This is what @bnellnm and I agreed on. However:

The other option we originally considered was:

[Router] → [Quantize-Dispatch-Permute] → [Experts] → [Unpermute-Combine]

Right now I am thinking that permute/unpermute will (unfortunately) depend both on the implementation of both dispatch/combine and experts, so we should consider breaking that out.

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice and clean

@bnellnm bnellnm changed the title Modular fused experts Modularize fused experts and integrate pplx kernels Apr 4, 2025
@bnellnm bnellnm changed the title Modularize fused experts and integrate pplx kernels Modularize fused experts and integrate PPLX kernels Apr 4, 2025
abcdabcd987 pushed a commit to perplexityai/pplx-kernels that referenced this pull request Apr 4, 2025
…nt for users. (#2)

Being able to query some of the setup parameters from the AllToAll class
would make client code a bit simpler/safer, e.g. see
pplx_dispatch_combine.py from
vllm-project/vllm#15956

cc @abcdabcd987 , @tlrmchlsmth

Signed-off-by: Bill Nell <bnell@redhat.com>
@abcdabcd987
Copy link

One thing I forgot to put in our examples is -- Please call the destructor! ata.destroy()

@bnellnm
Copy link
Contributor Author

bnellnm commented Apr 4, 2025

One thing I forgot to put in our examples is -- Please call the destructor! ata.destroy()

Do all references to an AllToAll need to be destroyed? The current plan is to have a cache (that will be in a different PR) manage all the AllToAll instances and the PplxDispatchCombine would hold on to a reference of one of the cached objects.

@abcdabcd987
Copy link

Do all references to an AllToAll need to be destroyed?

No. Call destroy() only when you are shutting down the engine (or removing the model from GPU, etc...)

You are right to cache the object. It is supposed to be reused across layers and across runs.

Copy link

mergify bot commented Apr 9, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @bnellnm.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Apr 9, 2025
@bnellnm bnellnm force-pushed the modular-fused-experts branch from 63f1297 to 8635098 Compare April 29, 2025 02:09
@mergify mergify bot removed the needs-rebase label Apr 29, 2025
@mergify mergify bot added the documentation Improvements or additions to documentation label Apr 29, 2025
@bnellnm bnellnm force-pushed the modular-fused-experts branch 2 times, most recently from 9803866 to f74ab61 Compare April 30, 2025 21:44
Copy link

mergify bot commented May 2, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @bnellnm.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label May 2, 2025
@bnellnm bnellnm force-pushed the modular-fused-experts branch from 5d960df to b04e5d3 Compare May 7, 2025 15:24
bnellnm and others added 11 commits May 14, 2025 14:55
Signed-off-by: Bill Nell <bnell@redhat.com>
Signed-off-by: Bill Nell <bnell@redhat.com>
Signed-off-by: Bill Nell <bnell@redhat.com>
Signed-off-by: Bill Nell <bnell@redhat.com>
Signed-off-by: Bill Nell <bnell@redhat.com>
Signed-off-by: Bill Nell <bnell@redhat.com>
Signed-off-by: Bill Nell <bnell@redhat.com>
Signed-off-by: Bill Nell <bnell@redhat.com>
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Signed-off-by: Bill Nell <bnell@redhat.com>
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Signed-off-by: Bill Nell <bnell@redhat.com>
Signed-off-by: Bill Nell <bnell@redhat.com>
@bnellnm bnellnm force-pushed the modular-fused-experts branch from 705da89 to 1f91cfd Compare May 14, 2025 15:35
@mergify mergify bot removed the needs-rebase label May 14, 2025
Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM once green!

@tlrmchlsmth tlrmchlsmth enabled auto-merge (squash) May 14, 2025 16:09
@bnellnm
Copy link
Contributor Author

bnellnm commented May 14, 2025

I've verified that the following tests fail on main

pytest -sv tests/spec_decode/e2e/test_medusa_correctness.py::test_medusa_e2e_greedy_correctness_with_preemption[-1-1-4-128-test_llm_kwargs0-baseline_llm_kwargs0-per_test_common_llm_kwargs0-common_llm_kwargs0]
pytest -sv tests/spec_decode/e2e/test_eagle_correctness.py::test_eagle_e2e_greedy_correctness_with_preemption[1-4-128-test_llm_kwargs0-baseline_llm_kwargs0-per_test_common_llm_kwargs0-common_llm_kwargs0]

I'm not sure what is happening with the v1 test. I can't get it to fail locally on main or with this branch. But the model being run does not contain any MoE layers, so I think this failure is not related to this PR.

@simon-mo simon-mo merged commit f9c069c into vllm-project:main May 14, 2025
87 of 90 checks passed
zzzyq pushed a commit to zzzyq/vllm that referenced this pull request May 24, 2025
Signed-off-by: Yuqi Zhang <yuqizhang@google.com>
minpeter pushed a commit to minpeter/vllm that referenced this pull request Jun 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build documentation Improvements or additions to documentation ready ONLY add when PR is ready to merge/full CI is needed tpu Related to Google TPUs v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants