-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Closed
Description
Checklist
- 1. I have searched related issues but cannot get the expected help.
- 2. The bug has not been fixed in the latest version.
- 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
- 4. If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/sgl-project/sglang/discussions/new/choose Otherwise, it will be closed.
- 5. Please use English, otherwise it will be closed.
Describe the bug
I got an error when running DeepSeek R1 on latest code.
Both TP+EP and TP+EP+DP have crash problem.
Here are some instruction steps to reproduce it.
Env setting
# install sglang refer to https://docs.sglang.ai/start/install.html
# Use the last release branch
git clone -b v0.4.6.post5 https://github.com/sgl-project/sglang.git
cd sglang
pip install --upgrade pip
cd sgl-kernel
python setup_rocm.py install
cd ..
pip install -e "python[all_hip]"
# install aiter refer to https://github.com/ROCm/aiter
git clone --recursive https://github.com/ROCm/aiter.git
cd aiter
python3 setup.py develop
Reproduction
- w/o --dp-size 8 --enable-dp-attention --moe-dense-tp-size 1
python3 -m sglang.launch_server --trust-remote-code --chunked-prefill-size 131072 --disable-cuda-graph --disable-custom-all-reduce --tp-size 8 --model /apps/data/models/DSR1 --ep-size 8 --enable-ep-moe
- error log
[2025-06-10 13:50:12 TP0] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0 [168/1859]
[2025-06-10 13:50:14 TP3] Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! Config file not found at /apps/fangyuan/sglang/python/sglang/srt/layers/quantization/configs/N=2112,K=7168,device_name=AMD_Radeo
n_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json
[2025-06-10 13:50:14 TP1] Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! Config file not found at /apps/fangyuan/sglang/python/sglang/srt/layers/quantization/configs/N=2112,K=7168,device_name=AMD_Radeo
n_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json
[2025-06-10 13:50:14 TP0] Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! Config file not found at /apps/fangyuan/sglang/python/sglang/srt/layers/quantization/configs/N=2112,K=7168,device_name=AMD_Radeo
n_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json
[2025-06-10 13:50:14 TP2] Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! Config file not found at /apps/fangyuan/sglang/python/sglang/srt/layers/quantization/configs/N=2112,K=7168,device_name=AMD_Radeo
n_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json
[2025-06-10 13:50:14 TP4] Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! Config file not found at /apps/fangyuan/sglang/python/sglang/srt/layers/quantization/configs/N=2112,K=7168,device_name=AMD_Radeo
n_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json
[2025-06-10 13:50:14 TP6] Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! Config file not found at /apps/fangyuan/sglang/python/sglang/srt/layers/quantization/configs/N=2112,K=7168,device_name=AMD_Radeo
n_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json
[2025-06-10 13:50:14 TP5] Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! Config file not found at /apps/fangyuan/sglang/python/sglang/srt/layers/quantization/configs/N=2112,K=7168,device_name=AMD_Radeo
n_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json
[2025-06-10 13:50:14 TP7] Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! Config file not found at /apps/fangyuan/sglang/python/sglang/srt/layers/quantization/configs/N=2112,K=7168,device_name=AMD_Radeo
n_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json
[2025-06-10 13:50:15 TP0] Using configuration from /apps/fangyuan/sglang/python/sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json for W8A8 Block FP8 kerne
l.
[2025-06-10 13:50:15 TP0] Using configuration from /apps/fangyuan/sglang/python/sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json for W8A8 Block FP8 kernel
.
[2025-06-10 13:50:17 TP0] Using configuration from /apps/fangyuan/sglang/python/sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json for W8A8 Block FP8 kerne
l.
[2025-06-10 13:50:17 TP0] Using configuration from /apps/fangyuan/sglang/python/sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json for W8A8 Block FP8 kerne
l.
[2025-06-10 13:50:17 TP0] Using configuration from /apps/fangyuan/sglang/python/sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json for W8A8 Block FP8 kerne
l.
[2025-06-10 13:50:20 TP0] Using configuration from /apps/fangyuan/sglang/python/sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json for W8A8 Block FP8 kernel
.
[2025-06-10 13:50:20 TP0] Using configuration from /apps/fangyuan/sglang/python/sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json for W8A8 Block FP8 kernel
.
/apps/fangyuan/sglang/python/sglang/srt/layers/moe/ep_moe/kernels.py:612:42: error: Unsupported conversion from 'f8E4M3FN' to 'f16'
accumulator += tl.dot(a_tile, b_tile.T) * a_scale * b_scale[None, :]
^
/apps/fangyuan/sglang/python/sglang/srt/layers/moe/ep_moe/kernels.py:612:42: error: failed to legalize operation 'tt.fp_to_fp' that was explicitly marked illegal
accumulator += tl.dot(a_tile, b_tile.T) * a_scale * b_scale[None, :]
^
/apps/fangyuan/sglang/python/sglang/srt/layers/moe/ep_moe/kernels.py:612:42: note: see current operation: %14339 = "tt.fp_to_fp"(%8960) : (tensor<128x32xf8E4M3FN, #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [8, 8], warps
PerCTA = [1, 4], order = [0, 1]}>>) -> tensor<128x32xf32, #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}>>
/apps/fangyuan/sglang/python/sglang/srt/layers/moe/ep_moe/kernels.py:612:42: error: Unsupported conversion from 'f8E4M3FN' to 'f16'
accumulator += tl.dot(a_tile, b_tile.T) * a_scale * b_scale[None, :]
^
[2025-06-10 13:50:23 TP0] TpModelWorkerClient hit an exception: Traceback (most recent call last):
File "/apps/fangyuan/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 118, in forward_thread_func
self.forward_thread_func_()
File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/apps/fangyuan/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 151, in forward_thread_func_
self.worker.forward_batch_generation(
File "/apps/fangyuan/sglang/python/sglang/srt/managers/tp_worker.py", line 202, in forward_batch_generation
logits_output, can_run_cuda_graph = self.model_runner.forward(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/apps/fangyuan/sglang/python/sglang/srt/model_executor/model_runner.py", line 1172, in forward
output = self._forward_raw(
^^^^^^^^^^^^^^^^^^
File "/apps/fangyuan/sglang/python/sglang/srt/model_executor/model_runner.py", line 1201, in _forward_raw
ret = self.forward_extend(
File "/apps/fangyuan/sglang/python/sglang/srt/model_executor/model_runner.py", line 1140, in forward_extend
return self.model.forward(
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/apps/fangyuan/sglang/python/sglang/srt/models/deepseek_v2.py", line 1475, in forward
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/apps/fangyuan/sglang/python/sglang/srt/models/deepseek_v2.py", line 1390, in forward
hidden_states, residual = layer(
^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/apps/fangyuan/sglang/python/sglang/srt/models/deepseek_v2.py", line 1255, in forward
return execute_operations(
^^^^^^^^^^^^^^^^^^^
File "/apps/fangyuan/sglang/python/sglang/srt/operations.py", line 18, in execute_operations
executor.next()
File "/apps/fangyuan/sglang/python/sglang/srt/operations.py", line 53, in next
self._stage_output = op.fn(
^^^^^^
File "/apps/fangyuan/sglang/python/sglang/srt/models/deepseek_v2.py", line 413, in op_experts
state.hidden_states_experts_output = self.experts(
^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/apps/fangyuan/sglang/python/sglang/srt/layers/moe/ep_moe/layer.py", line 292, in forward
gateup_output = self.grouped_gemm_runner(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/apps/fangyuan/sglang/python/sglang/srt/layers/moe/ep_moe/layer.py", line 114, in forward
c = grouped_gemm_triton(
^^^^^^^^^^^^^^^^^^^^
File "/apps/fangyuan/sglang/python/sglang/srt/layers/moe/ep_moe/kernels.py", line 696, in grouped_gemm_triton
grouped_gemm_triton_kernel[grid](
File "/usr/local/lib/python3.12/dist-packages/triton/runtime/jit.py", line 330, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/triton/runtime/jit.py", line 623, in run
kernel = self.compile(
^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/triton/compiler/compiler.py", line 286, in compile
next_module = compile_ir(module, metadata)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/triton/backends/amd/compiler.py", line 382, in <lambda>
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/triton/backends/amd/compiler.py", line 303, in make_llir
pm.run(mod)
RuntimeError: PassManager::run failed
[2025-06-10 13:50:23] Received sigquit from a child process. It usually means the child failed.
/apps/fangyuan/sglang/python/sglang/srt/layers/moe/ep_moe/kernels.py:612:42: error: failed to legalize operation 'tt.fp_to_fp' that was explicitly marked illegal
accumulator += tl.dot(a_tile, b_tile.T) * a_scale * b_scale[None, :]
^
/apps/fangyuan/sglang/python/sglang/srt/layers/moe/ep_moe/kernels.py:612:42: note: see current operation: %14339 = "tt.fp_to_fp"(%8960) : (tensor<128x32xf8E4M3FN, #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [8, 8], warps
PerCTA = [1, 4], order = [0, 1]}>>) -> tensor<128x32xf32, #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}>>
[2025-06-10 13:50:23 TP3] TpModelWorkerClient hit an exception: Traceback (most recent call last):
File "/apps/fangyuan/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 118, in forward_thread_func
self.forward_thread_func_()
File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/apps/fangyuan/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 151, in forward_thread_func_
self.worker.forward_batch_generation(
File "/apps/fangyuan/sglang/python/sglang/srt/managers/tp_worker.py", line 202, in forward_batch_generation
logits_output, can_run_cuda_graph = self.model_runner.forward(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/apps/fangyuan/sglang/python/sglang/srt/model_executor/model_runner.py", line 1172, in forward
output = self._forward_raw(
^^^^^^^^^^^^^^^^^^
File "/apps/fangyuan/sglang/python/sglang/srt/model_executor/model_runner.py", line 1201, in _forward_raw
ret = self.forward_extend(
^^^^^^^^^^^^^^^^^^^^
File "/apps/fangyuan/sglang/python/sglang/srt/model_executor/model_runner.py", line 1140, in forward_extend
return self.model.forward(
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/apps/fangyuan/sglang/python/sglang/srt/models/deepseek_v2.py", line 1475, in forward
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/apps/fangyuan/sglang/python/sglang/srt/models/deepseek_v2.py", line 1390, in forward
hidden_states, residual = layer(
^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/apps/fangyuan/sglang/python/sglang/srt/models/deepseek_v2.py", line 1255, in forward
return execute_operations(
^^^^^^^^^^^^^^^^^^^
File "/apps/fangyuan/sglang/python/sglang/srt/operations.py", line 18, in execute_operations
executor.next()
File "/apps/fangyuan/sglang/python/sglang/srt/operations.py", line 53, in next
self._stage_output = op.fn(
^^^^^^
File "/apps/fangyuan/sglang/python/sglang/srt/models/deepseek_v2.py", line 413, in op_experts
state.hidden_states_experts_output = self.experts(
^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/apps/fangyuan/sglang/python/sglang/srt/layers/moe/ep_moe/layer.py", line 292, in forward
gateup_output = self.grouped_gemm_runner(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/apps/fangyuan/sglang/python/sglang/srt/layers/moe/ep_moe/layer.py", line 114, in forward
c = grouped_gemm_triton(
^^^^^^^^^^^^^^^^^^^^
File "/apps/fangyuan/sglang/python/sglang/srt/layers/moe/ep_moe/kernels.py", line 696, in grouped_gemm_triton
grouped_gemm_triton_kernel[grid](
File "/usr/local/lib/python3.12/dist-packages/triton/runtime/jit.py", line 330, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/triton/runtime/jit.py", line 623, in run
kernel = self.compile(
^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/triton/compiler/compiler.py", line 286, in compile
next_module = compile_ir(module, metadata)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/triton/backends/amd/compiler.py", line 382, in <lambda>
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/triton/backends/amd/compiler.py", line 303, in make_llir
pm.run(mod)
RuntimeError: PassManager::run failed
- w/ --dp-size 8 --enable-dp-attention --moe-dense-tp-size 1
python3 -m sglang.launch_server --trust-remote-code --chunked-prefill-size 131072 --disable-cuda-graph --disable-custom-all-reduce --tp-size 8 --model /apps/data/models/DSR1 --ep-size 8 --enable-ep-moe --dp-size 8 --enable-dp-attention --moe-dense-tp-size 1
- error log
[2025-06-10 14:09:44 DP7 TP7] KV Cache is allocated. #tokens: 1230416, KV size: 80.53 GB [386/1830]
[2025-06-10 14:09:45 DP0 TP0] max_total_num_tokens=1230416, chunked_prefill_size=16384, max_prefill_tokens=16384, max_running_requests=3846, context_len=163840
[2025-06-10 14:09:46] INFO: Started server process [50708]
[2025-06-10 14:09:46] INFO: Waiting for application startup.
[2025-06-10 14:09:46] INFO: Application startup complete.
[2025-06-10 14:09:46] INFO: Uvicorn running on http://127.0.0.1:200 (Press CTRL+C to quit)
[2025-06-10 14:09:47] INFO: 127.0.0.1:50444 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-06-10 14:09:47 DP0 TP0] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-06-10 14:09:47 DP3 TP3] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-06-10 14:09:47 DP1 TP1] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-06-10 14:09:47 DP6 TP6] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-06-10 14:09:47 DP7 TP7] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-06-10 14:09:47 DP2 TP2] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-06-10 14:09:47 DP5 TP5] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-06-10 14:09:47 DP4 TP4] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-06-10 14:09:49 DP6 TP6] TpModelWorkerClient hit an exception: Traceback (most recent call last):
File "/apps/fangyuan/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 118, in forward_thread_func
self.forward_thread_func_()
File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/apps/fangyuan/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 151, in forward_thread_func_
self.worker.forward_batch_generation(
File "/apps/fangyuan/sglang/python/sglang/srt/managers/tp_worker.py", line 202, in forward_batch_generation
logits_output, can_run_cuda_graph = self.model_runner.forward(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/apps/fangyuan/sglang/python/sglang/srt/model_executor/model_runner.py", line 1172, in forward
output = self._forward_raw(
^^^^^^^^^^^^^^^^^^
File "/apps/fangyuan/sglang/python/sglang/srt/model_executor/model_runner.py", line 1207, in _forward_raw
ret = self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/apps/fangyuan/sglang/python/sglang/srt/model_executor/model_runner.py", line 1153, in forward_idle
return self.model.forward(
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/apps/fangyuan/sglang/python/sglang/srt/models/deepseek_v2.py", line 1475, in forward
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/apps/fangyuan/sglang/python/sglang/srt/models/deepseek_v2.py", line 1390, in forward
hidden_states, residual = layer(
^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/apps/fangyuan/sglang/python/sglang/srt/models/deepseek_v2.py", line 1255, in forward
return execute_operations(
File "/apps/fangyuan/sglang/python/sglang/srt/operations.py", line 18, in execute_operations
executor.next()
File "/apps/fangyuan/sglang/python/sglang/srt/operations.py", line 53, in next
self._stage_output = op.fn(
^^^^^^
File "/apps/fangyuan/sglang/python/sglang/srt/models/deepseek_v2.py", line 1296, in op_comm_prepare_mlp
self.layer_communicator.prepare_mlp(
File "/apps/fangyuan/sglang/python/sglang/srt/layers/communicator.py", line 179, in prepare_mlp
return _communicate_with_all_reduce_and_layer_norm(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/apps/fangyuan/sglang/python/sglang/srt/layers/communicator.py", line 389, in _communicate_with_all_reduce_and_layer_norm
_inner(), (hidden_states_output_mode, residual_output_mode)
^^^^^^^^
File "/apps/fangyuan/sglang/python/sglang/srt/layers/communicator.py", line 381, in _inner
raise NotImplementedError(
NotImplementedError: hidden_states_input_mode=<ScatterMode.TP_ATTN_FULL: 2> residual_input_mode=<ScatterMode.SCATTERED: 1> residual_output_mode=<ScatterMode.TP_ATTN_FULL: 2> residual_output_mode=<ScatterMode.TP_ATTN_FULL: 2>
Environment
Python: 3.12.8 (main, Dec 4 2024, 08:54:12) [GCC 11.4.0]
ROCM available: True
GPU 0,1,2,3,4,5,6,7: AMD Radeon Graphics
GPU 0,1,2,3,4,5,6,7 Compute Capability: 9.4
ROCM_HOME: /opt/rocm
HIPCC: HIP version: 6.3.42131-fa1d09cbd
ROCM Driver Version: 6.10.5
PyTorch: 2.6.0a0+git8d4926e
sglang: 0.4.6.post5
sgl_kernel: 0.0.5.post3
flashinfer_python: Module Not Found
triton: 3.2.0+gitd8057ea4
transformers: 4.51.1
torchao: 0.9.0
numpy: 1.26.4
aiohttp: 3.11.11
fastapi: 0.115.6
hf_transfer: 0.1.9
huggingface_hub: 0.32.4
interegular: 0.3.3
modelscope: 1.22.3
orjson: 3.10.15
outlines: 0.1.11
packaging: 24.2
psutil: 6.1.1
pydantic: 2.10.5
python-multipart: 0.0.20
pyzmq: 26.2.0
uvicorn: 0.34.0
uvloop: 0.21.0
vllm: 0.6.7.dev2+g113274a0.rocm630
xgrammar: 0.1.19
openai: 1.61.1
tiktoken: 0.7.0
anthropic: 0.45.2
litellm: 1.60.6
decord: 0.6.0
AMD Topology:
============================ ROCm System Management Interface ============================
=============================== Link Type between two GPUs ===============================
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7
GPU0 0 XGMI XGMI XGMI XGMI XGMI XGMI XGMI
GPU1 XGMI 0 XGMI XGMI XGMI XGMI XGMI XGMI
GPU2 XGMI XGMI 0 XGMI XGMI XGMI XGMI XGMI
GPU3 XGMI XGMI XGMI 0 XGMI XGMI XGMI XGMI
GPU4 XGMI XGMI XGMI XGMI 0 XGMI XGMI XGMI
GPU5 XGMI XGMI XGMI XGMI XGMI 0 XGMI XGMI
GPU6 XGMI XGMI XGMI XGMI XGMI XGMI 0 XGMI
GPU7 XGMI XGMI XGMI XGMI XGMI XGMI XGMI 0
================================== End of ROCm SMI Log ===================================
ulimit soft: 1048576
Metadata
Metadata
Assignees
Labels
No labels