Skip to content

Conversation

trevor-m
Copy link
Collaborator

@trevor-m trevor-m commented Jul 31, 2025

Motivation

By doing layernorm before all-gather, we operate on 1/DPth of the tokens reducing the computation time.

Modifications

Perform layernorm before DP gather in layer communicator.
Currently only enabled when DP==TP.

Accuracy Test

python3 -m sglang.launch_server --model-path nvidia/DeepSeek-R1-0528-FP4 --trust-remote-code --quantization modelopt_fp4 --tp 8 --enable-flashinfer-cutlass-moe --enable-ep-moe --ep-size 8 --dp 8 --enable-dp-attention
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1319 --parallel 1319 --port=30000
Accuracy: 0.959
Invalid: 0.000
Latency: 22.890 s
Output throughput: 6335.267 token/s

Benchmark & Profiling

Speedup: 3.79% end to end

python3 -m sglang.launch_server --model-path nvidia/DeepSeek-R1-0528-FP4 --trust-remote-code --quantization modelopt_fp4 --tp 8 --enable-flashinfer-cutlass-moe --enable-ep-moe --ep-size 8 --dp 8 --enable-dp-attention
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 1024 --random-input 1024 --random-output 1024 --random-range-ratio 1 --max-concurrency 1024

BEFORE

============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Max request concurrency:                 1024
Successful requests:                     1024
Benchmark duration (s):                  76.79
Total input tokens:                      1048576
Total generated tokens:                  1048576
Total generated tokens (retokenized):    1046065
Request throughput (req/s):              13.33
Input token throughput (tok/s):          13655.01
Output token throughput (tok/s):         13655.01
Total token throughput (tok/s):          27310.02
Concurrency:                             1021.18
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   76579.19
Median E2E Latency (ms):                 76556.41
---------------Time to First Token----------------
Mean TTFT (ms):                          11978.29
Median TTFT (ms):                        11876.27
P99 TTFT (ms):                           21513.72
---------------Inter-Token Latency----------------
Mean ITL (ms):                           63.15
Median ITL (ms):                         53.14
P95 ITL (ms):                            69.91
P99 ITL (ms):                            77.50
Max ITL (ms):                            18880.99
==================================================

AFTER

============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Max request concurrency:                 1024
Successful requests:                     1024
Benchmark duration (s):                  73.99
Total input tokens:                      1048576
Total generated tokens:                  1048576
Total generated tokens (retokenized):    1044921
Request throughput (req/s):              13.84
Input token throughput (tok/s):          14172.69
Output token throughput (tok/s):         14172.69
Total token throughput (tok/s):          28345.37
Concurrency:                             1021.65
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   73816.06
Median E2E Latency (ms):                 73829.85
---------------Time to First Token----------------
Mean TTFT (ms):                          10411.04
Median TTFT (ms):                        10263.18
P99 TTFT (ms):                           19516.86
---------------Inter-Token Latency----------------
Mean ITL (ms):                           61.98
Median ITL (ms):                         52.59
P95 ITL (ms):                            67.60
P99 ITL (ms):                            74.52
Max ITL (ms):                            17244.45
==================================================

Checklist

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @trevor-m, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on a critical performance optimization within the data parallel communication layer. By strategically moving the layer normalization step to occur before the all-gather operation, the system can perform computations on smaller, local data chunks, leading to reduced overhead and improved overall throughput. The changes are validated by significant benchmark improvements without any reported accuracy regressions.

Highlights

  • Performance Optimization: Reordered the Layer Normalization (layernorm) operation to occur before the Data Parallel (DP) gather operation within the _gather_hidden_states_and_residual function in python/sglang/srt/layers/communicator.py.
  • Efficiency Improvement: This change allows layernorm to operate on a smaller, local subset of tokens (specifically, 1/DPth of the tokens), which significantly reduces the computational time required for the normalization step.
  • Code Refinement: Removed an unnecessary dp_scatter operation, simplifying the data flow within the communicator layer.
  • Benchmark Results: Achieved a notable 3.79% end-to-end speedup, with total token throughput increasing from approximately 27310 tokens/s to 28345 tokens/s, demonstrating the effectiveness of the optimization.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to improve performance by moving the Layer Normalization operation before the all-gather in data parallel settings. The benchmark results show a speedup.

My review identifies a critical correctness issue. The change, combined with existing code, leads to inconsistent tensors across tensor parallel ranks when tp_size is not equal to dp_size. This will cause silent numerical errors in such configurations.

@trevor-m trevor-m changed the title Do layernorm before allgather for DP Draft: Do layernorm before allgather for DP Jul 31, 2025
@trevor-m trevor-m changed the title Draft: Do layernorm before allgather for DP Do layernorm before allgather for DP Jul 31, 2025
@trevor-m trevor-m changed the title Do layernorm before allgather for DP Do layernorm before allgather for DP attention Jul 31, 2025
@kaixih
Copy link
Collaborator

kaixih commented Aug 1, 2025

LGTM! Thx.

@kushanam
Copy link
Collaborator

kushanam commented Aug 1, 2025

@ch-wan could you please take a look? tnx

@zhyncs zhyncs self-assigned this Aug 1, 2025
Copy link
Collaborator

@ch-wan ch-wan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ch-wan ch-wan merged commit 32f2815 into sgl-project:main Aug 3, 2025
60 of 64 checks passed
htiennv pushed a commit to htiennv/sglang that referenced this pull request Aug 5, 2025
narutolhy pushed a commit to narutolhy/sglang that referenced this pull request Aug 17, 2025
narutolhy pushed a commit to narutolhy/sglang that referenced this pull request Aug 18, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants