-
-
Notifications
You must be signed in to change notification settings - Fork 10.1k
Integration SM100 FlashInfer fused allreduce RMSNorm #20691
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integration SM100 FlashInfer fused allreduce RMSNorm #20691
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @ilmarkov, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces a significant performance optimization by integrating FlashInfer's fused allreduce RMSNorm kernel into vLLM's compilation pipeline. This new AllReduceFusionPass
automatically detects and replaces common computation patterns involving tensor parallel all-reduce and RMSNorm operations with a highly optimized fused kernel, leading to notable speedups in token generation, particularly for large models in distributed environments.
Highlights
- Performance Optimization: This PR integrates FlashInfer's fused allreduce RMSNorm kernel to accelerate tensor parallel operations within vLLM. This is a significant optimization for distributed inference.
- New Compilation Pass: A new
AllReduceFusionPass
has been introduced. This pass identifies and replaces patterns ofAllReduce
followed byRMSNorm
orFusedAddRMSNorm
with the optimized FlashInfer kernel during compilation. - Configurability: The new fusion pass can be enabled or disabled via the
compilation-config
by settingenable_flashinfer_allreduce_fusion
totrue
within thepass_config
. - Benchmarking Results: Initial benchmarks on Llama-3.1-70B-Instruct (TP=4 on B200 GPUs) demonstrate a tangible performance improvement, showing a 10-15% speedup in Time Per Output Token (TPOT).
- Test Coverage: A new comprehensive test suite (
tests/compile/test_fusion_all_reduce.py
) has been added to validate the correctness and functionality of the fusion pass in a multi-GPU setup.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request integrates FlashInfer's fused all-reduce RMSNorm, which is a great performance optimization. The implementation includes a new fusion pass, updates to the pass manager and configuration, and a comprehensive new test file.
The overall approach is solid, but I've found several critical issues in the pattern matching logic that need to be addressed. Additionally, there are some areas for improvement regarding resource management, test hygiene, and removal of unused code. Addressing these points will improve the robustness and maintainability of this new feature.
def test_all_reduce_fusion_pass_replace(test_model: str, batch_size: int, | ||
seq_len: int, hidden_size: int, | ||
dtype: torch.dtype): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type hint for test_model
is str
, but it's being passed a class object (TestAllReduceRMSNormModel
or TestAllReduceFusedAddRMSNormModel
). This should be corrected to type
for better code clarity and correctness.
def test_all_reduce_fusion_pass_replace(test_model: str, batch_size: int, | |
seq_len: int, hidden_size: int, | |
dtype: torch.dtype): | |
def test_all_reduce_fusion_pass_replace(test_model: type, batch_size: int, | |
seq_len: int, hidden_size: int, | |
dtype: torch.dtype): |
compile_sizes=[2, 4, 8])) | ||
vllm_config.compilation_config.pass_config = PassConfig( | ||
enable_flashinfer_allreduce_fusion=True, | ||
dump_graph_dir=Path("dump_graph"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test hardcodes dump_graph_dir=Path("dump_graph")
, which creates artifacts in the project's working directory. This can interfere with other tests or pollute the source tree. It's best practice to use a temporary directory for test artifacts.
You can use pytest's tmp_path
fixture by:
- Adding
tmp_path: Path
to thetest_all_reduce_fusion_pass_replace
function signature. - Passing
tmp_path
through theargs
oftorch.multiprocessing.spawn
. - Adding
tmp_path: Path
to theall_reduce_fusion_pass_on_test_model
function signature. - Using
tmp_path
fordump_graph_dir
.
|
||
|
||
if flashinfer_comm is not None: | ||
_FI_WORKSPACE_TENSOR = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The use of the global variable _FI_WORKSPACE_TENSOR
to pass the workspace to the custom op makes the code less modular and harder to reason about. While this is a common pattern for torch.compile
custom ops, if possible, consider passing the workspace tensor as an argument to the custom op to avoid global state.
vllm/envs.py
Outdated
@@ -138,6 +138,7 @@ | |||
VLLM_ROCM_QUICK_REDUCE_QUANTIZATION: str = "NONE" | |||
VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True | |||
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None | |||
VLLM_USE_FLASHINFER_ALLREDUCE: bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto it seems you just use the pass config now?
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
vllm/envs.py
Outdated
@@ -138,6 +138,7 @@ | |||
VLLM_ROCM_QUICK_REDUCE_QUANTIZATION: str = "NONE" | |||
VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True | |||
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None | |||
VLLM_USE_FLASHINFER_ALLREDUCE: bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto it seems you just use the pass config now?
6693e83
to
d1068d8
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall, a few minor comments!
) -> None: | ||
pass | ||
|
||
try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this try-except? Seems like a noop
|
||
|
||
class AllReduceFusionPass(VllmInductorPass): | ||
MAX_TOKEN_NUM_INIT = 1024 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we read this from config?
logger.warning( | ||
"Flashinfer is not installed, skipping allreduce fusion pass") | ||
return | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It feels like we should set self.disabled = False
here to clarify that if we reach here, fusion is happening. If the workspace allocation can fail, wrap it into a try-except and handle the exception as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do it in the end of the constructor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know I'm saying we should do it here
vllm/config.py
Outdated
@@ -3945,6 +3945,8 @@ class PassConfig: | |||
"""Whether to enable sequence parallelism.""" | |||
enable_async_tp: bool = False | |||
"""Whether to enable async TP.""" | |||
enable_flashinfer_allreduce_fusion: bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a really long name (needs to be specified on the CLI). What about enable_fi_allreduce_fusion
? Could also do enable_allreduce_fusion
as we don't have multiple kinds of allreduce fusion
if flashinfer_comm is not None: | ||
_FI_WORKSPACE_TENSOR = None | ||
|
||
MB = 1024 * 1024 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: MiB
(we use the proper i
naming elsewhere too)
8: MB // 2, # 512KB | ||
} | ||
|
||
def call_trtllm_allreduce_fusion( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would call this trtllm_fused_allreduce_norm
, fusion is the act of fusing
_FI_WORKSPACE_TENSOR = None | ||
|
||
MB = 1024 * 1024 | ||
_FI_MAX_SIZES = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe a comment what the key here represents (I assume tp size?)
def test_all_reduce_fusion_pass_replace(test_model: torch.nn.Module, | ||
batch_size: int, seq_len: int, | ||
hidden_size: int, dtype: torch.dtype): | ||
num_processes = 2 | ||
|
||
def run_torch_spawn(fn, nprocs): | ||
torch.multiprocessing.spawn(fn, | ||
args=(num_processes, test_model, | ||
batch_size, seq_len, hidden_size, | ||
dtype), | ||
nprocs=nprocs) | ||
|
||
run_torch_spawn(all_reduce_fusion_pass_on_test_model, num_processes) | ||
|
||
|
||
def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int, | ||
test_model_cls: torch.nn.Module, | ||
batch_size: int, seq_len: int, | ||
hidden_size: int, dtype: torch.dtype): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we try to convert this into a decorator:
def test_all_reduce_fusion_pass_replace(test_model: torch.nn.Module, | |
batch_size: int, seq_len: int, | |
hidden_size: int, dtype: torch.dtype): | |
num_processes = 2 | |
def run_torch_spawn(fn, nprocs): | |
torch.multiprocessing.spawn(fn, | |
args=(num_processes, test_model, | |
batch_size, seq_len, hidden_size, | |
dtype), | |
nprocs=nprocs) | |
run_torch_spawn(all_reduce_fusion_pass_on_test_model, num_processes) | |
def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int, | |
test_model_cls: torch.nn.Module, | |
batch_size: int, seq_len: int, | |
hidden_size: int, dtype: torch.dtype): | |
@with_torch_spawn(nprocs=2) # adds local_rank and world_size params | |
def test_all_reduce_fusion_pass_replace(local_rank: int, world_size: int, | |
test_model_cls: torch.nn.Module, | |
batch_size: int, seq_len: int, | |
hidden_size: int, dtype: torch.dtype): |
With with_torch_spawn
looking something like
def with_torch_spawn(nprocs):
def run_torch_spawn(fn):
torch.multiprocessing.spawn(fn,
args=(num_processes, test_model,
batch_size, seq_len, hidden_size,
dtype),
nprocs=nprocs)
return run_torch_spawn
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a bit problematic to inject local_rank and world_size as parameters to test func as it conflicts with pytest.mark.parametrize decorators.
current_platform.seed_everything(0) | ||
|
||
device = torch.device(f"cuda:{local_rank}") | ||
torch.cuda.set_device(device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this line is unnecessary (it's a context manager for setting default device)
Can you also benchmark this/make a unit test on an MoE to make sure the integration works with those interfaces? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another minor comment; will take a final look tomorrow morning
|
||
def __init__(self, config: VllmConfig): | ||
def __init__(self, config: VllmConfig, max_token_num: int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we just read this directly from config instead of taking a separate param?
I haven't been able to run the unit test, have you run it recently?
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems we should also add a guard on using this fusion based on current_platform.is_device_capability(100)
since currently the kernels are only built for Blackwell
Signed-off-by: ilmarkov <imarkov@redhat.com>
Signed-off-by: ilmarkov <imarkov@redhat.com>
Signed-off-by: ilmarkov <imarkov@redhat.com>
Signed-off-by: ilmarkov <imarkov@redhat.com>
994c874
to
a812541
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM and verified the test locally, thanks!
great work! |
Good question @shixianc, currently FI jit-compiles their allreduce kernels only for sm100. However when we add sm90 flags their tests seem to work successfully. So we will post a PR to FI to build for Hopper as well, maybe other arches too |
Also, why is |
I saw it's discussed in flashinfer flashinfer-ai/flashinfer#1223 |
should it be max_token_num * allreduce_in.shape**[1]** * allreduce_in.element_size() ? |
Thanks for providing the link! However, I think that issue only applies to the lamport_oneshot kernel. We could still use the twoshot kernel when num_tokens is larger, right? |
@mgoin I built for sm90a and benchmarked using same commands but found worse TPOT. Let me know if you guys have any benchmarks on hoppers, thanks! |
Purpose
Integrates FlashInfer fused allreduce RMSNorm using fusion passes.
Can be enabled in compilation config:
--compilation-config='{"pass_config": {"enable_flashinfer_allreduce_fusion": true}, "custom_ops": ["+rms_norm"], "level":3}'
Baseline, no custom ops:


After:
Benchmarking End-to-End
Llama-3.1-70B-Instruct TP=4 on B200 GPUs
Client:
Server.
Baseline:
vllm serve meta-llama/Llama-3.1-70B-Instruct --disable-log-requests --no-enable-prefix-caching -tp 4
PR:
vllm serve meta-llama/Llama-3.1-70B-Instruct --disable-log-requests --no-enable-prefix-caching -tp 4 --compilation-config='{"pass_config": {"enable_flashinfer_allreduce_fusion": true}, "custom_ops": ["+rms_norm"], "level":3}'
Results:
Baseline
PR:
TPOT gets around 10-15% speedup
Test Plan
Added
tests/compile/test_fusion_all_reduce.py