-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Closed
Labels
Description
Checklist
- 1. I have searched related issues but cannot get the expected help.
- 2. The bug has not been fixed in the latest version.
- 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
- 4. If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/sgl-project/sglang/discussions/new/choose Otherwise, it will be closed.
- 5. Please use English, otherwise it will be closed.
Describe the bug
I use the latest main (d1112d8548eb13c842900b3a8d622345f9737759), start the DeepSeek V2.5 bf16 model, and when using the --enable-torch-compile
parameter, an error is report.
Reproduction
python3 -m sglang.launch_server --model /path/to/DeepSeek-V2.5-1210 --trust-remote --dtype bfloat16 --host 0.0.0.0 --port 30000 --tp 8 --enable-torch-compile --torch-compile-max-bs 4
The exception information is as follows:
[2025-03-17 07:53:55 TP7] Registering 2420 cuda graph addresses
[rank4]:E0317 07:53:56.045000 2265 torch/_guards.py:283] [14/1] Error while creating guard:
[rank4]:E0317 07:53:56.045000 2265 torch/_guards.py:283] [14/1] Name: "G['__import_sglang_dot_srt_dot_layers_dot_moe_dot_topk'].grouped_topk.__defaults__[2]"
[rank4]:E0317 07:53:56.045000 2265 torch/_guards.py:283] [14/1] Source: global
[rank4]:E0317 07:53:56.045000 2265 torch/_guards.py:283] [14/1] Create Function: CONSTANT_MATCH
[rank4]:E0317 07:53:56.045000 2265 torch/_guards.py:283] [14/1] Guard Types: None
[rank4]:E0317 07:53:56.045000 2265 torch/_guards.py:283] [14/1] Code List: None
[rank4]:E0317 07:53:56.045000 2265 torch/_guards.py:283] [14/1] Object Weakref: None
[rank4]:E0317 07:53:56.045000 2265 torch/_guards.py:283] [14/1] Guarded Class Weakref: None
[rank4]:E0317 07:53:56.045000 2265 torch/_guards.py:283] [14/1] Traceback (most recent call last):
[rank4]:E0317 07:53:56.045000 2265 torch/_guards.py:283] [14/1] File "/usr/local/lib/python3.10/dist-packages/torch/_guards.py", line 281, in create
[rank4]:E0317 07:53:56.045000 2265 torch/_guards.py:283] [14/1] return self.create_fn(builder, self)
[rank4]:E0317 07:53:56.045000 2265 torch/_guards.py:283] [14/1] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/guards.py", line 1576, in CONSTANT_MATCH
[rank4]:E0317 07:53:56.045000 2265 torch/_guards.py:283] [14/1] val = self.get(guard.name)
[rank4]:E0317 07:53:56.045000 2265 torch/_guards.py:283] [14/1] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/guards.py", line 1148, in get
[rank4]:E0317 07:53:56.045000 2265 torch/_guards.py:283] [14/1] return eval(name, self.scope, CLOSURE_VARS)
[rank4]:E0317 07:53:56.045000 2265 torch/_guards.py:283] [14/1] File "<string>", line 1, in <module>
[rank4]:E0317 07:53:56.045000 2265 torch/_guards.py:283] [14/1] TypeError: 'NoneType' object is not subscriptable
[2025-03-17 07:53:56 TP4] Registering 2420 cuda graph addresses
[2025-03-17 07:53:56 TP7] Scheduler hit an exception: Traceback (most recent call last):
File "/sgl-workspace/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 249, in __init__
self.capture()
File "/sgl-workspace/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 333, in capture
) = self.capture_one_batch_size(bs, forward)
File "/sgl-workspace/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 425, in capture_one_batch_size
run_once()
File "/sgl-workspace/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 418, in run_once
logits_output = forward(input_ids, forward_batch.positions, forward_batch)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 40, in inner
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 1080, in forward
hidden_states = self.model(input_ids, positions, forward_batch)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 1042, in forward
hidden_states, residual = layer(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 958, in forward
hidden_states = self.self_attn(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1269, in __call__
return self._torchdynamo_orig_callable(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1064, in __call__
result = self._inner_convert(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 526, in __call__
return _compile(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 952, in _compile
raise InternalTorchDynamoError(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 924, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 666, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 87, in wrapper_function
return function(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 796, in _compile_inner
check_fn = CheckFunctionManager(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/guards.py", line 2261, in __init__
guard.create(builder)
File "/usr/local/lib/python3.10/dist-packages/torch/_guards.py", line 281, in create
return self.create_fn(builder, self)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/guards.py", line 1576, in CONSTANT_MATCH
val = self.get(guard.name)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/guards.py", line 1148, in get
return eval(name, self.scope, CLOSURE_VARS)
File "<string>", line 1, in <module>
torch._dynamo.exc.InternalTorchDynamoError: TypeError: 'NoneType' object is not subscriptable
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1807, in run_scheduler_process
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 226, in __init__
self.tp_worker = TpWorkerClass(
File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 63, in __init__
self.worker = TpModelWorker(server_args, gpu_id, tp_rank, dp_rank, nccl_port)
File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker.py", line 74, in __init__
self.model_runner = ModelRunner(
File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 167, in __init__
self.initialize(min_per_gpu_memory)
File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 208, in initialize
self.init_cuda_graphs()
File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 889, in init_cuda_graphs
self.cuda_graph_runner = CudaGraphRunner(self)
File "/sgl-workspace/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 251, in __init__
raise Exception(
Exception: Capture cuda graph failed: TypeError: 'NoneType' object is not subscriptable
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
Possible solutions:
1. disable cuda graph by --disable-cuda-graph
2. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)
3. disable torch compile by not using --enable-torch-compile
4. set --cuda-graph-max-bs to a smaller value (e.g., 32)
Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose
Environment
Python: 3.10.12 (main, Jan 17 2025, 14:35:34) [GCC 11.4.0]
CUDA available: True
GPU 0,1,2,3,4,5,6,7: NVIDIA A800-SXM4-80GB
GPU 0,1,2,3,4,5,6,7 Compute Capability: 8.0
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 12.4, V12.4.131
CUDA Driver Version: 550.54.14
PyTorch: 2.5.1+cu124
sgl_kernel: 0.0.5.post2
flashinfer: 0.2.3+cu124torch2.5
triton: 3.1.0
transformers: 4.48.3
torchao: 0.9.0
numpy: 1.26.4
aiohttp: 3.11.13
fastapi: 0.115.11
hf_transfer: 0.1.9
huggingface_hub: 0.29.3
interegular: 0.3.3
modelscope: 1.23.2
orjson: 3.10.15
packaging: 24.2
psutil: 7.0.0
pydantic: 2.10.6
multipart: 0.0.20
zmq: 26.3.0
uvicorn: 0.34.0
uvloop: 0.21.0
vllm: 0.7.2
openai: 1.66.3
tiktoken: 0.9.0
anthropic: 0.49.0
decord: 0.6.0
NVIDIA Topology:
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7 NIC0 NIC1 NIC2 NIC3 NIC4 NIC5 NIC6 NIC7 NIC8 CPU Affinity NUMA Affinity GPU NUMA ID
GPU0 X NV8 NV8 NV8 NV8 NV8 NV8 NV8 PXB PXB SYS SYS SYS SYS SYS SYS SYS 0-31,64-95 0 N/A
GPU1 NV8 X NV8 NV8 NV8 NV8 NV8 NV8 PXB PXB SYS SYS SYS SYS SYS SYS SYS 0-31,64-95 0 N/A
GPU2 NV8 NV8 X NV8 NV8 NV8 NV8 NV8 SYS SYS PXB PXB SYS SYS SYS SYS SYS 0-31,64-95 0 N/A
GPU3 NV8 NV8 NV8 X NV8 NV8 NV8 NV8 SYS SYS PXB PXB SYS SYS SYS SYS SYS 0-31,64-95 0 N/A
GPU4 NV8 NV8 NV8 NV8 X NV8 NV8 NV8 SYS SYS SYS SYS PXB PXB SYS SYS SYS 32-63,96-127 1 N/A
GPU5 NV8 NV8 NV8 NV8 NV8 X NV8 NV8 SYS SYS SYS SYS PXB PXB SYS SYS SYS 32-63,96-127 1 N/A
GPU6 NV8 NV8 NV8 NV8 NV8 NV8 X NV8 SYS SYS SYS SYS SYS SYS PXB PXB SYS 32-63,96-127 1 N/A
GPU7 NV8 NV8 NV8 NV8 NV8 NV8 NV8 X SYS SYS SYS SYS SYS SYS PXB PXB SYS 32-63,96-127 1 N/A
NIC0 PXB PXB SYS SYS SYS SYS SYS SYS X PIX SYS SYS SYS SYS SYS SYS SYS
NIC1 PXB PXB SYS SYS SYS SYS SYS SYS PIX X SYS SYS SYS SYS SYS SYS SYS
NIC2 SYS SYS PXB PXB SYS SYS SYS SYS SYS SYS X PIX SYS SYS SYS SYS SYS
NIC3 SYS SYS PXB PXB SYS SYS SYS SYS SYS SYS PIX X SYS SYS SYS SYS SYS
NIC4 SYS SYS SYS SYS PXB PXB SYS SYS SYS SYS SYS SYS X PIX SYS SYS SYS
NIC5 SYS SYS SYS SYS PXB PXB SYS SYS SYS SYS SYS SYS PIX X SYS SYS SYS
NIC6 SYS SYS SYS SYS SYS SYS PXB PXB SYS SYS SYS SYS SYS SYS X PIX SYS
NIC7 SYS SYS SYS SYS SYS SYS PXB PXB SYS SYS SYS SYS SYS SYS PIX X SYS
NIC8 SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS X
Legend:
X = Self
SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
PIX = Connection traversing at most a single PCIe bridge
NV# = Connection traversing a bonded set of # NVLinks
NIC Legend:
NIC0: mlx5_2
NIC1: mlx5_3
NIC2: mlx5_4
NIC3: mlx5_5
NIC4: mlx5_6
NIC5: mlx5_7
NIC6: mlx5_8
NIC7: mlx5_9
NIC8: mlx5_bond_0
ulimit soft: 1048576