Skip to content

Conversation

ispobock
Copy link
Collaborator

@ispobock ispobock commented May 26, 2025

Motivation

Add draft extend CUDA graph for EAGLE. FA3 backend is supported. Other backends will be supported in follow up PRs.

Co-authored-by: @kssteven418

Benchmark

DSV3

11% per_user_throughput improvement for bs=1 and 5% for bs=32.

ref: #6606 (comment) and #6606 (comment)

python3 -m sglang.launch_server --model /dev/shm/DeepSeek-V3-0324 --tp 8 --trust-remote-code --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4 --speculative-algorithm EAGLE

curl http://127.0.0.1:30000/flush_cache
python3 -m sglang.bench_serving --backend sglang-oai  --dataset-name random --random-input-len 1000 --random-output-len 1000 --random-range-ratio 1 --num-prompts 5 --max-concurrency 1 --output-file dsv3.jsonl
curl http://127.0.0.1:30000/flush_cache
python3 -m sglang.bench_serving --backend sglang-oai  --dataset-name random --random-input-len 1000 --random-output-len 1000 --random-range-ratio 1 --num-prompts 20 --max-concurrency 4 --output-file dsv3.jsonl
curl http://127.0.0.1:30000/flush_cache
python3 -m sglang.bench_serving --backend sglang-oai  --dataset-name random --random-input-len 1000 --random-output-len 1000 --random-range-ratio 1 --num-prompts 80 --max-concurrency 16 --output-file dsv3.jsonl
curl http://127.0.0.1:30000/flush_cache
python3 -m sglang.bench_serving --backend sglang-oai  --dataset-name random --random-input-len 1000 --random-output-len 1000 --random-range-ratio 1 --num-prompts 160 --max-concurrency 32 --output-file dsv3.jsonl

main:

+----+-------------------+--------------------+---------------------+----------------+------------------+---------------+----------------+------------------+---------------+-----------------------+
|    |   max_concurrency |   input_throughput |   output_throughput |   mean_ttft_ms |   median_ttft_ms |   p99_ttft_ms |   mean_tpot_ms |   median_tpot_ms |   p99_tpot_ms |   per_user_throughput |
+====+===================+====================+=====================+================+==================+===============+================+==================+===============+=======================+
|  0 |             1.000 |            123.947 |             123.947 |        138.884 |          139.747 |       147.518 |          7.936 |            7.755 |         9.210 |               123.947 |
+----+-------------------+--------------------+---------------------+----------------+------------------+---------------+----------------+------------------+---------------+-----------------------+
|  1 |             4.000 |            376.367 |             376.367 |        330.610 |          163.647 |      1278.265 |          9.913 |            9.713 |        12.412 |                94.092 |
+----+-------------------+--------------------+---------------------+----------------+------------------+---------------+----------------+------------------+---------------+-----------------------+
|  2 |            16.000 |            982.446 |             982.446 |        504.908 |          179.340 |      1873.442 |         14.922 |           14.710 |        18.929 |                61.403 |
+----+-------------------+--------------------+---------------------+----------------+------------------+---------------+----------------+------------------+---------------+-----------------------+
|  3 |            32.000 |           1660.563 |            1660.563 |        359.882 |          183.473 |      1569.568 |         18.120 |           18.424 |        23.038 |                51.893 |
+----+-------------------+--------------------+---------------------+----------------+------------------+---------------+----------------+------------------+---------------+-----------------------+

this PR:

+----+-------------------+--------------------+---------------------+----------------+------------------+---------------+----------------+------------------+---------------+-----------------------+
|    |   max_concurrency |   input_throughput |   output_throughput |   mean_ttft_ms |   median_ttft_ms |   p99_ttft_ms |   mean_tpot_ms |   median_tpot_ms |   p99_tpot_ms |   per_user_throughput |
+====+===================+====================+=====================+================+==================+===============+================+==================+===============+=======================+
|  0 |             1.000 |            139.708 |             139.708 |        154.290 |          149.956 |       176.988 |          7.009 |            6.866 |         8.121 |               139.708 |
+----+-------------------+--------------------+---------------------+----------------+------------------+---------------+----------------+------------------+---------------+-----------------------+
|  1 |             4.000 |            413.854 |             413.854 |        346.852 |          163.505 |      1390.936 |          9.030 |            8.802 |        11.247 |               103.464 |
+----+-------------------+--------------------+---------------------+----------------+------------------+---------------+----------------+------------------+---------------+-----------------------+
|  2 |            16.000 |           1038.688 |            1038.688 |        524.877 |          174.334 |      1913.235 |         13.883 |           13.858 |        17.762 |                64.918 |
+----+-------------------+--------------------+---------------------+----------------+------------------+---------------+----------------+------------------+---------------+-----------------------+
|  3 |            32.000 |           1740.612 |            1740.612 |        357.374 |          178.539 |      1563.505 |         17.306 |           17.497 |        22.053 |                54.394 |
+----+-------------------+--------------------+---------------------+----------------+------------------+---------------+----------------+------------------+---------------+-----------------------+

Llama-3-8B

python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --speculative-algorithm EAGLE --speculative-draft-model-path lmsys/sglang-EAGLE-LLaMA3-Instruct-8B --speculative-num-steps 2 --speculative-eagle-topk 1 --speculative-num-draft-tokens 3 --trust-remote-code --dtype float16 --attention-backend fa3

main

+----+-------------------+--------------------+---------------------+----------------+------------------+---------------+----------------+------------------+---------------+-----------------------+
|    |   max_concurrency |   input_throughput |   output_throughput |   mean_ttft_ms |   median_ttft_ms |   p99_ttft_ms |   mean_tpot_ms |   median_tpot_ms |   p99_tpot_ms |   per_user_throughput |
+====+===================+====================+=====================+================+==================+===============+================+==================+===============+=======================+
|  0 |             1.000 |            210.866 |             210.866 |         41.137 |           35.311 |        61.458 |          4.705 |            4.516 |         5.600 |               210.866 |
+----+-------------------+--------------------+---------------------+----------------+------------------+---------------+----------------+------------------+---------------+-----------------------+
|  1 |             4.000 |            773.532 |             773.532 |         55.434 |           42.459 |       133.661 |          4.859 |            4.857 |         6.040 |               193.383 |
+----+-------------------+--------------------+---------------------+----------------+------------------+---------------+----------------+------------------+---------------+-----------------------+
|  2 |            16.000 |           2371.987 |            2371.987 |        100.885 |           44.230 |       418.474 |          5.754 |            5.661 |         7.655 |               148.249 |
+----+-------------------+--------------------+---------------------+----------------+------------------+---------------+----------------+------------------+---------------+-----------------------+
|  3 |            32.000 |           4313.602 |            4313.602 |        168.698 |           44.833 |       953.043 |          6.773 |            6.650 |         9.297 |               134.800 |
+----+-------------------+--------------------+---------------------+----------------+------------------+---------------+----------------+------------------+---------------+-----------------------+

This PR:

+----+-------------------+--------------------+---------------------+----------------+------------------+---------------+----------------+------------------+---------------+-----------------------+
|    |   max_concurrency |   input_throughput |   output_throughput |   mean_ttft_ms |   median_ttft_ms |   p99_ttft_ms |   mean_tpot_ms |   median_tpot_ms |   p99_tpot_ms |   per_user_throughput |
+====+===================+====================+=====================+================+==================+===============+================+==================+===============+=======================+
|  0 |             1.000 |            221.192 |             221.192 |         44.887 |           40.369 |        75.675 |          4.479 |            4.292 |         5.315 |               221.192 |
+----+-------------------+--------------------+---------------------+----------------+------------------+---------------+----------------+------------------+---------------+-----------------------+
|  1 |             4.000 |            805.923 |             805.923 |         80.402 |           42.115 |       301.355 |          4.640 |            4.716 |         5.763 |               201.481 |
+----+-------------------+--------------------+---------------------+----------------+------------------+---------------+----------------+------------------+---------------+-----------------------+
|  2 |            16.000 |           2518.878 |            2518.878 |        103.171 |           42.926 |       430.251 |          5.417 |            5.338 |         7.140 |               157.430 |
+----+-------------------+--------------------+---------------------+----------------+------------------+---------------+----------------+------------------+---------------+-----------------------+
|  3 |            32.000 |           4473.389 |            4473.389 |        165.690 |           43.676 |       945.616 |          6.514 |            6.396 |         8.787 |               139.793 |
+----+-------------------+--------------------+---------------------+----------------+------------------+---------------+----------------+------------------+---------------+-----------------------+

Profile

Llama-3-8B

main:
no-cg

This PR:
cg

DSV3

main:
dsv3-no-cg

This PR:
dsv3-cg

ispobock and others added 2 commits May 25, 2025 17:46
Co-authored-by: Sehoon Kim <sehoonkim@berkeley.edu>
@ispobock ispobock requested a review from zhaochenyang20 as a code owner May 26, 2025 03:46
@zhyncs zhyncs self-assigned this May 26, 2025
@zhyncs
Copy link
Member

zhyncs commented May 26, 2025

+----+-------------------+--------------------+---------------------+----------------+------------------+---------------+----------------+------------------+---------------+-----------------------+
|    |   max_concurrency |   input_throughput |   output_throughput |   mean_ttft_ms |   median_ttft_ms |   p99_ttft_ms |   mean_tpot_ms |   median_tpot_ms |   p99_tpot_ms |   per_user_throughput |
+====+===================+====================+=====================+================+==================+===============+================+==================+===============+=======================+
|  0 |             1.000 |            123.947 |             123.947 |        138.884 |          139.747 |       147.518 |          7.936 |            7.755 |         9.210 |               123.947 |
+----+-------------------+--------------------+---------------------+----------------+------------------+---------------+----------------+------------------+---------------+-----------------------+
|  1 |             4.000 |            376.367 |             376.367 |        330.610 |          163.647 |      1278.265 |          9.913 |            9.713 |        12.412 |                94.092 |
+----+-------------------+--------------------+---------------------+----------------+------------------+---------------+----------------+------------------+---------------+-----------------------+
|  2 |            16.000 |            982.446 |             982.446 |        504.908 |          179.340 |      1873.442 |         14.922 |           14.710 |        18.929 |                61.403 |
+----+-------------------+--------------------+---------------------+----------------+------------------+---------------+----------------+------------------+---------------+-----------------------+
|  3 |            32.000 |           1660.563 |            1660.563 |        359.882 |          183.473 |      1569.568 |         18.120 |           18.424 |        23.038 |                51.893 |
+----+-------------------+--------------------+---------------------+----------------+------------------+---------------+----------------+------------------+---------------+-----------------------+
|  4 |             1.000 |            137.814 |             137.814 |        148.209 |          150.231 |       162.498 |          7.113 |            7.011 |         8.425 |               137.814 |
+----+-------------------+--------------------+---------------------+----------------+------------------+---------------+----------------+------------------+---------------+-----------------------+
|  5 |             4.000 |            401.329 |             401.329 |        383.695 |          166.587 |      1617.817 |          9.327 |            9.155 |        11.889 |               100.332 |
+----+-------------------+--------------------+---------------------+----------------+------------------+---------------+----------------+------------------+---------------+-----------------------+
|  6 |            16.000 |            950.423 |             950.423 |        665.985 |          173.532 |      2647.285 |         15.058 |           14.964 |        19.418 |                59.401 |
+----+-------------------+--------------------+---------------------+----------------+------------------+---------------+----------------+------------------+---------------+-----------------------+
|  7 |            32.000 |           1592.793 |            1592.793 |        424.593 |          180.001 |      1891.864 |         18.747 |           18.875 |        25.483 |                49.775 |
+----+-------------------+--------------------+---------------------+----------------+------------------+---------------+----------------+------------------+---------------+-----------------------+
python3 -m sglang.launch_server --model /dev/shm/DeepSeek-V3-0324 --tp 8 --trust-remote-code --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4 --speculative-algorithm EAGLE
curl http://127.0.0.1:30000/flush_cache
python3 -m sglang.bench_serving --backend sglang-oai  --dataset-name random --random-input-len 1000 --random-output-len 1000 --random-range-ratio 1 --num-prompts 5 --max-concurrency 1 --output-file dsv3.jsonl
curl http://127.0.0.1:30000/flush_cache
python3 -m sglang.bench_serving --backend sglang-oai  --dataset-name random --random-input-len 1000 --random-output-len 1000 --random-range-ratio 1 --num-prompts 20 --max-concurrency 4 --output-file dsv3.jsonl
curl http://127.0.0.1:30000/flush_cache
python3 -m sglang.bench_serving --backend sglang-oai  --dataset-name random --random-input-len 1000 --random-output-len 1000 --random-range-ratio 1 --num-prompts 80 --max-concurrency 16 --output-file dsv3.jsonl
curl http://127.0.0.1:30000/flush_cache
python3 -m sglang.bench_serving --backend sglang-oai  --dataset-name random --random-input-len 1000 --random-output-len 1000 --random-range-ratio 1 --num-prompts 160 --max-concurrency 32 --output-file dsv3.jsonl

From the results, it can be seen that performance is better when batch size is 1 or 4, and worse when batch size is 16 or 32 compared to the main.

from sglang.srt.speculative.eagle_worker import EAGLEWorker


class EAGLEDraftExtendCudaGraphRunner:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW when enabling extended CUDA Graph, we should also adjust mem_fraction_static to avoid running out of memory. For example, set it to 5/4/8.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

python3 -m sglang.launch_server --model /dev/shm/DeepSeek-V3-0324 --tp 8 --trust-remote-code --speculative-num-steps 5 --speculative-eagle-topk 4 --speculative-num-draft-tokens 8 --speculative-algorithm EAGLE

lm_eval --model local-chat-completions --model_args model=/dev/shm/DeepSeek-V3-0324,base_url=http://127.0.0.1:30000/v1/chat/completions,num_concurrent=128,timeout=999999,max_gen_toks=2048 --tasks gsm8k --batch_size 128 --apply_chat_template --num_fewshot 8

This will be OOM.

@ispobock
Copy link
Collaborator Author

ispobock commented May 27, 2025

Performance improved after 9055a49.

+----+-------------------+--------------------+---------------------+----------------+------------------+---------------+----------------+------------------+---------------+-----------------------+
|    |   max_concurrency |   input_throughput |   output_throughput |   mean_ttft_ms |   median_ttft_ms |   p99_ttft_ms |   mean_tpot_ms |   median_tpot_ms |   p99_tpot_ms |   per_user_throughput |
+====+===================+====================+=====================+================+==================+===============+================+==================+===============+=======================+
|  0 |            16.000 |            996.351 |             996.351 |        355.755 |          173.577 |      1274.461 |         14.697 |           14.686 |        18.097 |                62.272 |
+----+-------------------+--------------------+---------------------+----------------+------------------+---------------+----------------+------------------+---------------+-----------------------+
|  1 |            32.000 |           1617.969 |            1617.969 |        398.944 |          177.661 |      1734.676 |         18.596 |           18.751 |        25.125 |                50.562 |
+----+-------------------+--------------------+---------------------+----------------+------------------+---------------+----------------+------------------+---------------+-----------------------+

@ispobock
Copy link
Collaborator Author

There is an accept rate issue on large batch size. It should be fixed before merge.

@zhyncs
Copy link
Member

zhyncs commented May 27, 2025

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces CUDA graph support for EAGLE draft extend, which is a significant step towards improving performance for speculative decoding. The changes look promising, with performance gains demonstrated in the PR description.

I've identified a couple of areas for improvement, primarily related to maintainability and a potential correctness issue in attention backend initialization. Overall, the new EAGLEDraftExtendCudaGraphRunner class seems well-structured for its purpose.

Summary of Findings

  • Magic Numbers for Memory Reservation: In python/sglang/srt/server_args.py, the memory reservation values (e.g., 1024 * 18, 1024 * 20) are magic numbers. Defining them as named constants would improve code readability and maintainability.
  • Attention Backend Initialization: In python/sglang/srt/speculative/eagle_worker.py (lines 696-699), when CUDA graph is not used for draft extend, init_forward_metadata appears to be called on the main attention backend of the draft model runner, while the subsequent forward pass uses self.draft_extend_attn_backend (via forward_batch.attn_backend). This could lead to inconsistencies. init_forward_metadata should likely be called on self.draft_extend_attn_backend.

Merge Readiness

The pull request introduces valuable performance enhancements with CUDA graph support for EAGLE draft extend. The core logic for the new CUDA graph runner and the integration into the EAGLE worker seem mostly correct.

However, there is one high-severity issue regarding attention backend initialization that should be addressed before merging to ensure correctness when CUDA graph is not used. Additionally, there's a medium-severity suggestion for improving maintainability by refactoring magic numbers related to memory reservation.

I am unable to approve this pull request. Please have another reviewer approve this code after addressing the identified issues.

Comment on lines +696 to +699
self.draft_model_runner.attn_backend.init_forward_metadata(forward_batch)
logits_output = self.draft_model_runner.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

It seems init_forward_metadata is being called on self.draft_model_runner.attn_backend (which is the main attention backend of the draft model runner), but the forward pass (line 697-699) will use forward_batch.attn_backend.

forward_batch.attn_backend is set to self.draft_extend_attn_backend (via batch.get_model_worker_batch() which sets attn_backend appropriately for DRAFT_EXTEND mode).

Shouldn't init_forward_metadata be called on self.draft_extend_attn_backend to ensure consistency with the backend used in the subsequent model.forward call when CUDA graph is not used?

Suggested change
self.draft_model_runner.attn_backend.init_forward_metadata(forward_batch)
logits_output = self.draft_model_runner.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
)
self.draft_extend_attn_backend.init_forward_metadata(forward_batch)
logits_output = self.draft_model_runner.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
)

Comment on lines +266 to +269
reserve_mem = 1024 * 18
# need reserve more memory for spec cuda graph
if self.speculative_algorithm is not None:
reserve_mem = 1024 * 20
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The values 1024 * 18 and 1024 * 20 for reserve_mem are magic numbers. Could these be defined as named constants at the module level or within the class for better readability and maintainability? For example:

_DEFAULT_RESERVE_MEM_MB = 18 * 1024
_SPEC_GRAPH_ADDITIONAL_RESERVE_MEM_MB = 2 * 1024 # Additional for speculative graph

# ... later in __post_init__
reserve_mem = _DEFAULT_RESERVE_MEM_MB
if self.speculative_algorithm is not None:
    reserve_mem += _SPEC_GRAPH_ADDITIONAL_RESERVE_MEM_MB

Or, keeping your current structure, just name the literals:

_BASE_RESERVE_MEM_FOR_CUDA_GRAPH_MB = 18 * 1024
_SPEC_RESERVE_MEM_FOR_CUDA_GRAPH_MB = 20 * 1024

reserve_mem = _BASE_RESERVE_MEM_FOR_CUDA_GRAPH_MB
if self.speculative_algorithm is not None:
    reserve_mem = _SPEC_RESERVE_MEM_FOR_CUDA_GRAPH_MB

@ispobock
Copy link
Collaborator Author

ispobock commented May 27, 2025

The accept rate issue should be fixed in 6707ea9.
Current benchmark results (11% per_user_throughput improvement for bs=1 and 5% for bs=32):

+----+-------------------+--------------------+---------------------+----------------+------------------+---------------+----------------+------------------+---------------+-----------------------+
|    |   max_concurrency |   input_throughput |   output_throughput |   mean_ttft_ms |   median_ttft_ms |   p99_ttft_ms |   mean_tpot_ms |   median_tpot_ms |   p99_tpot_ms |   per_user_throughput |
+====+===================+====================+=====================+================+==================+===============+================+==================+===============+=======================+
|  0 |             1.000 |            139.708 |             139.708 |        154.290 |          149.956 |       176.988 |          7.009 |            6.866 |         8.121 |               139.708 |
+----+-------------------+--------------------+---------------------+----------------+------------------+---------------+----------------+------------------+---------------+-----------------------+
|  1 |             4.000 |            413.854 |             413.854 |        346.852 |          163.505 |      1390.936 |          9.030 |            8.802 |        11.247 |               103.464 |
+----+-------------------+--------------------+---------------------+----------------+------------------+---------------+----------------+------------------+---------------+-----------------------+
|  2 |            16.000 |           1038.688 |            1038.688 |        524.877 |          174.334 |      1913.235 |         13.883 |           13.858 |        17.762 |                64.918 |
+----+-------------------+--------------------+---------------------+----------------+------------------+---------------+----------------+------------------+---------------+-----------------------+
|  3 |            32.000 |           1740.612 |            1740.612 |        357.374 |          178.539 |      1563.505 |         17.306 |           17.497 |        22.053 |                54.394 |
+----+-------------------+--------------------+---------------------+----------------+------------------+---------------+----------------+------------------+---------------+-----------------------+

@zhyncs zhyncs merged commit 6319502 into main May 27, 2025
1 check failed
@zhyncs zhyncs deleted the draft-extend-cg branch May 27, 2025 09:35
@Z-NAVY
Copy link

Z-NAVY commented May 29, 2025

The three parameters, speculative-num-steps, speculative-eagle-topk, and speculative-num-draft-tokens, have a significant impact on performance. Is the only way to select them in practice through continuous parameter tuning?

Layssy pushed a commit to Layssy/sglang-iaas that referenced this pull request Jun 9, 2025
Co-authored-by: Sehoon Kim <sehoonkim@berkeley.edu>
xwu-intel pushed a commit to xwu-intel/sglang that referenced this pull request Jun 17, 2025
Co-authored-by: Sehoon Kim <sehoonkim@berkeley.edu>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants