-
Notifications
You must be signed in to change notification settings - Fork 2.8k
feat: integrate deepgemm into EPMoE #5805
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: integrate deepgemm into EPMoE #5805
Conversation
@@ -47,6 +55,8 @@ | |||
|
|||
logger = logging.getLogger(__name__) | |||
|
|||
epmoe_use_deepgemm = get_bool_env_var("EPMOE_USE_DEEPGEMM") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_ENABLE_JIT_DEEPGEMM = True |
We might import it directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, do you mean we just replace EPMOE_USE_DEEPGEMM with _ENABLE_JIT_DEEPGEMM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, enabling _ENABLE_JIT_DEEPGEMM will set deepgemm at epmoe as the default configuration.
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): | ||
if use_deep_gemm and epmoe_use_deepgemm: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why disable EPMOE DeepGEMM when use_deep_gemm is enabled?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe forward_deepgemm is called when use_deep_gemm is enabled.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are there any cases where Triton GEMM in forward_normal outperforms DeepGEMM?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As for now, I didn't find any case where Triton GEMM in forward_normal outperforms DeepGEMM, but DeepGEMM may occupy more GPU memory.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could remove epmoe_use_deepgemm and corresponding Environment variable EPMOE_USE_DEEPGEMM for the sake of clarity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, done
@xutizhou Could you please help me merge this |
Sure, I need some time to review and test. |
Hi @TianQiLin666666 May you help fix the conflicts? Thanks! |
|
||
|
||
def exp2_upper(num: int) -> int: | ||
for i in range(2, 31): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why does the variable num start from 2**2=4
Hi @TianQiLin666666 Thanks for the great work! Due to this PR not being updated for a while, and our high interest in this feature, I have asked @xutizhou to make some fixes and optimizations based on your work in a new PR #6821. We will add you as a co-author in the new PR. Thank you for your understanding and help. |
Motivation
For normal EPMoE (no DeepEP), integrate DeepGEMM as an option.
Modifications
forward_deepgemm
in EPMoE. Use envEPMOE_USE_DEEPGEMM
to enable it.forward_deepgemm
.Evaluation
Speed
With 2H20-96G8 for EP16, enabling EPMOE_USE_DEEPGEMM leads to a 14% throughput gain.
Accuracy
MMLU test with mmlu/bench_sglang.py
Checklist