Skip to content

Conversation

HaiShaw
Copy link
Collaborator

@HaiShaw HaiShaw commented Dec 30, 2024

Motivation

Together with core changes from: #2637

# python3 -m sglang.launch_server --model /data2/DeepSeek-V3/ --tp 8 --trust-remote-code

# python3 benchmark/gsm8k/bench_sglang.py --num-questions 2000 --parallel 2000 --num-shots 8
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [02:47<00:00,  7.87it/s]
Accuracy: 0.920
Invalid: 0.000
Latency: 173.540 s
Output throughput: 726.016 token/s
#

Modifications

Several places in deepseek_v2.py needs extra normalization to weights (to e4m3fnuz), scaling numbers adjustments, and final bmm results scaling.

MI300X dry performance (everything un-tuned):

# python -m sglang.bench_one_batch --batch-size 32 --input 1024 --output 1024 --model /data2/DeepSeek-V3/ --tp 8 --trust-remote-code

Benchmark ...
Prefill. latency: 2.46543 s, throughput:  13291.01 token/s
Decode.  latency: 0.09925 s, throughput:    322.41 token/s
Decode.  latency: 0.10268 s, throughput:    311.66 token/s
Decode.  latency: 0.10196 s, throughput:    313.85 token/s
Decode.  latency: 0.10094 s, throughput:    317.03 token/s
Decode.  latency: 0.10277 s, throughput:    311.37 token/s
Decode.  median latency: 0.10719 s, median throughput:    298.54 token/s
Total. latency: 112.018 s, throughput:    585.05 token/s

Checklist

  • [+] Format your code according to the Contributor Guide.
  • [+] Add unit tests as outlined in the Contributor Guide.
  • [+] Update documentation as needed, including docstrings or example tutorials.

@zhyncs
Copy link
Member

zhyncs commented Dec 30, 2024

Do we need this #2601 any more?

@HaiShaw : Yes, we will have one final to glue altogether, wait and see, still have an imaging/packing one left, so keep #2601 .

@HaiShaw HaiShaw enabled auto-merge (squash) December 30, 2024 11:02
@HaiShaw HaiShaw disabled auto-merge December 30, 2024 11:02
@zhyncs
Copy link
Member

zhyncs commented Dec 30, 2024

hold on

@zhyncs
Copy link
Member

zhyncs commented Dec 30, 2024

[2024-12-30 11:23:43 TP0] Scheduler hit an exception: Traceback (most recent call last):
  File "/actions-runner/_work/sglang/sglang/python/sglang/srt/managers/scheduler.py", line 1573, in run_scheduler_process
    scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
  File "/actions-runner/_work/sglang/sglang/python/sglang/srt/managers/scheduler.py", line 194, in __init__
    self.tp_worker = TpWorkerClass(
  File "/actions-runner/_work/sglang/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 63, in __init__
    self.worker = TpModelWorker(server_args, gpu_id, tp_rank, dp_rank, nccl_port)
  File "/actions-runner/_work/sglang/sglang/python/sglang/srt/managers/tp_worker.py", line 63, in __init__
    self.model_runner = ModelRunner(
  File "/actions-runner/_work/sglang/sglang/python/sglang/srt/model_executor/model_runner.py", line 184, in __init__
    self.init_cuda_graphs()
  File "/actions-runner/_work/sglang/sglang/python/sglang/srt/model_executor/model_runner.py", line 643, in init_cuda_graphs
    self.cuda_graph_runner = CudaGraphRunner(self)
  File "/actions-runner/_work/sglang/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 209, in __init__
    self.capture()
  File "/actions-runner/_work/sglang/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 275, in capture
    ) = self.capture_one_batch_size(bs, forward)
  File "/actions-runner/_work/sglang/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 339, in capture_one_batch_size
    run_once()
  File "/actions-runner/_work/sglang/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 332, in run_once
    logits_output = forward(input_ids, forward_batch.positions, forward_batch)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/actions-runner/_work/sglang/sglang/python/sglang/srt/models/deepseek_v2.py", line 859, in forward
    hidden_states = self.model(input_ids, positions, forward_batch)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/actions-runner/_work/sglang/sglang/python/sglang/srt/models/deepseek_v2.py", line 820, in forward
    hidden_states, residual = layer(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/actions-runner/_work/sglang/sglang/python/sglang/srt/models/deepseek_v2.py", line 758, in forward
    hidden_states = self.self_attn(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/actions-runner/_work/sglang/sglang/python/sglang/srt/models/deepseek_v2.py", line 515, in forward
    return self.forward_absorb(positions, hidden_states, forward_batch)
  File "/actions-runner/_work/sglang/sglang/python/sglang/srt/models/deepseek_v2.py", line 590, in forward_absorb
    q_nope_out = bmm_fp8(
  File "/usr/local/lib/python3.10/dist-packages/flashinfer/gemm.py", line 266, in bmm_fp8
    _kernels.bmm_fp8(A, B, out, A_scale, B_scale)
TypeError: bmm_fp8(): incompatible function arguments. The following argument types are supported:
    1. (arg0: torch.Tensor, arg1: torch.Tensor, arg2: torch.Tensor, arg3: torch.Tensor, arg4: torch.Tensor) -> None

@zhyncs zhyncs merged commit c5210df into sgl-project:main Dec 30, 2024
11 of 15 checks passed
XiaotongJiang pushed a commit to XiaotongJiang/sglang that referenced this pull request Jan 3, 2025
timethink pushed a commit to timethink/sglang that referenced this pull request Mar 9, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants