Skip to content

Conversation

xutizhou
Copy link
Collaborator

@xutizhou xutizhou commented Mar 21, 2025

Motivation

The current performance of DeepEP is suboptimal due to the low efficiency of PyTorch's native permute function, which is used for formatting data before and after DeepEP communication. To address this limitation, we have implemented high-efficiency Triton kernels that significantly improve overall performance.

Co-authored-by: @zhou9402

Performance on H20

Single Node

Command

python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --trust-remote-code --tp 8 --dp 8 --host 0.0.0.0 --port 30000 --enable-dp-attention --enable-deepep-moe --max-running-requests 128 --disable-radix-cache --mem-fraction-static 0.9 --stream-output --disable-cuda-graph

python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 512 --random-input 1000 --random-output 1000 --random-range-ratio 1 --host 127.0.0.1 --port 30000 --max-concurrency 128
Version Concurrency Input Output Num Requests Input Throughput(tok/s) Output Throughput (tok/s) Total Throughput (tok/s)
DeepEP(original) 127.97 1000 1000 512 436.69 436.69 873.38
DeepEP(current) 127.97 1000 1000 512 581.94 581.94 1163.87

Multi Node

Command

# node 0
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --trust-remote-code \
  --tp 16 --dp 16  --dist-init-addr 10.6.131.5:5000 --nnodes 2 --node-rank 0 \
  --enable-dp-attention --enable-deepep-moe \
  --disable-cuda-graph
# node 1
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --trust-remote-code \
  --tp 16 --dp 16  --dist-init-addr 10.6.131.5:5000 --nnodes 2 --node-rank 1 \
  --enable-dp-attention --enable-deepep-moe \
  --disable-cuda-graph
Version Concurrency Input Output Num Requests Input Throughput(tok/s) Output Throughput (tok/s) Total Throughput (tok/s)
DeepEP(current) 255.93 1000 1000 512 956.36 956.36 1912.71
DeepEP(current) 511.31 1000 1000 1024 1711.54 1711.54 3423.09
DeepEP(current) 1023.17 1000 1000 2048 2974.21 2974.21 5948.42
DeepEP(current) 2046.18 1000 1000 4096 3929.73 3929.73 7859.46
EPMoe 255.55 1000 1000 512 868.55 868.55 1737.10
EPMoe 511.85 1000 1000 1024 1694.59 1694.59 3389.18
EPMoe 1022.27 1000 1000 2048 2735.53 2735.53 5471.06
EPMoe 2045.90 1000 1000 4096 3489.57 3489.57 6979.15

Modifications

Checklist

@xutizhou xutizhou marked this pull request as ready for review March 21, 2025 05:01

def deepep_run_moe_deep_preprocess(topk_ids: torch.Tensor, num_experts: int):
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can be init using torch.empty

deepep_compute_src2dst_triton_kernel[grid](
reorder_ids, src2dst, topk_ids.numel(), num_minus_one, BLOCK_SIZE
)
# src2dst -= num_minus_one
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

debugging code?

@@ -17,6 +17,116 @@
logger = logging.getLogger(__name__)


@triton.jit
def compute_src2dst_triton_kernel(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compute_src2dst_triton_kernel and deepep_compute_src2dst_triton_kernel are defined twice.



@triton.jit
def deepep_compute_src2dst_triton_kernel(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why developing a triton kernel is necessary? Is it faster?



@triton.jit
def deepep_permute_triton_kernel(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is defined twice.



@triton.jit
def deepep_post_reorder_triton_kernel(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is defined twice.

Comment on lines 351 to 355
output = torch.zeros(
(num_tokens, hidden_states.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use torch.empty?

@@ -294,7 +294,7 @@ def forward_deepep(
correction_bias=self.correction_bias,
)
if self.tp_size > 1:
recv_hidden_states, topk_idx, topk_weights, tokens_per_expert = (
recv_hidden_states, reorder_topk_ids, seg_indptr = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add some short comments on the meaning/examples of reorder_topk_ids and seg_indptr for readability?

@zhyncs zhyncs merged commit c2bd094 into sgl-project:main Mar 22, 2025
0 of 16 checks passed
@xutizhou xutizhou deleted the optimize_permute_kernel branch March 23, 2025 03:21
@xutizhou xutizhou restored the optimize_permute_kernel branch March 23, 2025 04:43
@Huixxi
Copy link

Huixxi commented Mar 24, 2025

Will there be further optimization plans for this permute kernel?

@xutizhou
Copy link
Collaborator Author

Will there be further optimization plans for this permute kernel?

We will continue to optimize the permute kernel, but it is not our top priority at the moment.

@ch-wan ch-wan mentioned this pull request Mar 24, 2025
18 tasks
@Huixxi
Copy link

Huixxi commented Mar 26, 2025

node 0
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --trust-remote-code
--tp 16 --dp 16 --dist-init-addr 10.6.131.5:5000 --nnodes 2 --node-rank 0
--enable-dp-attention --enable-deepep-moe
--disable-cuda-graph
node 1
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --trust-remote-code
--tp 16 --dp 16 --dist-init-addr 10.6.131.5:5000 --nnodes 2 --node-rank 1
--enable-dp-attention --enable-deepep-moe
--disable-cuda-graph

But, it seems that I can't reproduce the performance of deepseek on 2 * H800 x 8 with roce rdma. I don't know why.
image

@xutizhou
Copy link
Collaborator Author

node 0
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --trust-remote-code
--tp 16 --dp 16 --dist-init-addr 10.6.131.5:5000 --nnodes 2 --node-rank 0
--enable-dp-attention --enable-deepep-moe
--disable-cuda-graph
node 1
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --trust-remote-code
--tp 16 --dp 16 --dist-init-addr 10.6.131.5:5000 --nnodes 2 --node-rank 1
--enable-dp-attention --enable-deepep-moe
--disable-cuda-graph

But, it seems that I can't reproduce the performance of deepseek on 2 * H800 x 8 with roce rdma. I don't know why. image

The observed issue could potentially be attributed to ROCE network configuration. To verify this hypothesis, we recommend running the inter-node communication test from DeepEP's validation suite, specifically the internode connectivity check

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants