Skip to content

Conversation

tdoublep
Copy link
Member

@tdoublep tdoublep commented Jul 18, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • Benchmarking
  • Correctness

Purpose

Fairly straightforward implementation of piecewise CUDA graphs for mamba layers in V1. Just tried to follow what is done on main for attention layers. It brings a pretty big performance boost (see benchmarking below).

There is a lot of CPU overhead in the mamba layer, so there might still be some performance gap to V0 until we implement full CUDA graphs. I have a working branch that does this for decode-only batches. Once this one is merged I will follow-up quickly with that one.

Test Plan

I've removed --enforce-eager from the test_hybrid which should test it for all supported models.

Test Result

All tests are passing for me locally.

Benchmarking

$ VLLM_USE_V1=1 VLLM_ATTENTION_BACKEND=FLASHINFER python benchmark_latency.py \
	--model ibm-granite/granite-4.0-tiny-preview \
	--input-len 31500 \
	--output-len 128 \
	--batch-size 1	\
	--num-iters-warmup 3 \
	--num-iters 3 \
	--enforce-eager

produces:

Avg latency: 5.859217131335754 seconds
10% percentile latency: 5.78308402079856 seconds
25% percentile latency: 5.82700428449607 seconds
50% percentile latency: 5.900204723991919 seconds
75% percentile latency: 5.911923774503521 seconds
90% percentile latency: 5.918955204810482 seconds
99% percentile latency: 5.923174062994658 seconds

and now:

$ VLLM_USE_V1=1 VLLM_ATTENTION_BACKEND=FLASHINFER python benchmark_latency.py \
	--model ibm-granite/granite-4.0-tiny-preview \
	--input-len 31500 \
	--output-len 128 \
	--batch-size 1	\
	--num-iters-warmup 3 \
	--num-iters 3

produces:

Avg latency: 2.872378972329898 seconds
10% percentile latency: 2.77216162639088 seconds
25% percentile latency: 2.7751373929932015 seconds
50% percentile latency: 2.7800970039970707 seconds
75% percentile latency: 2.9234795675001806 seconds
90% percentile latency: 3.0095091056020467 seconds
99% percentile latency: 3.061126828463166 seconds

Correctness

In eager mode:

VLLM_USE_V1=1 VLLM_ATTENTION_BACKEND=FLASHINFER lm_eval --model vllm --model_args pretrained=ibm-granite/granite-4.0-tiny-preview,enable_prefix_caching=False,enforce_eager=True --tasks gsm8k --num_fewshot 5 --batch_size auto --limit 500

produces:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.594|±  |0.0220|
|     |       |strict-match    |     5|exact_match|↑  |0.568|±  |0.0222|

Using compile + piecewise CUDA graphs:

VLLM_USE_V1=1 VLLM_ATTENTION_BACKEND=FLASHINFER lm_eval --model vllm --model_args pretrained=ibm-granite/granite-4.0-tiny-preview,enable_prefix_caching=False --tasks gsm8k --num_fewshot 5 --batch_size auto --limit 500

produces:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.580|±  |0.0221|
|     |       |strict-match    |     5|exact_match|↑  |0.556|±  |0.0222|

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Jul 18, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request enables piecewise CUDA Graph for mamba layers. The changes include modifications to several files to integrate mamba layers with CUDA graphs. The review identified several print statements that should be removed before merging, as well as a block of commented-out code that should be removed.

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
@tdoublep
Copy link
Member Author

Ready for review @heheda12345 @tlrmchlsmth

Comment on lines +766 to +767
op_name="mamba_mixer2",
op_func=mamba_mixer2,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use a common op name here (unified_ssm_mixer?) so we can avoid adding a bunch of cases to splitting_ops (fine for this PR though)

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Glad this was so straightforward!

@tlrmchlsmth tlrmchlsmth enabled auto-merge (squash) July 19, 2025 14:05
@tlrmchlsmth tlrmchlsmth merged commit 881e3cb into vllm-project:main Jul 19, 2025
62 checks passed
Copy link
Collaborator

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great job. Does the performance match v0 now? If yes, is it the time to use v1 by default in supported cases? If no, do you know what else need we do?

Comment on lines +126 to +127
output = torch.empty_like(hidden_states)
self.mamba(hidden_states, output, mamba_cache_params, mamba2_metadata)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to change the code back to output = self.mamba(...)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I initially tried that but I ran into problems. I think because you need the output tensor to reside in the CUDA graph? I believe I checked and this is how it is being done for attention.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to do it the other way (using return instead of mutating output) for linear_attention in this PR. It seems work. I will revisit why it didn't work for the mamba_mixer.

@tdoublep
Copy link
Member Author

Does the performance match v0 now? If yes, is it the time to use v1 by default in supported cases? If no, do you know what else need we do?

@heheda12345 Based on my benchmarking, there is still a gap to V0 for small batch sizes. This is because we are still not using CUDA graph for the mamba layers in V1 (whereas in V0 we do). It is a simple change to enable full cuda graphs for mamba layers for decode-only batches. I will create a PR but it would be ideal if we can first merge #21367 to enable the same for FlashInfer. Otherwise we will need to add some logic to handle attention layers using piecewise and mamba layers using full cuda graph.

LyrisZhong pushed a commit to LyrisZhong/vllm that referenced this pull request Jul 23, 2025
avigny pushed a commit to avigny/vllm that referenced this pull request Jul 31, 2025
…ect#21194)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: avigny <47987522+avigny@users.noreply.github.com>
x22x22 pushed a commit to x22x22/vllm that referenced this pull request Aug 5, 2025
…ect#21194)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: x22x22 <wadeking@qq.com>
Pradyun92 pushed a commit to Pradyun92/vllm that referenced this pull request Aug 6, 2025
npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
…ect#21194)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
taneem-ibrahim pushed a commit to taneem-ibrahim/vllm that referenced this pull request Aug 14, 2025
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
…ect#21194)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Diego-Castan <diego.castan@ibm.com>
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 27, 2025
googlercolin pushed a commit to googlercolin/vllm that referenced this pull request Aug 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants