Skip to content

Conversation

farazkh80
Copy link
Contributor

@farazkh80 farazkh80 commented Jul 10, 2025

Motivation

This PR integrates TRTLLM-GEN MLA Decode kernel from flashinfer to sglang.

Modifications

Intorduced new mla backend option TRTLLMMLABackend in python/sglang/srt/layers/attention/trtllm_mla_backend.py.

Benchmarking

Low Concurrency Results TP=4 (4xB200)

Server Command: python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-R1 --trust-remote-code --attention-backend trtllm_mla/flashinfer/cutlass_mla --page-size 32/64/128 --tp-size 4 --max-running-requests 1 --cuda-graph-max-bs 1 -mem-fraction-static 0.90 |

Client Command: python -m sglang.bench_serving --backend sglang --host 0.0.0.0 --port 30000 --dataset-name random --random-input-len 1024 --random-output-len 8192 --num-prompts 1 --max-concurrency 1

Backend Page Size Requests Max Concurrency Achieved Concurrency Output Throughput (tok/s) Total Throughput (tok/s)
trtllm_mla 32 1 1 1.0 51.68 60.03
flashinfer 32 1 1 1.0 52.56 61.06
trtllm_mla 64 1 1 1.0 52.40 60.88
flashinfer 64 1 1 1.0 51.88 60.27
cutlass_mla 128 1 1 1.0 49.60 57.62

Note: the reason we don't observe any considerable perf gain in low concurrency is because the kernel time is only about 7% of e2e latency (23 µs for kernel out of 300 µs for one layer's forward path). The trtllm_mla kernel itself is 40% faster (17 µs for page-size 32 case which is 6 µs faster than flashinfer MLA's 23 µs) than flashinfer backend.
However there is an extra q_rope and q_nope concatenation step before calling trtllm_batch_decode_with_kv_cache_mla and an extra void flashinfer::zero_gmem_semaphore<int>(T1 *, int) inside flashinfer. These two extra steps add together another 5 µs which cancels out the 6 µs gain from the trtllm_batch_decode_with_kv_cache_mla kernel itself. (all these can be seen in the kernel wise comparison snapshot below)

image

High Concurrency Results

Server Command: python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-R1 --trust-remote-code --attention-backend trtllm_mla/flashinfer --page-size 32/64 --tp-size 8 --max-running-requests 512 --cuda-graph-max-bs 512 -mem-fraction-static 0.90 |

Client Command: python -m sglang.bench_serving --backend sglang --host 0.0.0.0 --port 30000 --dataset-name random --random-input-len 1024 --random-output-len 8192 --num-prompts 1024 --max-concurrency 512

Backend Page Size Requests Max Concurrency Achieved Concurrency Output Throughput (tok/s) Total Throughput (tok/s)
trtllm_mla 32 1024 512 393.25 4697.43 5273.36
flashinfer 32 1024 512 389.44 3311.30 3717.28
trtllm_mla 64 1024 512 390.85 4651.67 5221.99
flashinfer 64 1024 512 388.12 3355.26 3766.63
cutlass_mla 128 1024 512 389.99 3768.46 4230.50

Note: at high concurency the kernel is the major bottleneck, thus we observe the full 40% improvement in e2e perf when compared to flashinfer MLA.

Checklist

  • Format your code according to the Code Formatting with Pre-Commit.
  • Add unit tests as outlined in the Running Unit Tests.
  • E2E DeepSeek R1 server launch and generation sanity
  • Update documentation / docstrings / example tutorials as needed, according to Writing Documentation.
  • Provide throughput / latency benchmark results and accuracy evaluation results as needed, according to Benchmark and Profiling and Accuracy Results.
  • For reviewers: If you haven't made any contributions to this PR and are only assisting with merging the main branch, please remove yourself as a co-author when merging the PR.
  • Please feel free to join our Slack channel at https://slack.sglang.ai to discuss your PR.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @farazkh80, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request is an initial work-in-progress commit to integrate and enable TensorRT-LLM (TRTLLM) Multi-Head Latent Attention (MLA) kernels within the system. The primary objective is to enhance attention computation performance, particularly on Blackwell architectures, by utilizing these specialized kernels from FlashInfer. This lays the foundational groundwork for future performance improvements in decode operations.

Highlights

  • New Attention Backend: Introduced a new TRTLLMMLABackend to integrate TensorRT-LLM (TRTLLM) Multi-Head Latent Attention (MLA) kernels, leveraging the FlashInfer library for optimized decode operations.
  • Blackwell Optimization: The system now intelligently prioritizes the use of the new trtllm_mla attention backend on Blackwell (SM100) architectures, provided the specific FlashInfer kernels are available, aiming for improved performance.
  • Dependency Update: The flashinfer_python dependency in pyproject.toml has been updated to point to a specific Git commit, likely to access the necessary TRTLLM MLA features that are not yet in a released version.
  • Unit Testing: A new comprehensive test suite (test_trtllm_mla_backend.py) has been added to validate the functionality and correctness of the TRTLLM MLA backend across various configurations.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for TRT-LLM MLA kernels. My review has identified a few areas for improvement:

  • The project now depends on a personal fork of flashinfer, which should be addressed for better maintainability.
  • The new TRTLLMMLABackend has some hardcoded values for model dimensions and workspace size that could be made more flexible.
  • There's a critical bug in TRTLLMMLABackend where an assertion will fail for MLA use cases.
  • The backend selection logic in model_runner.py uses a bare except clause which should be more specific.

@farazkh80 farazkh80 changed the title [WIP] [DLFW-5721] trtllm gen mla initial commit [WIP] [DLFW-5721] trtllm gen mla integration Jul 10, 2025
@farazkh80 farazkh80 changed the title [WIP] [DLFW-5721] trtllm gen mla integration [WIP] trtllm gen mla integration Jul 14, 2025
@farazkh80
Copy link
Contributor Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the trtllm_mla attention backend, integrating TensorRT-LLM's Multi-Head Latent Attention kernels. The changes include the backend implementation, integration with the model runner, and a new test suite. Key areas for improvement include dependency management, KV cache preparation, and ensuring robustness for quantized models.

@farazkh80 farazkh80 marked this pull request as ready for review July 15, 2025 03:13
@farazkh80 farazkh80 changed the title [WIP] trtllm gen mla integration TRTLLM gen mla integration Jul 15, 2025
@farazkh80 farazkh80 changed the title TRTLLM gen mla integration TRTLLM Gen MLA Decode Kernel Integration Jul 15, 2025
@farazkh80
Copy link
Contributor Author

PR should be ready for an initial review. Only pending changes are waiting for flashinfer-ai/flashinfer#1289 to deduplicate kv-cache. Duplication of the kv-cache is the main bottleneck for e2e perf. As seen in, nsys kernel-wise comparision capture below.

image

Left hand-side is the flashinfer BatchMLAPagedAttention the one used by default for MLA, and right hand side is the new TRTLLM MLA kernel that this PR adds. This is done on high concurrency=512 on tp=8xB200.

@farazkh80
Copy link
Contributor Author

The kv-cache deduplication is merged now on flashinfer side flashinfer-ai/flashinfer#1289. I have reflected the changes in this PR and now at high concurrency we have 40% throughput improvement. This is currently using bf16 kv-cache for MLA, there will be a seperate PR in future to support fp8 kv-cache and query which should allows us to further improve perf and concurrency.

Backend Page Size Requests Max Concurrency Achieved Concurrency Output Throughput (tok/s) Total Throughput (tok/s)
trtllm_mla 32 1024 512 393.25 4697.43 5273.36
flashinfer 32 1024 512 389.44 3311.30 3717.28

Copy link
Contributor

@merrymercy merrymercy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please fix the lint

Copy link
Contributor

@merrymercy merrymercy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if possible, can you also attach a torch profile. Just to check whether the overlap scheduler works and there is no any cpu-gpu sync

@farazkh80 farazkh80 force-pushed the fkhoubsirat-trtllm_gen_mla_sglang branch from fd9b07f to 8f8c478 Compare July 24, 2025 20:57
farazkh80 and others added 2 commits July 29, 2025 16:40
Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
@kushanam kushanam enabled auto-merge (squash) July 30, 2025 06:30
@kushanam kushanam self-requested a review July 30, 2025 06:49
@kushanam kushanam disabled auto-merge July 30, 2025 07:04
Copy link
Collaborator

@pavanimajety pavanimajety left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please fix the file permissions for files that were modified to 755

Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
Copy link
Collaborator

@pavanimajety pavanimajety left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@zhyncs zhyncs self-assigned this Jul 31, 2025
@zhyncs zhyncs closed this Jul 31, 2025
zhyncs pushed a commit that referenced this pull request Jul 31, 2025
Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
huangzhilin-hzl pushed a commit to huangzhilin-hzl/sglang that referenced this pull request Aug 1, 2025
…gl-project#8632)

Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
TianQiLin666666 pushed a commit to TianQiLin666666/sglang that referenced this pull request Aug 1, 2025
…gl-project#8632)

Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
lifuhuang pushed a commit that referenced this pull request Aug 3, 2025
Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
ShangmingCai pushed a commit that referenced this pull request Aug 5, 2025
Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
ShangmingCai pushed a commit that referenced this pull request Aug 5, 2025
Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
narutolhy pushed a commit to narutolhy/sglang that referenced this pull request Aug 17, 2025
…gl-project#8632)

Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
narutolhy pushed a commit to narutolhy/sglang that referenced this pull request Aug 18, 2025
…gl-project#8632)

Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants