-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Optimize Permute Kernel in DeepEP #4643
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
||
def deepep_run_moe_deep_preprocess(topk_ids: torch.Tensor, num_experts: int): | ||
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True) | ||
seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It can be init using torch.empty
deepep_compute_src2dst_triton_kernel[grid]( | ||
reorder_ids, src2dst, topk_ids.numel(), num_minus_one, BLOCK_SIZE | ||
) | ||
# src2dst -= num_minus_one |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
debugging code?
@@ -17,6 +17,116 @@ | |||
logger = logging.getLogger(__name__) | |||
|
|||
|
|||
@triton.jit | |||
def compute_src2dst_triton_kernel( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
compute_src2dst_triton_kernel
and deepep_compute_src2dst_triton_kernel
are defined twice.
|
||
|
||
@triton.jit | ||
def deepep_compute_src2dst_triton_kernel( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why developing a triton kernel is necessary? Is it faster?
|
||
|
||
@triton.jit | ||
def deepep_permute_triton_kernel( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is defined twice.
|
||
|
||
@triton.jit | ||
def deepep_post_reorder_triton_kernel( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is defined twice.
output = torch.zeros( | ||
(num_tokens, hidden_states.shape[1]), | ||
device=hidden_states.device, | ||
dtype=hidden_states.dtype, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use torch.empty?
@@ -294,7 +294,7 @@ def forward_deepep( | |||
correction_bias=self.correction_bias, | |||
) | |||
if self.tp_size > 1: | |||
recv_hidden_states, topk_idx, topk_weights, tokens_per_expert = ( | |||
recv_hidden_states, reorder_topk_ids, seg_indptr = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we add some short comments on the meaning/examples of reorder_topk_ids and seg_indptr for readability?
Will there be further optimization plans for this permute kernel? |
We will continue to optimize the permute kernel, but it is not our top priority at the moment. |
The observed issue could potentially be attributed to ROCE network configuration. To verify this hypothesis, we recommend running the inter-node communication test from DeepEP's validation suite, specifically the internode connectivity check |
Motivation
The current performance of DeepEP is suboptimal due to the low efficiency of PyTorch's native permute function, which is used for formatting data before and after DeepEP communication. To address this limitation, we have implemented high-efficiency Triton kernels that significantly improve overall performance.
Performance on H20
Single Node
Command
Multi Node
Command
Modifications
Checklist