Skip to content

Conversation

liz-badada
Copy link
Contributor

@liz-badada liz-badada commented Mar 9, 2025

Motivation

Intergrate DeepEP into SGLang framework. Still WIP but could use '--enable-dp-attention --enable-deepep-moe' to trigger DeepEP intranode / internode, please follow the install guide of NVSHMEM dependency, also provide a Dockerfile.deepep based on SGLang image.

Co-auther @xutizhou

Note:

  • Currently need --disable-cuda-graph as a W/A of error: "Exception: Capture cuda graph failed: CUDA error: capturing stream has unjoined work"

Single node:

  • command
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --trust-remote-code \
  --tp 8 --dp 8 --host 0.0.0.0 --port 30000 \
  --enable-dp-attention --enable-deepep-moe \
  --disable-cuda-graph
  • test_openai.py
import openai
client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="EMPTY")
# Chat completion
response = client.chat.completions.create(
    model="default",
    messages=[
        {"role": "system", "content": "You are a helpful AI assistant"},
        {"role": "user", "content": "List 3 countries and their capitals."},
    ],
    temperature=0,
    max_tokens=64,
)
print(response)
  • python3 test_openai.py
    ChatCompletion(id='c72b3edaf08f4145a53d497428c534d9', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='1. France - Capital: Paris \n2. Japan - Capital: Tokyo \n3. Brazil - Capital: Brasília', refusal=None, role='assistant', audio=None, function_call=None, tool_calls=None, reasoning_content=None), matched_stop=1)], created=1741594490, model='default', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=38, prompt_tokens=17, total_tokens=55, completion_tokens_details=None, prompt_tokens_details=None))

Multi node:

  • command
# node 0
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --trust-remote-code \
  --tp 16 --dp 16  --dist-init-addr 10.6.131.5:5000 --nnodes 2 --node-rank 0 \
  --enable-dp-attention --enable-deepep-moe \
  --disable-cuda-graph
# node 1
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --trust-remote-code \
  --tp 16 --dp 16  --dist-init-addr 10.6.131.5:5000 --nnodes 2 --node-rank 1 \
  --enable-dp-attention --enable-deepep-moe \
  --disable-cuda-graph
  • python3 test_openai.py
    ChatCompletion(id='72a8328a7ca14e98b2c10604dfbee7ee', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='Sure! Here are three countries and their capitals:\n\n1. France - Paris \n2. Japan - Tokyo \n3. Brazil - Brasília', refusal=None, role='assistant', audio=None, function_call=None, tool_calls=None, reasoning_content=None), matched_stop=1)], created=1741605820, model='default', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=35, prompt_tokens=17, total_tokens=52, completion_tokens_details=None, prompt_tokens_details=None))

Performance (Current performance is below expectations as token permutation is not yet optimized. Due to some bugs in the permute triton kernel, we have temporarily fallen back to using PyTorch's native permute function):

  • command
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --trust-remote-code --tp 8 --dp 8 --host 0.0.0.0 --port 30000 --enable-dp-attention --enable-deepep-moe --max-running-requests 128 --disable-radix-cache --mem-fraction-static 0.9 --stream-output --disable-cuda-graph

python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 512 --random-input 1000 --random-output 1000 --random-range-ratio 1 --host 127.0.0.1 --port 30000 --max-concurrency 128
  • on H20
Parallel Concurrency Input Output Num Requests Input Throughput (tok/s) Output Throughput (tok/s) Total Throughput (tok/s)
DP Attn + DeepEP 127.97 1000 1000 512 436.69 436.69 873.38
DP Attn + EP 127.97 1000 1000 512 617.90 617.90 1235.79

Modifications

Checklist

@zhyncs
Copy link
Member

zhyncs commented Mar 9, 2025

@ch-wan

@liz-badada liz-badada force-pushed the Integrate_DeepEP_into_SGLang branch from aba1c69 to 3223a15 Compare March 10, 2025 03:36
@zhyncs
Copy link
Member

zhyncs commented Mar 10, 2025

@liz-badada May you rebase the latest main?

@liz-badada
Copy link
Contributor Author

@liz-badada May you rebase the latest main?

done

@MiterV1
Copy link
Contributor

MiterV1 commented Mar 11, 2025

Hi liz-badada,

I want to reproduce the code locally.
Please help me list the CUDA version, GPU driver version, and other system information you are using.

Thank you.

@liz-badada
Copy link
Contributor Author

liz-badada commented Mar 11, 2025

Hi liz-badada,

I want to reproduce the code locally. Please help me list the CUDA version, GPU driver version, and other system information you are using.

Thank you.

Hi, please make sure you have already setup GDRCopy on host and some information:

  • CUDA Version: 12.4
  • Driver Version: 550.127.08
  • GPU: H20-3e

@liz-badada liz-badada changed the title Integrate deepEP into SGLang Integrate DeepEP into SGLang Mar 12, 2025
@DeepTecher
Copy link

DeepTecher commented Mar 13, 2025

Great Job~ Is there any data about the performance of throughput/latency benchmark results

@Xiaofei-fei
Copy link

Thank you and the amazing open-source community! However, while working with your repository, I encountered the following CUDA out-of-memory error. I’ve tried multiple approaches without success and would appreciate any suggestions you might have.

torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.73 GiB. GPU 0 has a total capacity of 95.10 GiB of which 612.19 MiB is free. Process 777271 has 94.49 GiB memory in use. Of the allocated memory 91.27 GiB is allocated by PyTorch, and 8.74 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

@MiterV1
Copy link
Contributor

MiterV1 commented Mar 13, 2025

Hi liz-badada,
I want to reproduce the code locally. Please help me list the CUDA version, GPU driver version, and other system information you are using.
Thank you.

Hi, please make sure you have already setup GDRCopy on host and some information:

  • CUDA Version: 12.4
  • Driver Version: 550.127.08
  • GPU: H20-3e

Hi liz-badada,

Please help me to check some system information.
Pls run these commands:

  1. cat /etc/lsb-release
  2. uname -a
  3. ofed_info
  4. dpkg -l | grep NVIDIA

thanks a lot.

@MiterV1
Copy link
Contributor

MiterV1 commented Mar 13, 2025

Thank you and the amazing open-source community! However, while working with your repository, I encountered the following CUDA out-of-memory error. I’ve tried multiple approaches without success and would appreciate any suggestions you might have.

torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.73 GiB. GPU 0 has a total capacity of 95.10 GiB of which 612.19 MiB is free. Process 777271 has 94.49 GiB memory in use. Of the allocated memory 91.27 GiB is allocated by PyTorch, and 8.74 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

Hi Xiaofei-fei,

Try to set --mem-fraction-static 0.9, which expend the memory usage of the KV cache memory pool and helps both prefill and decoding.

@Xiaofei-fei
Copy link

Thank you and the amazing open-source community! However, while working with your repository, I encountered the following CUDA out-of-memory error. I’ve tried multiple approaches without success and would appreciate any suggestions you might have.

torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.73 GiB. GPU 0 has a total capacity of 95.10 GiB of which 612.19 MiB is free. Process 777271 has 94.49 GiB memory in use. Of the allocated memory 91.27 GiB is allocated by PyTorch, and 8.74 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

Hi Xiaofei-fei,

Try to set --mem-fraction-static 0.9, which expend the memory usage of the KV cache memory pool and helps both prefill and decoding.

Thank you for your reply. I’m using an H20 GPU with 96GB of VRAM. Below are the solutions I’ve tried:

--mem-fraction-static 0.9
export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True"
torch.cuda.empty_cache()

However, the problem remains unresolved.

@liz-badada
Copy link
Contributor Author

Hi liz-badada,
I want to reproduce the code locally. Please help me list the CUDA version, GPU driver version, and other system information you are using.
Thank you.

Hi, please make sure you have already setup GDRCopy on host and some information:

  • CUDA Version: 12.4
  • Driver Version: 550.127.08
  • GPU: H20-3e

Hi liz-badada,

Please help me to check some system information. Pls run these commands:

  1. cat /etc/lsb-release
  2. uname -a
  3. ofed_info
  4. dpkg -l | grep NVIDIA

thanks a lot.

Please check here: sys_info.zip

@liz-badada
Copy link
Contributor Author

Thank you and the amazing open-source community! However, while working with your repository, I encountered the following CUDA out-of-memory error. I’ve tried multiple approaches without success and would appreciate any suggestions you might have.

torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.73 GiB. GPU 0 has a total capacity of 95.10 GiB of which 612.19 MiB is free. Process 777271 has 94.49 GiB memory in use. Of the allocated memory 91.27 GiB is allocated by PyTorch, and 8.74 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

Hi Xiaofei-fei,
Try to set --mem-fraction-static 0.9, which expend the memory usage of the KV cache memory pool and helps both prefill and decoding.

Thank you for your reply. I’m using an H20 GPU with 96GB of VRAM. Below are the solutions I’ve tried:

--mem-fraction-static 0.9
export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True"
torch.cuda.empty_cache()

However, the problem remains unresolved.

For 1 node, may you try with smaller value like '--mem-fraction-static 0.8'? And also please use V3 instead of R1

@Sun1Plus
Copy link

Sun1Plus commented Mar 19, 2025

Did anyone encountered the following ERROR? when I was running "tp=16 dp=16", there's an error:

[2025-03-19 19:58:45 TP13] TpModelWorkerClient hit an exception: Traceback (most recent call last):
  File "/DATA/disk1/sunyijia/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 112, in forward_thread_func
    self.forward_thread_func_()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/DATA/disk1/sunyijia/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 143, in forward_thread_func_
    logits_output, next_token_ids = self.worker.forward_batch_generation(
  File "/DATA/disk1/sunyijia/sglang/python/sglang/srt/managers/tp_worker.py", line 172, in forward_batch_generation
    logits_output = self.model_runner.forward(forward_batch)
  File "/DATA/disk1/sunyijia/sglang/python/sglang/srt/model_executor/model_runner.py", line 976, in forward
    return self.forward_extend(
  File "/DATA/disk1/sunyijia/sglang/python/sglang/srt/model_executor/model_runner.py", line 937, in forward_extend
    return self.model.forward(
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/DATA/disk1/sunyijia/sglang/python/sglang/srt/models/deepseek_v2.py", line 1206, in forward
    hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/DATA/disk1/sunyijia/sglang/python/sglang/srt/models/deepseek_v2.py", line 1166, in forward
    hidden_states, residual = layer(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/DATA/disk1/sunyijia/sglang/python/sglang/srt/models/deepseek_v2.py", line 1101, in forward
    hidden_states = self.mlp(hidden_states)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/DATA/disk1/sunyijia/sglang/python/sglang/srt/models/deepseek_v2.py", line 126, in forward
    x, _ = self.down_proj(x)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/DATA/disk1/sunyijia/sglang/python/sglang/srt/layers/linear.py", line 1277, in forward
    output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
  File "/DATA/disk1/sunyijia/sglang/python/sglang/srt/layers/linear.py", line 172, in apply
    return F.linear(x, layer.weight, bias)
RuntimeError: CUDA error: CUBLAS_STATUS_INTERNAL_ERROR when calling `cublasGemmEx( handle, opa, opb, m, n, k, &falpha, a, CUDA_R_16BF, lda, b, CUDA_R_16BF, ldb, &fbeta, c, CUDA_R_16BF, ldc, compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)`

some frame error info:

[rank2]:[E319 19:58:45.837235025 ProcessGroupNCCL.cpp:1595] [PG ID 2 PG GUID 3 Rank 2] Process group watchdog thread terminated with exception: CUDA error: misaligned address
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Exception raised from c10_cuda_check_implementation at ../c10/cuda/CUDAException.cpp:43 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f2a29f6c446 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7f2a29f166e4 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x118 (0x7f2a2a3d6a18 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #3: c10d::ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const + 0x56 (0x7f29dfe25726 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #4: c10d::ProcessGroupNCCL::WorkNCCL::isCompleted() + 0xa0 (0x7f29dfe2a3f0 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #5: c10d::ProcessGroupNCCL::watchdogHandler() + 0x1da (0x7f29dfe31b5a in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #6: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x14d (0x7f29dfe3361d in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #7: <unknown function> + 0x145c0 (0x7f2a2a4515c0 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch.so)
frame #8: <unknown function> + 0x94ac3 (0x7f2acd9c2ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #9: <unknown function> + 0x126850 (0x7f2acda54850 in /usr/lib/x86_64-linux-gnu/libc.so.6)

terminate called after throwing an instance of 'c10::DistBackendError'
  what():  [PG ID 2 PG GUID 3 Rank 2] Process group watchdog thread terminated with exception: CUDA error: misaligned address
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

my exec command:

GLOO_SOCKET_IFNAME=bond0 TP_SOCKET_IFNAME=bond0 NCCL_SOCKET_IFNAME=bond0 TORCH_USE_CUDA_DSA=1 CUDA_LAUNCH_BLOCKING=1 python3 -m sglang.launch_server --model-path /DATA/disk1/sunyijia/models/DeepSeek-V2-Lite-Chat --trust-remote-code --tp 16 --dp 16 --dist-init-addr 10.42.1.126:5000 --nnodes 2 --node-rank 1 --mem-fraction-static 0.9 --disable-cuda-graph --enable-deepep-moe --enable-dp-attention 

I have tried without "--enable-dp-attention" or "--enable-deepep-moe", this error still stays.

And I have check the idea of issue 1479, but I install the sglang with source, it maybe not effective for this situation.

I have encounted this error before, it disappeared when i merged the main branch , so I didn't delve into it, but this time it blocks. 😭

Did anyone encountered the same error?

@Huixxi
Copy link

Huixxi commented Mar 19, 2025

So far whether other --tp --dp combines is supported? like --dp 2 --tp 8 which is actually DP2 with Group TP4, that will need group all_reduce before the dispatch operate right?

@Sun1Plus
Copy link

So far whether other --tp --dp combines is supported? like --dp 2 --tp 8 which is actually DP2 with Group TP4, that will need group all_reduce before the dispatch operate right?

I tried several tp & dp combinations, acted as following table:

  1. stay --enable-dp-attention, without --enable-deepep-moe
tp=8 tp=16
dp=2 ok Error
dp=4 Error Error

all "Error" is :

RuntimeError: CUDA error: CUBLAS_STATUS_INTERNAL_ERROR when calling `cublasGemmEx( handle, opa, opb, m, n, k, &falpha, a, CUDA_R_16BF, lda, b, CUDA_R_16BF, ldb, &fbeta, c, CUDA_R_16BF, ldc, compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)`
  1. stay --enable-dp-attention and --enable-deepep-moe
tp=8 tp=16
dp=2 otherError Error
dp=4 Error Error

"otherError" is:

  File "/DATA/disk1/sunyijia/sglang/python/sglang/srt/managers/scheduler.py", line 1809, in run_scheduler_process
    scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
  File "/DATA/disk1/sunyijia/sglang/python/sglang/srt/managers/scheduler.py", line 227, in __init__
    self.tp_worker = TpWorkerClass(
  File "/DATA/disk1/sunyijia/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 63, in __init__
    self.worker = TpModelWorker(server_args, gpu_id, tp_rank, dp_rank, nccl_port)
  File "/DATA/disk1/sunyijia/sglang/python/sglang/srt/managers/tp_worker.py", line 74, in __init__
    self.model_runner = ModelRunner(
  File "/DATA/disk1/sunyijia/sglang/python/sglang/srt/model_executor/model_runner.py", line 168, in __init__
    self.initialize(min_per_gpu_memory)
  File "/DATA/disk1/sunyijia/sglang/python/sglang/srt/model_executor/model_runner.py", line 178, in initialize
    self.load_model()
  File "/DATA/disk1/sunyijia/sglang/python/sglang/srt/model_executor/model_runner.py", line 383, in load_model
    self.model = get_model(
  File "/DATA/disk1/sunyijia/sglang/python/sglang/srt/model_loader/__init__.py", line 22, in get_model
    return loader.load_model(
  File "/DATA/disk1/sunyijia/sglang/python/sglang/srt/model_loader/loader.py", line 366, in load_model
    model = _initialize_model(
  File "/DATA/disk1/sunyijia/sglang/python/sglang/srt/model_loader/loader.py", line 147, in _initialize_model
    return model_class(
  File "/DATA/disk1/sunyijia/sglang/python/sglang/srt/models/deepseek_v2.py", line 1185, in __init__
    self.model = DeepseekV2Model(
  File "/DATA/disk1/sunyijia/sglang/python/sglang/srt/models/deepseek_v2.py", line 1136, in __init__
    [
  File "/DATA/disk1/sunyijia/sglang/python/sglang/srt/models/deepseek_v2.py", line 1137, in <listcomp>
    DeepseekV2DecoderLayer(
  File "/DATA/disk1/sunyijia/sglang/python/sglang/srt/models/deepseek_v2.py", line 1026, in __init__
    self.mlp = DeepseekV2MoE(
  File "/DATA/disk1/sunyijia/sglang/python/sglang/srt/models/deepseek_v2.py", line 234, in __init__
    self.deepep_dispatcher = DeepEPDispatcher(
  File "/DATA/disk1/sunyijia/sglang/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py", line 215, in __init__
    self.buffer_normal = get_buffer_normal(
  File "/DATA/disk1/sunyijia/sglang/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py", line 52, in get_buffer_normal
    _buffer_normal = Buffer(group, num_nvl_bytes, num_rdma_bytes)
  File "/usr/local/lib/python3.10/dist-packages/deep_ep-1.0.0+c4b8ffc-py3.10-linux-x86_64.egg/deep_ep/buffer.py", line 89, in __init__
    self.runtime.sync(device_ids, ipc_handles, root_unique_id)
RuntimeError: Failed: CUDA error /DATA/disk1/sunyijia/DeepEP/csrc/deep_ep.cpp:178 'invalid resource handle'

I couldn't understand why it is.

@zhyncs zhyncs merged commit f44db16 into sgl-project:main Mar 19, 2025
32 of 36 checks passed
@ch-wan
Copy link
Collaborator

ch-wan commented Mar 19, 2025

@Huixxi We recommend dp=tp at this moment. We are going to implement reduce_scatter to adapt deepep to a broader scenario.

@Kiokana
Copy link

Kiokana commented Mar 20, 2025

@liz-badada I'm confused about the performance table. It shows that the throughput of DP Attn + DeepEP is lower than original DP+EP,and Input Throughput (tok/s) == Output Throughput.
Does it mean DeepEP is slower?
Also,Have u ever tried two nodes deployment and evaluted internode performance?

@liz-badada
Copy link
Contributor Author

@liz-badada I'm confused about the performance table. It shows that the throughput of DP Attn + DeepEP is lower than original DP+EP,and Input Throughput (tok/s) == Output Throughput. Does it mean DeepEP is slower? Also,Have u ever tried two nodes deployment and evaluted internode performance?

No. As I comment above, the token permutation mechanism has not yet been optimized. We've encountered some issues with the permute triton kernel, which has necessitated a temporary fallback to PyTorch's native permute function. Additionally, the low-latency dispatch for decoding remains disabled. These limitations indicate that significant optimization work still needs to be carried out to achieve our target performance levels.

@Huixxi
Copy link

Huixxi commented Mar 21, 2025

We tested enable_deepep_moe vs. enable_ep_moe on single node with 8*H800 under the same conditions (identical input length, output length, max concurrency, etc.) use a smaller model. During the prefill stage, the performance remains nearly the same. However, in the decode stage, deepep performs significantly worse. Profiling results indicate that a bubbling phenomenon occurs during deepep's dispatch and combine operations in both the prefill and decode stages. What causes this issue? It doesn’t seem to be related to permute and unpermute operations.
"fig1:deepep_decode"
image
"fig2:deepep_prefill"
image
"fig3:ep_decode" (and why ep decode's compute and communication is in one stream sequentially)
image
"fig4:ep_prefill"
image

@ch-wan
Copy link
Collaborator

ch-wan commented Mar 21, 2025

@Huixxi Thank you for profiling our integration. One recent PR #4643 is trying to accelerate permute and unpermute. We welcome further evaluation and profiling. Also, we plan to fuse the permute kernel with GroupedGeMM in the future to avoid D2H sync.

@Huixxi
Copy link

Huixxi commented Mar 22, 2025

@Huixxi Thank you for profiling our integration. One recent PR #4643 is trying to accelerate permute and unpermute. We welcome further evaluation and profiling. Also, we plan to fuse the permute kernel with GroupedGeMM in the future to avoid D2H sync.

Yes, and FYI. from the profiling file, it really really has too many HtoD and DtoD operators that 10x ~ 100x much more than ep moe which only have less than 100 HtoD + DtoD ops.

@xwz-ol
Copy link

xwz-ol commented Mar 23, 2025

It seems to be stuck before loading weights on 28H20

node 0

GLOO_SOCKET_IFNAME=eth1 TP_SOCKET_IFNAME=eth1 NCCL_SOCKET_IFNAME=eth1 TORCH_USE_CUDA_DSA=1 CUDA_LAUNCH_BLOCKING=1 python3 -m sglang.launch_server --model-path model/DeepSeek-V3 --trust-remote-code --tp 16 --dp 16 --dist-init-addr ip:port --nnodes 2 --node-rank 0 --enable-dp-attention --enable-deepep-moe --disable-cuda-graph INFO 03-23 06:23:31 init.py:190] Automatically detected platform cuda.
image

node 1
GLOO_SOCKET_IFNAME=eth1 TP_SOCKET_IFNAME=eth1 NCCL_SOCKET_IFNAME=eth1 TORCH_USE_CUDA_DSA=1 CUDA_LAUNCH_BLOCKING=1 python3 -m sglang.launch_server --model-path model/DeepSeek-V3 --trust-remote-code --tp 16 --dp 16 --dist-init-addr ip:port --nnodes 2 --node-rank 1 --enable-dp-attention --enable-deepep-moe --disable-cuda-graph
INFO 03-23 06:23:40 init.py:190] A
image
utomatically detected platform cuda.

@xle97
Copy link

xle97 commented Mar 23, 2025

@liz-badada When do you plan to integrate the low_latency mode?

@liz-badada
Copy link
Contributor Author

@liz-badada When do you plan to integrate the low_latency mode?

low latency is WIP.

@Huixxi
Copy link

Huixxi commented Mar 24, 2025

@Huixxi Thank you for profiling our integration. One recent PR #4643 is trying to accelerate permute and unpermute. We welcome further evaluation and profiling. Also, we plan to fuse the permute kernel with GroupedGeMM in the future to avoid D2H sync.

Yes, and FYI. from the profiling file, it really really has too many HtoD and DtoD operators that 10x ~ 100x much more than ep moe which only have less than 100 HtoD + DtoD ops.

Update, after pr Optimize Permute Kernel in DeepEP #4643, the number of HtoD and DtoD has decreased to the same order of magnitude as ep moe, just a little more than its.

@CUHKSZzxy
Copy link
Contributor

May I ask how to deploy on 4 nodes with TP/DP/EP? Should we adopt the following commands with --tp 32 --dp 32, is this a reasonable config to achieve the best inference efficiency?

# node 0
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --trust-remote-code \
  --tp 32 --dp 32  --dist-init-addr xxx:xxx --nnodes 4 --node-rank 0 \
  --enable-dp-attention --enable-deepep-moe \
  --disable-cuda-graph

# node 1
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --trust-remote-code \
  --tp 32 --dp 32  --dist-init-addr xxx:xxx --nnodes 4 --node-rank 1 \
  --enable-dp-attention --enable-deepep-moe \
  --disable-cuda-graph

# node 1
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --trust-remote-code \
  --tp 32 --dp 32  --dist-init-addr xxx:xxx --nnodes 4 --node-rank 2 \
  --enable-dp-attention --enable-deepep-moe \
  --disable-cuda-graph

# node 1
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --trust-remote-code \
  --tp 32 --dp 32  --dist-init-addr xxx:xxx --nnodes 4 --node-rank 3 \
  --enable-dp-attention --enable-deepep-moe \
  --disable-cuda-graph

@liz-badada
Copy link
Contributor Author

May I ask how to deploy on 4 nodes with TP/DP/EP? Should we adopt the following commands with --tp 32 --dp 32, is this a reasonable config to achieve the best inference efficiency?

# node 0
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --trust-remote-code \
  --tp 32 --dp 32  --dist-init-addr xxx:xxx --nnodes 4 --node-rank 0 \
  --enable-dp-attention --enable-deepep-moe \
  --disable-cuda-graph

# node 1
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --trust-remote-code \
  --tp 32 --dp 32  --dist-init-addr xxx:xxx --nnodes 4 --node-rank 1 \
  --enable-dp-attention --enable-deepep-moe \
  --disable-cuda-graph

# node 1
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --trust-remote-code \
  --tp 32 --dp 32  --dist-init-addr xxx:xxx --nnodes 4 --node-rank 2 \
  --enable-dp-attention --enable-deepep-moe \
  --disable-cuda-graph

# node 1
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --trust-remote-code \
  --tp 32 --dp 32  --dist-init-addr xxx:xxx --nnodes 4 --node-rank 3 \
  --enable-dp-attention --enable-deepep-moe \
  --disable-cuda-graph

Hi, please check this: #4836

@CUHKSZzxy
Copy link
Contributor

CUHKSZzxy commented Apr 10, 2025

May I ask how to deploy on 4 nodes with TP/DP/EP? Should we adopt the following commands with --tp 32 --dp 32, is this a reasonable config to achieve the best inference efficiency?

# node 0
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --trust-remote-code \
  --tp 32 --dp 32  --dist-init-addr xxx:xxx --nnodes 4 --node-rank 0 \
  --enable-dp-attention --enable-deepep-moe \
  --disable-cuda-graph

# node 1
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --trust-remote-code \
  --tp 32 --dp 32  --dist-init-addr xxx:xxx --nnodes 4 --node-rank 1 \
  --enable-dp-attention --enable-deepep-moe \
  --disable-cuda-graph

# node 1
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --trust-remote-code \
  --tp 32 --dp 32  --dist-init-addr xxx:xxx --nnodes 4 --node-rank 2 \
  --enable-dp-attention --enable-deepep-moe \
  --disable-cuda-graph

# node 1
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --trust-remote-code \
  --tp 32 --dp 32  --dist-init-addr xxx:xxx --nnodes 4 --node-rank 3 \
  --enable-dp-attention --enable-deepep-moe \
  --disable-cuda-graph

Hi, please check this: #4836

Thanks for your reply!
I have checked the PR you mentioned. It says that "4x8xH100 cannot run without this PR." Does that mean I need to wait for the merge of #4836 in order to deploy with 4x8 Hopper GPU?

@CSEEduanyu
Copy link

Motivation

Intergrate DeepEP into SGLang framework. Still WIP but could use '--enable-dp-attention --enable-deepep-moe' to trigger DeepEP intranode / internode, please follow the install guide of NVSHMEM dependency, also provide a Dockerfile.deepep based on SGLang image.

Note:

  • Currently need --disable-cuda-graph as a W/A of error: "Exception: Capture cuda graph failed: CUDA error: capturing stream has unjoined work"

Single node:

  • command
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --trust-remote-code \
  --tp 8 --dp 8 --host 0.0.0.0 --port 30000 \
  --enable-dp-attention --enable-deepep-moe \
  --disable-cuda-graph
  • test_openai.py
import openai
client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="EMPTY")
# Chat completion
response = client.chat.completions.create(
    model="default",
    messages=[
        {"role": "system", "content": "You are a helpful AI assistant"},
        {"role": "user", "content": "List 3 countries and their capitals."},
    ],
    temperature=0,
    max_tokens=64,
)
print(response)
  • python3 test_openai.py
    ChatCompletion(id='c72b3edaf08f4145a53d497428c534d9', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='1. France - Capital: Paris \n2. Japan - Capital: Tokyo \n3. Brazil - Capital: Brasília', refusal=None, role='assistant', audio=None, function_call=None, tool_calls=None, reasoning_content=None), matched_stop=1)], created=1741594490, model='default', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=38, prompt_tokens=17, total_tokens=55, completion_tokens_details=None, prompt_tokens_details=None))

Multi node:

  • command
# node 0
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --trust-remote-code \
  --tp 16 --dp 16  --dist-init-addr 10.6.131.5:5000 --nnodes 2 --node-rank 0 \
  --enable-dp-attention --enable-deepep-moe \
  --disable-cuda-graph
# node 1
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --trust-remote-code \
  --tp 16 --dp 16  --dist-init-addr 10.6.131.5:5000 --nnodes 2 --node-rank 1 \
  --enable-dp-attention --enable-deepep-moe \
  --disable-cuda-graph
  • python3 test_openai.py
    ChatCompletion(id='72a8328a7ca14e98b2c10604dfbee7ee', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='Sure! Here are three countries and their capitals:\n\n1. France - Paris \n2. Japan - Tokyo \n3. Brazil - Brasília', refusal=None, role='assistant', audio=None, function_call=None, tool_calls=None, reasoning_content=None), matched_stop=1)], created=1741605820, model='default', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=35, prompt_tokens=17, total_tokens=52, completion_tokens_details=None, prompt_tokens_details=None))

Performance (Current performance is below expectations as token permutation is not yet optimized. Due to some bugs in the permute triton kernel, we have temporarily fallen back to using PyTorch's native permute function):

  • command
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --trust-remote-code --tp 8 --dp 8 --host 0.0.0.0 --port 30000 --enable-dp-attention --enable-deepep-moe --max-running-requests 128 --disable-radix-cache --mem-fraction-static 0.9 --stream-output --disable-cuda-graph

python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 512 --random-input 1000 --random-output 1000 --random-range-ratio 1 --host 127.0.0.1 --port 30000 --max-concurrency 128
  • on H20

Parallel Concurrency Input Output Num Requests Input Throughput (tok/s) Output Throughput (tok/s) Total Throughput (tok/s)
DP Attn + DeepEP 127.97 1000 1000 512 436.69 436.69 873.38
DP Attn + EP 127.97 1000 1000 512 617.90 617.90 1235.79

Modifications

Checklist

Is the data in the table incorrect?

@liz-badada
Copy link
Contributor Author

Hi @CSEEduanyu, these perf data are pretty out of date, they are several weeks ago. Suggest using main branch to do the benchmarking.

@CSEEduanyu
Copy link

Hi @CSEEduanyu, these perf data are pretty out of date, they are several weeks ago. Suggest using main branch to do the benchmarking.

What I want to ask is why the throughput of DeepEP is worse than EP? Could it be that the data in the table is reversed for the two?

@nannaer
Copy link

nannaer commented May 13, 2025

Thank you very much for your contribution! However, I have two questions regarding the implementation of DP+EP:

(1) When deploying with 16DP+16EP, why do we have to set both TP Size and DP Size to 16?
(2) During the 16DP+16EP deployment, I noticed that after completing the MLA computation, the functions dp_gather_partial and dp_scatter are called before invoking the expert modules. What are the roles of dp_gather_partial and dp_scatter here? Could you please explain this in detail?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.