Skip to content

Conversation

yukavio
Copy link
Contributor

@yukavio yukavio commented Jun 9, 2025

Motivation

Optimizing the performance of speculative decode with high throughput.

Modifications

Integrating the target verify stage and draft stage in one cuda graph to decrease the overhead of speculative decoding.
This PR is cooperated with @josephydu

Checklist

performance:
target_model:meta-llama/Llama-2-70b-chat-hf
draft_model: lmsys/sglang-EAGLE-llama2-chat-70B
H20 / TP2 with 500 requests, input/output/ratio 1024/1024/0.7

without speculative decoding:
throughput: 0.62 req/s

eagle:
1.--speculative-num-steps 2 --speculative-eagle-topk 4 --speculative-num-draft-tokens 4
throughput:0.52 req/s
2. --speculative-num-steps 1 --speculative-eagle-topk 2 --speculative-num-draft-tokens 2
throughput: 0.56req/s

naive spec: (equivalent to --speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2)
throughput: 0.68 req/s

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @yukavio, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

Summary of Changes

Hello team, gemini-code-assist here with a summary of this pull request. This PR introduces a new speculative decoding algorithm called "Naive EAGLE". The primary goal is to optimize the performance of speculative decoding, particularly for high throughput scenarios, by integrating the target model's verification stage and the draft model's generation stage into a single CUDA graph. This aims to reduce overhead and improve overall efficiency compared to previous speculative decoding implementations like the standard EAGLE algorithm. The description includes benchmark results showing improved throughput for the "naive spec" approach compared to the original EAGLE implementation under specific conditions.

Highlights

  • New Speculative Algorithm: Introduces a new speculative decoding algorithm named "Naive EAGLE".
  • CUDA Graph Integration: Integrates the target model's verification step and the draft model's generation step into a single CUDA graph to minimize overhead and improve throughput.
  • New Components: Adds NaiveEagleWorker and NaiveEAGLECudaGraphRunner classes to manage the new speculative decoding logic and CUDA graph execution.
  • Forward Mode Extensions: Adds new ForwardMode states (NAIVE_DRAFT_EXTEND, NAIVE_TARGET_VERIFY) to distinguish the forward passes specific to the Naive EAGLE algorithm.
  • Triton Kernel for KV Cache: Includes a new Triton kernel (create_draft_kv_indices) to efficiently manage KV cache indices for the draft model in the naive approach.
  • Performance Benchmarks and Tests: Includes performance benchmarks in the description and adds new unit tests (test_eagle_infer_naive.py) to verify the correctness and functionality of the Naive EAGLE implementation, including batching, EOS handling, acceptance length, logprobs, and penalties.

Changelog

Click here to see the changelog
  • python/sglang/srt/layers/attention/base_attn_backend.py
    • Modified forward method to include is_naive_verify() in the condition for calling forward_decode (lines 67-70).
  • python/sglang/srt/layers/attention/flashinfer_backend.py
    • Updated init_forward_metadata to include is_naive_verify() in the decode/idle condition and added logic to skip attention backend init if naive_skip_attn_backend_init is true during naive verify (lines 193-201).
    • Updated init_forward_metadata_capture_cuda_graph to include is_naive_verify() in the decode/idle condition (line 299).
    • Modified init_forward_metadata_capture_cuda_graph to include is_naive_draft() in the target verify condition and handle custom_mask_buf accordingly (lines 331-337).
    • Updated init_forward_metadata_capture_cuda_graph to use the potentially modified custom_mask_buf (line 348).
    • Updated init_forward_metadata_replay_cuda_graph to include is_naive_verify() in the decode/idle condition (line 379).
    • Updated init_forward_metadata_replay_cuda_graph to include is_naive_draft() in the target verify condition (line 388).
  • python/sglang/srt/layers/logits_processor.py
    • Added and not forward_batch.forward_mode.is_naive_draft() to the condition for determining extend_return_top_logprob (line 122).
    • Included is_naive_draft() and is_naive_verify() in the condition for pruning hidden states in the forward method (lines 235-236).
  • python/sglang/srt/managers/scheduler.py
    • Added logic to instantiate NaiveEagleWorker if the speculative algorithm is NAIVE_EAGLE (lines 270-280).
  • python/sglang/srt/model_executor/forward_batch_info.py
    • Added new ForwardMode enums: NAIVE_DRAFT_EXTEND and NAIVE_TARGET_VERIFY (lines 68-69).
    • Updated is_extend method to include NAIVE_DRAFT_EXTEND (line 84).
    • Updated is_draft_extend method to include NAIVE_DRAFT_EXTEND (lines 100-101).
    • Updated is_extend_or_draft_extend_or_mixed method to include NAIVE_DRAFT_EXTEND (line 109).
    • Added new helper methods is_naive_draft and is_naive_verify (lines 125-128).
    • Added naive_skip_attn_backend_init boolean field to ForwardBatch (line 259).
  • python/sglang/srt/model_executor/model_runner.py
    • Added a condition to skip the default init_cuda_graphs() if the speculative algorithm is NAIVE_EAGLE (lines 213-216).
    • Included NAIVE_EAGLE in the condition for initializing plan_stream_for_flashinfer (line 878).
    • Updated the condition for calling forward_decode to include is_naive_verify() (lines 1035-1037).
    • Added a check to skip the default cuda graph replay if the speculative algorithm is NAIVE_EAGLE (line 1027).
  • python/sglang/srt/server_args.py
    • Added requests_all_greedy boolean argument with a default value (line 143).
    • Included NAIVE_EAGLE in the list of speculative algorithms checked for setting max_running_requests and added a condition to not set max_running_requests if the algorithm is NAIVE_EAGLE (lines 337-340).
    • Modified the auto-choosing of speculative parameters to set fixed values for NAIVE_EAGLE (lines 351-364).
    • Added NAIVE_EAGLE to the choices for the --speculative-algorithm CLI argument (line 881).
    • Added the --requests-all-greedy CLI argument (lines 926-930).
  • python/sglang/srt/speculative/eagle_utils.py
    • Added prepare_extend_after_decode_for_naive_eagle method, which includes a Triton kernel call, to prepare batch info for draft extend after decode in the naive approach (lines 80-107).
    • Added create_draft_kv_indices Triton kernel to create KV cache indices for the draft model (lines 835-870).
  • python/sglang/srt/speculative/naive_eagle.py
    • Added a new file implementing the NaiveEagleWorker class, which handles the logic for the Naive EAGLE speculative decoding algorithm, including draft and verify steps, CUDA graph integration, and state management.
  • python/sglang/srt/speculative/navie_eagle_cuda_graph_runner.py
    • Added a new file implementing the NaiveEAGLECudaGraphRunner class, responsible for capturing and replaying CUDA graphs that combine the target verify and draft extend steps for the Naive EAGLE algorithm.
  • python/sglang/srt/speculative/spec_info.py
    • Added NAIVE_EAGLE to the SpeculativeAlgorithm enum (line 8).
    • Updated is_eagle method to include NAIVE_EAGLE (lines 14-18).
    • Added is_naive_eagle helper method (line 23).
    • Updated from_string method to map "NAIVE_EAGLE" to the new enum value (line 31).
  • test/srt/test_eagle_infer_naive.py
    • Added a new file containing unit tests specifically for the NAIVE_EAGLE implementation, covering correctness, batching, EOS, acceptance length, aborts, max_tokens=1, GSM8k, logprobs, and penalties.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new 'NAIVE_EAGLE' speculative decoding algorithm, aiming to optimize performance by integrating the target verification and draft stages into a single CUDA graph. The changes are quite extensive, involving new forward modes, a new worker (NaiveEagleWorker), and a dedicated CUDA graph runner (NaiveEAGLECudaGraphRunner).

The core idea of fusing these stages into one CUDA graph is promising for reducing overhead. The provided performance numbers show that "naive spec" (presumably this new implementation) achieves higher throughput (0.68 req/s) compared to both "without speculative decoding" (0.62 req/s) and the existing "eagle" implementation (0.52-0.56 req/s in the tested H20 / TP2 setup).

Overall, the changes look well-structured. I have a few points regarding potential improvements and clarifications, mainly concerning a filename typo, a potentially unused Triton kernel, and a confusing comment in another kernel. Additionally, there are a couple of TODO items in the new code that mark current limitations (e.g., page size > 1 support).

It would also be beneficial to update the documentation to reflect this new speculative decoding algorithm and its associated parameters once the PR is finalized.

Summary of Findings

  • Filename Typo: The filename python/sglang/srt/speculative/navie_eagle_cuda_graph_runner.py and its import in naive_eagle.py contain a typo ('navie' instead of 'naive'). This should be corrected for consistency.
  • Triton Kernel Clarity/Redundancy: There's a potentially confusing comment in the create_draft_kv_indices Triton kernel regarding expand_factor. Additionally, an unused create_extend_spec_info Triton kernel definition was found in naive_eagle.py.
  • Known Limitations (TODOs): The new naive_eagle.py file includes NotImplementedError for page size > 1 and a TODO for aligning evict mask to page size, indicating current limitations of the implementation.

Merge Readiness

The pull request introduces a significant optimization for speculative decoding. The core logic seems sound, and the performance improvements are promising. However, there are a few medium-severity issues that should be addressed before merging:

  1. A filename typo for navie_eagle_cuda_graph_runner.py.
  2. A potentially unused Triton kernel in naive_eagle.py.
  3. A confusing comment in the create_draft_kv_indices Triton kernel in eagle_utils.py.

Addressing these points will improve code clarity and maintainability. Additionally, acknowledging the TODO items regarding page size > 1 support would be helpful. I am unable to approve pull requests, so please have another team member review and approve these changes after the suggested modifications are made.

seq_len_data = tl.load(seq_lens + batch_offset, mask=batch_offset < bid)
seq_len = tl.load(seq_lens + bid)
cum_seq_len = tl.sum(seq_len_data)
kv_offset = cum_seq_len * 2 - bid # only for expand_factor==1 currently
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment # only for expand_factor==1 currently for the kv_offset calculation seems inconsistent with the expand_factor: tl.constexpr (line 842) which has a comment # only support ==2 currently and the overall kernel logic which appears tailored for expand_factor=2 (e.g., writing 2 * seq_len - 1 entries per request).

Could you clarify if this comment is outdated or if there's a specific scenario where expand_factor=1 is used with this offset calculation? If the kernel is indeed hardcoded or primarily designed for expand_factor=2, updating the comment here would improve clarity.

assign_req_to_token_pool,
create_draft_kv_indices,
)
from sglang.srt.speculative.navie_eagle_cuda_graph_runner import (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There seems to be a typo in the imported module name: navie_eagle_cuda_graph_runner. Should this be naive_eagle_cuda_graph_runner to match the naming convention used elsewhere (e.g., NaiveEagleWorker, NAIVE_EAGLE algorithm)?

If so, the filename python/sglang/srt/speculative/navie_eagle_cuda_graph_runner.py would also need to be corrected.

Comment on lines 60 to 75
@triton.jit
def create_extend_spec_info(
seq_len,
accept_len,
accept_len_cum,
positions,
accept_len_upper: tl.constexpr, # 1
):
pid = tl.program_id(axis=0)
offset = 0 if pid == 0 else tl.load(accept_len_cum + pid - 1)
seq_length = tl.load(seq_len + pid)
accept_length = tl.load(accept_len + pid) # 1
positions_ptr = positions + offset
data = tl.arange(0, accept_len_upper)
mask = data < accept_length
tl.store(positions_ptr + data, seq_length - accept_length + data, mask)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This Triton kernel create_extend_spec_info appears to be unused within this NaiveEagleWorker class. The methods prepare_extend_after_decode and prepare_extend_after_decode_for_naive_eagle are imported from sglang.srt.speculative.eagle_utils and they use a different version of create_extend_spec_info defined in that file.

Is this local definition of create_extend_spec_info intended for a different purpose, or could it be removed if it's indeed dead code?

next_power_of_2(num_seqs),
)
else:
raise NotImplementedError("TODO: Page size > 1 not supported yet")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The NotImplementedError for page size > 1 is a current limitation. It's good that this is explicitly marked. Could you briefly mention if there are plans to support page size > 1 for this naive eagle approach in the near future, or if it's a longer-term item?

Comment on lines 523 to 524
# TODO: align_evict_mask_to_page_size, see eagle_utils.py/align_evict_mask_to_page_size
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This TODO points to eagle_utils.py/align_evict_mask_to_page_size. It's good to track this. Does this imply that for page_size != 1, the KV cache eviction might not be optimal or could have issues with the naive eagle approach currently?

@zhyncs
Copy link
Member

zhyncs commented Jun 9, 2025

Hi @yukavio May you fix the conflicts

Copy link
Contributor

@merrymercy merrymercy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Todo

  • Support step > 1
  • Support overlap scheduler. Reorganize the code into three parts, so we can reuse the even loop logics
    • prepare_one_batch (allocate kv cache locations)
    • forward (non-blocking, just kernel launch)
    • process_results (free kv cache locations)
  • Support page size > 1

# verify
indices = torch.arange(num_seqs, device="cuda", dtype=torch.int32)
accept_index[:, 0] = indices * 2
if forward_batch.sampling_info.is_all_greedy:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some code redundancy

)

if is_cuda():
from sgl_kernel import top_k_renorm_prob, top_p_renorm_prob
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try to reuse some code from cuda_graph_runner.py to reduce redundancy.

@@ -289,7 +289,18 @@ def __init__(
)

# Launch a draft worker for speculative decoding
if self.spec_algorithm.is_eagle():
if self.spec_algorithm.is_naive_eagle():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dose this feature support for DeepseekR1?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can not now, but will in feature.
But you can use MTP in DeepseekR1, it also based on Speculative Decoding.
See in https://docs.sglang.ai/references/deepseek.html#multi-token-prediction

@ch-wan
Copy link
Collaborator

ch-wan commented Jul 11, 2025

@yukavio I managed to launch your code without encountering any error, but the current implementation is not compatible with --disable-cuda-graph? Could you please support this minor feature as disabling cuda graph is usually necessary for debugging. Thanks!

@Z-NAVY
Copy link

Z-NAVY commented Jul 14, 2025

There are currently some parameter shape issues in the MLA backend. Will this be supported in the future?

@josephydu
Copy link
Contributor

@yukavio I managed to launch your code without encountering any error, but the current implementation is not compatible with --disable-cuda-graph? Could you please support this minor feature as disabling cuda graph is usually necessary for debugging. Thanks!

fixed in new commit.

@josephydu
Copy link
Contributor

There are currently some parameter shape issues in the MLA backend. Will this be supported in the future?

it will

@ch-wan
Copy link
Collaborator

ch-wan commented Jul 14, 2025

@Z-NAVY
Copy link

Z-NAVY commented Jul 16, 2025

I see that the code has speculative-num-draft-tokens=2 hardcoded by self.num_tokens_per_bs = 2 in NaiveEAGLECudaGraphRunner. I'm not sure if I understand it correctly. Will other numbers be supported in the future? Thank you in advance.

@josephydu
Copy link
Contributor

I see that the code has speculative-num-draft-tokens=2 hardcoded by self.num_tokens_per_bs = 2. I'm not sure if I understand it correctly. Will other numbers be supported in the future? Thank you in advance.

We evaluated it with num_tokens_per_bs > 2, but there is no performance improvement almost.

@Z-NAVY
Copy link

Z-NAVY commented Jul 17, 2025

I see that the code has speculative-num-draft-tokens=2 hardcoded by self.num_tokens_per_bs = 2. I'm not sure if I understand it correctly. Will other numbers be supported in the future? Thank you in advance.

We evaluated it with num_tokens_per_bs > 2, but there is no performance improvement almost.

However, a larger draft_num_tokens will result in a longer accept length, which is equivalent to reducing the inference time per token. Additionally, when we complete MTP inference for the entire cuda graph, can we adopt the ‌zero-overhead scheduler‌ to further reduce end-to-end inference latency? I'm currently testing this locally.
123

@josephydu
Copy link
Contributor

josephydu commented Jul 17, 2025

I see that the code has speculative-num-draft-tokens=2 hardcoded by self.num_tokens_per_bs = 2. I'm not sure if I understand it correctly. Will other numbers be supported in the future? Thank you in advance.

We evaluated it with num_tokens_per_bs > 2, but there is no performance improvement almost.

However, a larger draft_num_tokens will result in a longer accept length, which is equivalent to reducing the inference time per token. Additionally, when we complete MTP inference for the entire cuda graph, can we adopt the ‌zero-overhead scheduler‌ to further reduce end-to-end inference latency? I'm currently testing this locally. 123

Good questions!

  1. We tested the accept len, and it didn't increase almost when we increased draft_num_tokens.You can view the table below for details.
image
  1. It will support overlap scheduler in future, but not in this pr.

@Z-NAVY
Copy link

Z-NAVY commented Jul 17, 2025

I see that the code has speculative-num-draft-tokens=2 hardcoded by self.num_tokens_per_bs = 2. I'm not sure if I understand it correctly. Will other numbers be supported in the future? Thank you in advance.

We evaluated it with num_tokens_per_bs > 2, but there is no performance improvement almost.

However, a larger draft_num_tokens will result in a longer accept length, which is equivalent to reducing the inference time per token. Additionally, when we complete MTP inference for the entire cuda graph, can we adopt the ‌zero-overhead scheduler‌ to further reduce end-to-end inference latency? I'm currently testing this locally. 123

Good questions!

  1. We tested the accept len, and it didn't increase almost when we increased draft_num_tokens.You can view the table below for details.

image 2. It will support overlap schedule in future, but not in this pr.

The accept length is related to the training effectiveness of the draft model. In our internal scenarios, different draft_num_tokens lead to significant variations in accept length. A higher accept length can notably reduce the average generation time per token.
I will continue testing this PR locally and attempt to apply the zero-overhead scheduler.

@josephydu
Copy link
Contributor

I see that the code has speculative-num-draft-tokens=2 hardcoded by self.num_tokens_per_bs = 2. I'm not sure if I understand it correctly. Will other numbers be supported in the future? Thank you in advance.

We evaluated it with num_tokens_per_bs > 2, but there is no performance improvement almost.

However, a larger draft_num_tokens will result in a longer accept length, which is equivalent to reducing the inference time per token. Additionally, when we complete MTP inference for the entire cuda graph, can we adopt the ‌zero-overhead scheduler‌ to further reduce end-to-end inference latency? I'm currently testing this locally. 123

Good questions!

  1. We tested the accept len, and it didn't increase almost when we increased draft_num_tokens.You can view the table below for details.

image 2. It will support overlap schedule in future, but not in this pr.

The accept length is related to the training effectiveness of the draft model. In our internal scenarios, different draft_num_tokens lead to significant variations in accept length. A higher accept length can notably reduce the average generation time per token.

You're right.
You can try EAGLE firstly in your scenarios, we currently only plan to support draft_num_tokens=2, and we will consider it after evaluation. Thank you for your advice.

@zhanxxxxxxx
Copy link

May I ask if this supports Qwen?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.