Skip to content

Conversation

mickaelseznec
Copy link
Contributor

@mickaelseznec mickaelseznec commented Jul 17, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

For MLA models that have a q_lora_rank: fuse q_lora and kv_lora into the same matrix (avoids some traffic + one less kernel call).

Also adds a implementation for layernorm to operate on strided input, this avoids memory copy.

Test Plan

Units tests added for strided layernorm. E2E testing & benchamrks results in this PR

Test Result

Accuracy

main (20149d8)

vllm (pretrained=deepseek-ai/DeepSeek-V3-0324,tensor_parallel_size=8,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto                                                                                                 
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|                                                                                                                                                                                       
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|                                                                                                                                                                                       
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9469|±  |0.0062|                                                                                                                                                                                       
|     |       |strict-match    |     5|exact_match|↑  |0.9454|±  |0.0063|

This PR:

vllm (pretrained=deepseek-ai/DeepSeek-V3-0324,add_bos_token=true,tensor_parallel_size=8), gen_kwargs: (None), limit: 250.0, num_fewshot: None, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.952|±  |0.0135|
|     |       |strict-match    |     5|exact_match|↑  |0.952|±  |0.0135|

Performance

main (20149d8)

venv ❯ python benchmarks/benchmark_serving.py --model deepseek-ai/DeepSeek-V3-0324 --dataset-name sharegpt --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json
INFO 07-15 17:16:08 [__init__.py:253] Automatically detected platform cuda.
Namespace(backend='vllm', base_url=None, host='127.0.0.1', port=8000, endpoint='/v1/completions', dataset_name='sharegpt', dataset_path='ShareGPT_V3_unfiltered_cleaned_split.json', no_stream=False, max_concurrency=None, model='deepseek-ai/DeepSeek-V3-0324', tokenizer=None, use_beam_search=False, num_prompts=1000, logprobs=None, request_rate=inf, burstiness=1.0, seed=0, trust_remote_code=False, disable_tqdm=False, profile=False, save_result=False, save_detailed=False, append_result=False, metadata=None, result_dir=None, result_filename=None, ignore_eos=False, percentile_metrics='ttft,tpot,itl', metric_percentiles='99', goodput=None, custom_output_len=256, custom_skip_chat_template=False, sonnet_input_len=550, sonnet_output_len=150, sonnet_prefix_len=200, sharegpt_output_len=None, random_input_len=1024, random_output_len=128, random_range_ratio=0.0, random_prefix_len=0, hf_subset=None, hf_split=None, hf_output_len=None, top_p=None, top_k=None, min_p=None, temperature=None, tokenizer_mode='auto', served_model_name=None, lora_modules=None, ramp_up_strategy=None, ramp_up_start_rps=None, ramp_up_end_rps=None)
Starting initial single prompt test run...
Initial test run completed. Starting main benchmark run...
Traffic request rate: inf RPS.
Burstiness factor: 1.0 (Poisson process)
Maximum request concurrency: None
100%|██████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:58<00:00, 17.10it/s]
============ Serving Benchmark Result ============
Successful requests:                     1000
Benchmark duration (s):                  58.46
Total input tokens:                      219171
Total generated tokens:                  164272
Request throughput (req/s):              17.10
Output token throughput (tok/s):         2809.81
Total Token throughput (tok/s):          6558.65
---------------Time to First Token----------------
Mean TTFT (ms):                          8290.64
Median TTFT (ms):                        7975.92
P99 TTFT (ms):                           14349.76
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          177.57
Median TPOT (ms):                        115.76
P99 TPOT (ms):                           434.24
---------------Inter-token Latency----------------
Mean ITL (ms):                           98.84
Median ITL (ms):                         66.80
P99 ITL (ms):                            435.74
==================================================

This PR:

venv ❯ python benchmarks/benchmark_serving.py --model deepseek-ai/DeepSeek-V3-0324 --dataset-name sharegpt --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json
INFO 07-17 10:27:38 [__init__.py:253] Automatically detected platform cuda.
Namespace(backend='vllm', base_url=None, host='127.0.0.1', port=8000, endpoint='/v1/completions', dataset_name='sharegpt', dataset_path='ShareGPT_V3_unfiltered_cleaned_split.json', no_stream=False, max_concurrency=None, model='deepseek-ai/DeepSeek-V3-0324', tokenizer=None, use_beam_search=False, num_prompts=1000, logprobs=None, request_rate=inf, burstiness=1.0, seed=0, trust_remote_code=False, disable_tqdm=False, profile=False, save_result=False, save_detailed=False, append_result=False, metadata=None, result_dir=None, result_filename=None, ignore_eos=False, percentile_metrics='ttft,tpot,itl', metric_percentiles='99', goodput=None, custom_output_len=256, custom_skip_chat_template=False, sonnet_input_len=550, sonnet_output_len=150, sonnet_prefix_len=200, sharegpt_output_len=None, random_input_len=1024, random_output_len=128, random_range_ratio=0.0, random_prefix_len=0, hf_subset=None, hf_split=None, hf_output_len=None, top_p=None, top_k=None, min_p=None, temperature=None, tokenizer_mode='auto', served_model_name=None, lora_modules=None, ramp_up_strategy=None, ramp_up_start_rps=None, ramp_up_end_rps=None)
Starting initial single prompt test run...
Initial test run completed. Starting main benchmark run...
Traffic request rate: inf RPS.
Burstiness factor: 1.0 (Poisson process)
Maximum request concurrency: None
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:56<00:00, 17.63it/s]
============ Serving Benchmark Result ============
Successful requests:                     1000
Benchmark duration (s):                  56.72
Total input tokens:                      219171
Total generated tokens:                  165898
Request throughput (req/s):              17.63
Output token throughput (tok/s):         2925.10
Total Token throughput (tok/s):          6789.51
---------------Time to First Token----------------
Mean TTFT (ms):                          6917.92
Median TTFT (ms):                        6629.26
P99 TTFT (ms):                           12941.51
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          171.18
Median TPOT (ms):                        108.68
P99 TPOT (ms):                           461.18
---------------Inter-token Latency----------------
Mean ITL (ms):                           95.07
Median ITL (ms):                         67.52
P99 ITL (ms):                            431.03
==================================================

(Optional) Documentation Update

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the deepseek Related to DeepSeek models label Jul 17, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces two significant optimizations: fusing the QKV projection for MLA models and implementing a strided LayerNorm kernel. The changes are well-implemented and should provide the performance benefits described.

The fusion of Q-LoRA and KV-LoRA projections into a single matrix operation for DeepSeek-V2 models is a smart optimization that reduces kernel launch overhead and memory traffic. The introduction of MergedReplicatedLinear to handle this fusion is a clean way to extend the existing linear layer infrastructure.

The addition of a strided layernorm implementation is crucial for the fusion to be effective, as it avoids expensive .contiguous() calls on tensor slices. The CUDA kernels have been updated correctly to handle the input_stride, and the PyTorch bindings are adjusted accordingly.

The test suite has been properly extended to cover the new strided input case for the layernorm kernels, ensuring the correctness of the new implementation.

Overall, this is a high-quality contribution that improves performance while maintaining code clarity and correctness. I have no major concerns.

Signed-off-by: Mickael Seznec <mickael@mistral.ai>
@mickaelseznec mickaelseznec force-pushed the mseznec/merged-qkv-and-strided-layernorm branch from 75b3d50 to e3962ab Compare July 17, 2025 10:38
@mgoin mgoin requested a review from LucasWilkinson July 17, 2025 12:06
Signed-off-by: Mickael Seznec <mickael@mistral.ai>
Signed-off-by: Mickael Seznec <mickael@mistral.ai>
@LucasWilkinson
Copy link
Collaborator

LucasWilkinson commented Jul 17, 2025

Nice thanks for the contribution! Clean, simple and gives perf; the trifecta haha. Overall looks pretty good to me but I think one of the weight loading experts, i.e. @dsikka or @mgoin should take a look to make sure we dont break 4bit quantized models

Signed-off-by: Mickael Seznec <mickael@mistral.ai>
Copy link
Collaborator

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work!

Comment on lines +423 to +424
from vllm.model_executor.layers.quantization.fp8 import (
Fp8LinearMethod, Fp8MoEMethod)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we refactor the code, so that we can put import on top of the file without worrying about the circular import instead here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well it's tricky, because FP8Linear already depends on Linear (which makes sense). I don't know how you'd like to proceed.

I lazily copy/pasted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/linear.py#L787-L791

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I am thinking, if A imports B, B imports A. We can have a base file C, move base things into C, so A imports C, B imports C as well.
We don't need to do it right now in this pr if you don't wish, could be done by refactor in the future.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! Here, the best way would probably be to rely on inheritance by defining (and overriding) methods like: QuantizeMethodBase.supports_block_quantization()

However, I don't have a complete overview on all the supported cases and potential edge-cases and it might make this PR heavier than needed now.

Happy to help with a following PR though :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds great, certainly you can do that in another pr

@mickaelseznec mickaelseznec changed the title feat: add fused MLA QKV + strided layernorm [perf] Add fused MLA QKV + strided layernorm Jul 18, 2025
Signed-off-by: Mickael Seznec <mickael@mistral.ai>
Signed-off-by: Mickael Seznec <mickael@mistral.ai>
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work!

@mgoin mgoin enabled auto-merge (squash) July 21, 2025 18:38
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 21, 2025
@vllm-bot vllm-bot merged commit 4fb5691 into vllm-project:main Jul 22, 2025
106 of 108 checks passed
yeqcharlotte pushed a commit to yeqcharlotte/vllm that referenced this pull request Jul 23, 2025
Signed-off-by: Mickael Seznec <mickael@mistral.ai>
Co-authored-by: mgoin <mgoin64@gmail.com>
zixi-qi pushed a commit to zixi-qi/vllm that referenced this pull request Jul 23, 2025
Signed-off-by: Mickael Seznec <mickael@mistral.ai>
Co-authored-by: mgoin <mgoin64@gmail.com>
Signed-off-by: qizixi <qizixi@meta.com>
LyrisZhong pushed a commit to LyrisZhong/vllm that referenced this pull request Jul 23, 2025
Signed-off-by: Mickael Seznec <mickael@mistral.ai>
Co-authored-by: mgoin <mgoin64@gmail.com>
avigny pushed a commit to avigny/vllm that referenced this pull request Jul 31, 2025
Signed-off-by: Mickael Seznec <mickael@mistral.ai>
Co-authored-by: mgoin <mgoin64@gmail.com>
Signed-off-by: avigny <47987522+avigny@users.noreply.github.com>
wenscarl pushed a commit to wenscarl/vllm that referenced this pull request Aug 4, 2025
Signed-off-by: Mickael Seznec <mickael@mistral.ai>
Co-authored-by: mgoin <mgoin64@gmail.com>
Signed-off-by: shuw <shuw@nvidia.com>
x22x22 pushed a commit to x22x22/vllm that referenced this pull request Aug 5, 2025
Signed-off-by: Mickael Seznec <mickael@mistral.ai>
Co-authored-by: mgoin <mgoin64@gmail.com>
Signed-off-by: x22x22 <wadeking@qq.com>
Pradyun92 pushed a commit to Pradyun92/vllm that referenced this pull request Aug 6, 2025
Signed-off-by: Mickael Seznec <mickael@mistral.ai>
Co-authored-by: mgoin <mgoin64@gmail.com>
npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
Signed-off-by: Mickael Seznec <mickael@mistral.ai>
Co-authored-by: mgoin <mgoin64@gmail.com>
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
Signed-off-by: Mickael Seznec <mickael@mistral.ai>
Co-authored-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
Signed-off-by: Mickael Seznec <mickael@mistral.ai>
Co-authored-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Paul Pak <paulpak58@gmail.com>
taneem-ibrahim pushed a commit to taneem-ibrahim/vllm that referenced this pull request Aug 14, 2025
Signed-off-by: Mickael Seznec <mickael@mistral.ai>
Co-authored-by: mgoin <mgoin64@gmail.com>
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
Signed-off-by: Mickael Seznec <mickael@mistral.ai>
Co-authored-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Diego-Castan <diego.castan@ibm.com>
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 28, 2025
Signed-off-by: Mickael Seznec <mickael@mistral.ai>
Co-authored-by: mgoin <mgoin64@gmail.com>
googlercolin pushed a commit to googlercolin/vllm that referenced this pull request Aug 29, 2025
Signed-off-by: Mickael Seznec <mickael@mistral.ai>
Co-authored-by: mgoin <mgoin64@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
deepseek Related to DeepSeek models ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants