Skip to content

Conversation

izhuhaoran
Copy link
Contributor

@izhuhaoran izhuhaoran commented Jun 25, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

I noticed that in the V1 engine, the default cuda-graph-sizes is a fixed number 512. This may create a performance problem if a user sets max_num_seqs to a larger value, because not all batch sizes will be covered by the CUDA graphs.

To fix this, this PR makes the default cuda-graph_sizes align with max_num_seqs when the user doesn't provide it.

Signed-off-by: izhuhaoran <izhuhaoran@qq.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @izhuhaoran, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request implements a bugfix aimed at improving the performance of the V1 engine, specifically concerning CUDA graph capture. The core change ensures that the default cuda_graph_sizes configuration automatically scales with the max_num_seqs setting, thereby optimizing batch processing and preventing scenarios where certain batch sizes might not be efficiently handled by CUDA graphs.

Highlights

  • Default CUDA Graph Sizes: The default initialization of cuda_graph_sizes in the SchedulerConfig has been changed. Previously, it defaulted to a fixed [512]. Now, if not explicitly provided by the user, it will dynamically default to [max_num_seqs].
  • Performance Improvement / Bugfix: This change addresses a potential performance bottleneck in the V1 engine. By aligning the default cuda_graph_sizes with max_num_seqs, it ensures that CUDA graph capture covers all relevant batch sizes, preventing performance degradation when max_num_seqs is larger than the previous fixed default.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request modifies the cuda_graph_sizes configuration in vllm/config.py to default to max_num_seqs for the V1 engine when no value is provided by the user. This change aims to address a potential performance issue where a fixed cuda-graph-sizes value might not cover all batch sizes when max_num_seqs is set to a larger value. The code changes include updating the default value of cuda_graph_sizes to None and setting it to [self.max_num_seqs] in the __post_init__ method if it's None.

@aarnphm
Copy link
Collaborator

aarnphm commented Jun 25, 2025

cc @ProExpertProg for visibility

Signed-off-by: izhuhaoran <izhuhaoran@qq.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Signed-off-by: izhuhaoran <izhuhaoran@qq.com>
Signed-off-by: izhuhaoran <izhuhaoran@qq.com>
@izhuhaoran
Copy link
Contributor Author

@aarnphm @ProExpertProg Could you please take a final look at this PR to ensure it's ready for auto-merge?

@aarnphm aarnphm enabled auto-merge (squash) June 25, 2025 16:52
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jun 25, 2025
@aarnphm aarnphm merged commit 9f0608f into vllm-project:main Jun 25, 2025
68 checks passed
@ProExpertProg
Copy link
Collaborator

I am realizing that this effectively doubled the cudagraph capture time (from 512 to 1024). CUDA Graphs are less important at larger batch sizes and it's not clear to me that the performance benefits warrant longer collection time. @izhuhaoran do you have any performance numbers that show otherwise?

@izhuhaoran
Copy link
Contributor Author

I am realizing that this effectively doubled the cudagraph capture time (from 512 to 1024). CUDA Graphs are less important at larger batch sizes and it's not clear to me that the performance benefits warrant longer collection time. @izhuhaoran do you have any performance numbers that show otherwise?

Thank you for the feedback. Sorry that I haven't specifically benchmarked the performance impact for larger batches (e.g., >512), but I wanted to address a usability concern with the previous default behavior.

In my view, the previous default value of [512] introduced two key issues:

  • Users setting max_num_seqs > 512 unknowingly miss CUDA Graph optimizations for their actual target batch size.
  • For smaller max_num_seqs values (e.g., 256), it would still capture unnecessarily large graphs (512), increasing startup overhead without too much benefit.

This PR only aligns the default cuda_graph_sizes with max_num_seqs to make the behavior more intuitive and robust for typical use cases. While this may not be the perfect tradeoff either, I think a fixed default of 512 is more problematic. Do you have alternative suggestions for the default size configuration?

For users sensitive to startup overhead or seeking optimal performance, I'd recommend explicitly setting --cuda-graph-sizes to match their specific workload requirements rather than relying on defaults.

m-misiura pushed a commit to m-misiura/vllm that referenced this pull request Jun 26, 2025
@mgoin
Copy link
Member

mgoin commented Jun 26, 2025

First, I don't think this is a clear decision given chunked prefill is on by default. You also need to consider max_num_batched_tokens, right?
Second, it doesn't make sense to compile CUDA graphs for such large sizes by default, like @ProExpertProg said above will happen now.
Can we revert this in the meantime to discuss?

@ProExpertProg
Copy link
Collaborator

I think also the actual shape we cudagraph is not the number of sequences (limited by max_num_seqs), but rather the number of tokens. So even at max_num_seqs=512 we might run size 512 if seq_len is >1 for some of the requests. Someone correct me if I'm wrong here.

@izhuhaoran
Copy link
Contributor Author

izhuhaoran commented Jun 26, 2025

First, I don't think this is a clear decision given chunked prefill is on by default. You also need to consider max_num_batched_tokens, right? Second, it doesn't make sense to compile CUDA graphs for such large sizes by default, like @ProExpertProg said above will happen now. Can we revert this in the meantime to discuss?

Thank you for the feedback. I completely agree with your points — during the PR, I also recognized these limitations you said. The current implementation aligns cuda_graph_sizes with max_num_seqs as a simple default, as I mentioned earlier ("While this may not be the perfect tradeoff either"). I think fixed [512] is arguably even worse.

I fully support further discussion on better strategies for this setting. At this moment, I don't have a more optimal solution to propose, but I'd be happy to explore alternatives (e.g., heuristic-based sizing, adaptive capture ranges) that balance performance and overhead more effectively.

@izhuhaoran
Copy link
Contributor Author

I think also the actual shape we cudagraph is not the number of sequences (limited by max_num_seqs), but rather the number of tokens. So even at max_num_seqs=512 we might run size 512 if seq_len is >1 for some of the requests. Someone correct me if I'm wrong here.

I agree with you. Aligning to max_num_seqs and hardcoding [512] are both suboptimal choices, we need to further discuss this.

@ProExpertProg
Copy link
Collaborator

I agree with you. Aligning to max_num_seqs and hardcoding [512] are both suboptimal choices, we need to further discuss this.

Yes but one of them has doubled the default capture time, increasing startup time. Also, I don't think we can reach a better solution without comprehensive benchmarking with and without cudagraphs. I agree with @mgoin that we should revert this until we have those numbers and reach a decision.

It would be great if you could help with some of the benchmarking! Do you have bandwidth to do so?

@izhuhaoran
Copy link
Contributor Author

Yes but one of them has doubled the default capture time, increasing startup time. Also, I don't think we can reach a better solution without comprehensive benchmarking with and without cudagraphs. I agree with @mgoin that we should revert this until we have those numbers and reach a decision.

It would be great if you could help with some of the benchmarking! Do you have bandwidth to do so?

I'd be glad to do some benchmarking. I will test the performance difference with and without CUDA graphs for max_num_seqs > 512 and with large batch sizes.

While my bandwidth is limited, I'll make time to run these tests to help improve current setting. After all, a reasonable capture size setting is crucial for performance.

@ProExpertProg
Copy link
Collaborator

ProExpertProg commented Jun 26, 2025

Can I try running for a few different model sizes? I assume cudagraphs are more important for smaller models.

@izhuhaoran
Copy link
Contributor Author

izhuhaoran commented Jun 26, 2025

Can I try running for a few different model sizes? I assume cudagraphs are more important for smaller models.

I plan to test qwen2.5-7b & qwen2.5-72b, both tp 4, different max_num_seqs, any suggestion for this plan? And I would appreciate it if you have time to have time do some test as my bandwidth is limited.

gmarinho2 pushed a commit to gmarinho2/vllm that referenced this pull request Jun 26, 2025
@izhuhaoran
Copy link
Contributor Author

Here are some benchmark results, cc @ProExpertProg -- Any thoughts?
BTW, sorry for my limited bandwidth, I haven't tested the 72b model and more input/output len settings.

Qwen2.5-0.5B + TP1 + Nvidia H20 (inlen=1, num_prompts=4096)

outlen max-concurrency out throughout (no cuda graph) mean tpot (no cuda graph) out throughout (with cuda graph) mean tpot (with cuda graph)
512 512 22874.03 21.31 25085.01 19.53
512 640 23745.44 24.39 26071.75 22.44
512 768 24346.01 28.22 25862.20 25.61
512 896 24475.21 32.34 26121.76 29.75
512 1024 25478.02 36.79 26432.65 34.91
1024 512 20700.46 24.35 22338.37 22.56
1024 640 21009.73 28.64 22363.75 27.12
1024 768 20928.19 34.15 22067.09 32.69
1024 896 21519.75 38.41 22336.83 37.20
1024 1024 21926.26 44.23 22724.86 42.56
2048 512 16882.78 30.17 17980.14 28.30
2048 640 17398.15 35.00 18382.80 33.28
2048 768 17241.59 42.09 18065.37 40.38
2048 896 17459.67 48.16 18361.21 45.86
2048 1024 18037.11 54.84 18643.86 52.86

Qwen2.5-7B + TP4 + Nvidia H20 (inlen=1, num_prompts=4096)

outlen max-concurrency out throughout (no cuda graph) mean tpot (no cuda graph) out throughout (with cuda graph) mean tpot (with cuda graph)
512 512 13762.88 36.75 14708.59 34.38
512 640 13857.21 43.65 14721.48 41.11
512 768 13870.09 51.94 14704.59 49.18
512 896 14214.34 58.86 14952.40 55.85
512 1024 14282.70 70.16 14916.21 67.19
1024 512 13185.95 38.60 14192.27 35.86
1024 640 13219.01 45.93 14118.08 43.16
1024 768 13270.73 54.69 14154.43 51.51
1024 896 13603.18 61.81 14337.43 58.75
1024 1024 13765.58 73.64 14368.21 70.52
2048 512 12470.96 40.94 13322.81 38.32
2048 640 12447.39 48.94 13258.93 46.18
2048 768 12560.98 57.94 13234.20 55.21
2048 896 12812.86 65.83 13413.84 62.95
2048 1024 12960.92 78.64 13515.25 75.41

xjpang pushed a commit to xjpang/vllm that referenced this pull request Jun 30, 2025
wseaton pushed a commit to wseaton/vllm that referenced this pull request Jun 30, 2025
…llm-project#20062)

Signed-off-by: izhuhaoran <izhuhaoran@qq.com>
Signed-off-by: Will Eaton <weaton@redhat.com>
wseaton pushed a commit to wseaton/vllm that referenced this pull request Jun 30, 2025
wwl2755-google pushed a commit to wwl2755-google/vllm that referenced this pull request Jul 1, 2025
@yeqcharlotte
Copy link
Collaborator

yeqcharlotte commented Jul 5, 2025

@mgoin @ProExpertProg -- concern on loading time and we may actually chunk based on max_num_batched_tokens is valid. wondering do you feel min(max_num_seqs * 2, 512) is an ok compromise better than status quo? more advanced requirements can go through config customization

we are running into similar issues when running this at very tight mem and very low batch size situation only. it still ooms while trying to capture cudagraph for high batch size we don't need. usable should be more important than potential perf gaps with slightly fewer batch size captured due to chunked prefill.

@mgoin
Copy link
Member

mgoin commented Jul 5, 2025

@yeqcharlotte I think that is a reasonable compromise for now. I understand the advantage of reducing the cudagraphs for small batches

@izhuhaoran
Copy link
Contributor Author

izhuhaoran commented Jul 5, 2025

@yeqcharlotte I think that is a reasonable compromise for now. I understand the advantage of reducing the cudagraphs for small batches

I agree with using "min(max_num_seqs * 2, 512)" as the default. @mgoin ,Shall I update a commit and resubmit as a new PR? also cc @ProExpertProg

@yeqcharlotte
Copy link
Collaborator

@yeqcharlotte I think that is a reasonable compromise for now. I understand the advantage of reducing the cudagraphs for small batches

I agree with using "min(max_num_seqs * 2, 512)" as the default. @mgoin ,Shall I update a commit and resubmit as a new PR? also cc @ProExpertProg

just submit a new pr, or revert revert, whichever easier. also cc: @houseroad @zou3519

@izhuhaoran
Copy link
Contributor Author

Just submit a new PR in #20628. cc @mgoin @yeqcharlotte

avigny pushed a commit to avigny/vllm that referenced this pull request Jul 31, 2025
…llm-project#20062)

Signed-off-by: izhuhaoran <izhuhaoran@qq.com>
Signed-off-by: avigny <47987522+avigny@users.noreply.github.com>
googlercolin pushed a commit to googlercolin/vllm that referenced this pull request Aug 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants