Skip to content

Conversation

ilmarkov
Copy link
Contributor

@ilmarkov ilmarkov commented Jul 9, 2025

Purpose

Integrates FlashInfer fused allreduce RMSNorm using fusion passes.
Can be enabled in compilation config: --compilation-config='{"pass_config": {"enable_flashinfer_allreduce_fusion": true}, "custom_ops": ["+rms_norm"], "level":3}'

Baseline, no custom ops:
image
After:
image

Benchmarking End-to-End

Llama-3.1-70B-Instruct TP=4 on B200 GPUs

Client:

   DURATION_SECONDS=60; \
   vllm bench serve \
    --model meta-llama/Llama-3.1-70B-Instruct \
    --dataset-name sonnet  \
    --dataset-path benchmarks/sonnet.txt \
    --request-rate "$qps" \
    --num-prompts $((DURATION_SECONDS * qps))`

Server.

Baseline:
vllm serve meta-llama/Llama-3.1-70B-Instruct --disable-log-requests --no-enable-prefix-caching -tp 4
PR: vllm serve meta-llama/Llama-3.1-70B-Instruct --disable-log-requests --no-enable-prefix-caching -tp 4 --compilation-config='{"pass_config": {"enable_flashinfer_allreduce_fusion": true}, "custom_ops": ["+rms_norm"], "level":3}'

Results:

Baseline

QPS Mean TTFT (ms) Median TTFT (ms) Mean TPOT (ms) Median TPOT (ms)
1 42.557 41.910 13.210 13.216
5 46.372 44.402 14.442 14.382
10 54.465 49.271 17.834 17.953

PR:

QPS Mean TTFT (ms) Median TTFT (ms) Mean TPOT (ms) Median TPOT (ms)
1 43.9 42.54 11.875 11.864
5 48.234 46.041 13.261 13.185
10 52.725 46.457 15.352 14.292

TPOT gets around 10-15% speedup

Test Plan

Added tests/compile/test_fusion_all_reduce.py

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @ilmarkov, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant performance optimization by integrating FlashInfer's fused allreduce RMSNorm kernel into vLLM's compilation pipeline. This new AllReduceFusionPass automatically detects and replaces common computation patterns involving tensor parallel all-reduce and RMSNorm operations with a highly optimized fused kernel, leading to notable speedups in token generation, particularly for large models in distributed environments.

Highlights

  • Performance Optimization: This PR integrates FlashInfer's fused allreduce RMSNorm kernel to accelerate tensor parallel operations within vLLM. This is a significant optimization for distributed inference.
  • New Compilation Pass: A new AllReduceFusionPass has been introduced. This pass identifies and replaces patterns of AllReduce followed by RMSNorm or FusedAddRMSNorm with the optimized FlashInfer kernel during compilation.
  • Configurability: The new fusion pass can be enabled or disabled via the compilation-config by setting enable_flashinfer_allreduce_fusion to true within the pass_config.
  • Benchmarking Results: Initial benchmarks on Llama-3.1-70B-Instruct (TP=4 on B200 GPUs) demonstrate a tangible performance improvement, showing a 10-15% speedup in Time Per Output Token (TPOT).
  • Test Coverage: A new comprehensive test suite (tests/compile/test_fusion_all_reduce.py) has been added to validate the correctness and functionality of the fusion pass in a multi-GPU setup.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

mergify bot commented Jul 9, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ilmarkov.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jul 9, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request integrates FlashInfer's fused all-reduce RMSNorm, which is a great performance optimization. The implementation includes a new fusion pass, updates to the pass manager and configuration, and a comprehensive new test file.

The overall approach is solid, but I've found several critical issues in the pattern matching logic that need to be addressed. Additionally, there are some areas for improvement regarding resource management, test hygiene, and removal of unused code. Addressing these points will improve the robustness and maintainability of this new feature.

Comment on lines 78 to 80
def test_all_reduce_fusion_pass_replace(test_model: str, batch_size: int,
seq_len: int, hidden_size: int,
dtype: torch.dtype):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The type hint for test_model is str, but it's being passed a class object (TestAllReduceRMSNormModel or TestAllReduceFusedAddRMSNormModel). This should be corrected to type for better code clarity and correctness.

Suggested change
def test_all_reduce_fusion_pass_replace(test_model: str, batch_size: int,
seq_len: int, hidden_size: int,
dtype: torch.dtype):
def test_all_reduce_fusion_pass_replace(test_model: type, batch_size: int,
seq_len: int, hidden_size: int,
dtype: torch.dtype):

compile_sizes=[2, 4, 8]))
vllm_config.compilation_config.pass_config = PassConfig(
enable_flashinfer_allreduce_fusion=True,
dump_graph_dir=Path("dump_graph"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The test hardcodes dump_graph_dir=Path("dump_graph"), which creates artifacts in the project's working directory. This can interfere with other tests or pollute the source tree. It's best practice to use a temporary directory for test artifacts.

You can use pytest's tmp_path fixture by:

  1. Adding tmp_path: Path to the test_all_reduce_fusion_pass_replace function signature.
  2. Passing tmp_path through the args of torch.multiprocessing.spawn.
  3. Adding tmp_path: Path to the all_reduce_fusion_pass_on_test_model function signature.
  4. Using tmp_path for dump_graph_dir.



if flashinfer_comm is not None:
_FI_WORKSPACE_TENSOR = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use of the global variable _FI_WORKSPACE_TENSOR to pass the workspace to the custom op makes the code less modular and harder to reason about. While this is a common pattern for torch.compile custom ops, if possible, consider passing the workspace tensor as an argument to the custom op to avoid global state.

vllm/envs.py Outdated
@@ -138,6 +138,7 @@
VLLM_ROCM_QUICK_REDUCE_QUANTIZATION: str = "NONE"
VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None
VLLM_USE_FLASHINFER_ALLREDUCE: bool = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The new environment variable VLLM_USE_FLASHINFER_ALLREDUCE is defined but doesn't seem to be used anywhere in this pull request. If it's intended for future use, it might be better to add it when it's actually used. Otherwise, it's dead code and should be removed to avoid confusion.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto it seems you just use the pass config now?

Copy link

github-actions bot commented Jul 9, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

vllm/envs.py Outdated
@@ -138,6 +138,7 @@
VLLM_ROCM_QUICK_REDUCE_QUANTIZATION: str = "NONE"
VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None
VLLM_USE_FLASHINFER_ALLREDUCE: bool = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto it seems you just use the pass config now?

@ilmarkov ilmarkov force-pushed the imarkov/flashinfer_fused_allreduce branch from 6693e83 to d1068d8 Compare July 9, 2025 18:59
@mergify mergify bot removed the needs-rebase label Jul 9, 2025
Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall, a few minor comments!

) -> None:
pass

try:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this try-except? Seems like a noop



class AllReduceFusionPass(VllmInductorPass):
MAX_TOKEN_NUM_INIT = 1024
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we read this from config?

logger.warning(
"Flashinfer is not installed, skipping allreduce fusion pass")
return

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels like we should set self.disabled = False here to clarify that if we reach here, fusion is happening. If the workspace allocation can fail, wrap it into a try-except and handle the exception as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do it in the end of the constructor

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know I'm saying we should do it here

vllm/config.py Outdated
@@ -3945,6 +3945,8 @@ class PassConfig:
"""Whether to enable sequence parallelism."""
enable_async_tp: bool = False
"""Whether to enable async TP."""
enable_flashinfer_allreduce_fusion: bool = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a really long name (needs to be specified on the CLI). What about enable_fi_allreduce_fusion? Could also do enable_allreduce_fusion as we don't have multiple kinds of allreduce fusion

if flashinfer_comm is not None:
_FI_WORKSPACE_TENSOR = None

MB = 1024 * 1024
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: MiB (we use the proper i naming elsewhere too)

8: MB // 2, # 512KB
}

def call_trtllm_allreduce_fusion(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would call this trtllm_fused_allreduce_norm, fusion is the act of fusing

_FI_WORKSPACE_TENSOR = None

MB = 1024 * 1024
_FI_MAX_SIZES = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a comment what the key here represents (I assume tp size?)

Comment on lines +77 to +97
def test_all_reduce_fusion_pass_replace(test_model: torch.nn.Module,
batch_size: int, seq_len: int,
hidden_size: int, dtype: torch.dtype):
num_processes = 2

def run_torch_spawn(fn, nprocs):
torch.multiprocessing.spawn(fn,
args=(num_processes, test_model,
batch_size, seq_len, hidden_size,
dtype),
nprocs=nprocs)

run_torch_spawn(all_reduce_fusion_pass_on_test_model, num_processes)


def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
test_model_cls: torch.nn.Module,
batch_size: int, seq_len: int,
hidden_size: int, dtype: torch.dtype):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we try to convert this into a decorator:

Suggested change
def test_all_reduce_fusion_pass_replace(test_model: torch.nn.Module,
batch_size: int, seq_len: int,
hidden_size: int, dtype: torch.dtype):
num_processes = 2
def run_torch_spawn(fn, nprocs):
torch.multiprocessing.spawn(fn,
args=(num_processes, test_model,
batch_size, seq_len, hidden_size,
dtype),
nprocs=nprocs)
run_torch_spawn(all_reduce_fusion_pass_on_test_model, num_processes)
def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
test_model_cls: torch.nn.Module,
batch_size: int, seq_len: int,
hidden_size: int, dtype: torch.dtype):
@with_torch_spawn(nprocs=2) # adds local_rank and world_size params
def test_all_reduce_fusion_pass_replace(local_rank: int, world_size: int,
test_model_cls: torch.nn.Module,
batch_size: int, seq_len: int,
hidden_size: int, dtype: torch.dtype):

With with_torch_spawn looking something like

def with_torch_spawn(nprocs):
    def run_torch_spawn(fn):
        torch.multiprocessing.spawn(fn,
                                    args=(num_processes, test_model,
                                          batch_size, seq_len, hidden_size,
                                          dtype),
                                    nprocs=nprocs)

    return run_torch_spawn

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit problematic to inject local_rank and world_size as parameters to test func as it conflicts with pytest.mark.parametrize decorators.

current_platform.seed_everything(0)

device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this line is unnecessary (it's a context manager for setting default device)

@mgoin
Copy link
Member

mgoin commented Jul 10, 2025

Can you also benchmark this/make a unit test on an MoE to make sure the integration works with those interfaces?

Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another minor comment; will take a final look tomorrow morning


def __init__(self, config: VllmConfig):
def __init__(self, config: VllmConfig, max_token_num: int):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just read this directly from config instead of taking a separate param?

@mgoin
Copy link
Member

mgoin commented Jul 10, 2025

I haven't been able to run the unit test, have you run it recently?

pytest -s -v tests/compile/test_fusion_all_reduce.py
INFO 07-10 16:15:19 [__init__.py:253] Automatically detected platform cuda.
==================================================================================================================================================================== test session starts ====================================================================================================================================================================
platform linux -- Python 3.12.3, pytest-8.3.5, pluggy-1.6.0 -- /home/mgoin/venvs/vllm/bin/python3
cachedir: .pytest_cache
rootdir: /home/mgoin/code/vllm
configfile: pyproject.toml
plugins: anyio-4.9.0
collecting ... WARNING 07-10 16:15:22 [interface.py:519] Current platform cuda does not have '_pytestfixturefunction' attribute.
WARNING 07-10 16:15:22 [interface.py:519] Current platform cuda does not have '__test__' attribute.
WARNING 07-10 16:15:22 [interface.py:519] Current platform cuda does not have '__bases__' attribute.
WARNING 07-10 16:15:22 [interface.py:519] Current platform cuda does not have '__test__' attribute.
collected 4 items                                                                                                                                                                                                                                                                                                                                           

tests/compile/test_fusion_all_reduce.py::test_all_reduce_fusion_pass_replace[dtype0-4096-8-8-TestAllReduceRMSNormModel] Fork a new process to run a test 435978
Fork a new process to run a test 0
INFO 07-10 16:15:26 [__init__.py:253] Automatically detected platform cuda.
INFO 07-10 16:15:26 [__init__.py:253] Automatically detected platform cuda.
WARNING 07-10 16:15:29 [config.py:4895] Current vLLM config is not set.
WARNING 07-10 16:15:29 [config.py:4895] Current vLLM config is not set.
WARNING 07-10 16:15:30 [config.py:4895] Current vLLM config is not set.
WARNING 07-10 16:15:30 [config.py:4895] Current vLLM config is not set.
WARNING 07-10 16:15:30 [config.py:4895] Current vLLM config is not set.
WARNING 07-10 16:15:30 [config.py:4895] Current vLLM config is not set.
INFO 07-10 16:15:30 [__init__.py:1344] Found nccl from library libnccl.so.2
INFO 07-10 16:15:30 [__init__.py:1344] Found nccl from library libnccl.so.2
INFO 07-10 16:15:30 [pynccl.py:70] vLLM is using nccl==2.26.2
INFO 07-10 16:15:30 [pynccl.py:70] vLLM is using nccl==2.26.2
libibverbs: Warning: couldn't load driver 'libvmw_pvrdma-rdmav34.so': libvmw_pvrdma-rdmav34.so: cannot open shared object file: No such file or directory
libibverbs: Warning: couldn't load driver 'libvmw_pvrdma-rdmav34.so': libvmw_pvrdma-rdmav34.so: cannot open shared object file: No such file or directory
INFO 07-10 16:16:07 [custom_all_reduce_utils.py:246] reading GPU P2P access cache from /home/mgoin/.cache/vllm/gpu_p2p_access_cache_for_0,1,2,3,4,5,6,7.json
INFO 07-10 16:16:07 [custom_all_reduce_utils.py:246] reading GPU P2P access cache from /home/mgoin/.cache/vllm/gpu_p2p_access_cache_for_0,1,2,3,4,5,6,7.json
INFO 07-10 16:16:07 [shm_broadcast.py:289] vLLM message queue communication handle: Handle(local_reader_ranks=[1], buffer_handle=(1, 4194304, 6, 'psm_5cfd027d'), local_subscribe_addr='ipc:///tmp/ce6f21f1-1f98-47a9-9944-aaac6f32b99c', remote_subscribe_addr=None, remote_addr_ipv6=False)
WARNING 07-10 16:16:07 [config.py:4895] Current vLLM config is not set.
WARNING 07-10 16:16:07 [config.py:4895] Current vLLM config is not set.
INFO 07-10 16:16:07 [parallel_state.py:1078] rank 1 in world size 2 is assigned as DP rank 0, PP rank 0, TP rank 1, EP rank 1
INFO 07-10 16:16:07 [parallel_state.py:1078] rank 0 in world size 2 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0
INFO 07-10 16:16:12 [config.py:852] This model supports multiple tasks: {'reward', 'classify', 'embed', 'generate'}. Defaulting to 'generate'.
WARNING 07-10 16:16:12 [config.py:3401] Casting torch.bfloat16 to torch.float16.
INFO 07-10 16:16:12 [config.py:1489] Using max model len 2048
INFO 07-10 16:16:12 [config.py:852] This model supports multiple tasks: {'reward', 'classify', 'generate', 'embed'}. Defaulting to 'generate'.
WARNING 07-10 16:16:12 [config.py:3401] Casting torch.bfloat16 to torch.float16.
INFO 07-10 16:16:12 [config.py:1489] Using max model len 2048
[rank0]:[W710 16:16:13.332953121 ProcessGroupNCCL.cpp:1476] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
W0710 16:16:15.972000 435978 torch/multiprocessing/spawn.py:169] Terminating process 435983 via signal SIGTERM
Traceback (most recent call last):
  File "/home/mgoin/code/vllm/tests/utils.py", line 741, in wrapper
    f(*args, **kwargs)
  File "/home/mgoin/code/vllm/tests/compile/test_fusion_all_reduce.py", line 89, in test_all_reduce_fusion_pass_replace
    run_torch_spawn(all_reduce_fusion_pass_on_test_model, num_processes)
  File "/home/mgoin/code/vllm/tests/compile/test_fusion_all_reduce.py", line 83, in run_torch_spawn
    torch.multiprocessing.spawn(fn,
  File "/home/mgoin/venvs/vllm/lib/python3.12/site-packages/torch/multiprocessing/spawn.py", line 340, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method="spawn")
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mgoin/venvs/vllm/lib/python3.12/site-packages/torch/multiprocessing/spawn.py", line 296, in start_processes
    while not context.join():
              ^^^^^^^^^^^^^^
  File "/home/mgoin/venvs/vllm/lib/python3.12/site-packages/torch/multiprocessing/spawn.py", line 215, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException: 

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/home/mgoin/venvs/vllm/lib/python3.12/site-packages/torch/multiprocessing/spawn.py", line 90, in _wrap
    fn(i, *args)
  File "/home/mgoin/code/vllm/tests/compile/test_fusion_all_reduce.py", line 132, in all_reduce_fusion_pass_on_test_model
    all_reduce_fusion_pass = AllReduceFusionPass(
                             ^^^^^^^^^^^^^^^^^^^^
  File "/home/mgoin/code/vllm/vllm/compilation/collective_fusion.py", line 430, in __init__
    flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
  File "/home/mgoin/venvs/vllm/lib/python3.12/site-packages/flashinfer/comm/trtllm_ar.py", line 567, in trtllm_create_ipc_workspace_for_all_reduce_fusion
    ipc_handles.append(create_shared_buffer(aligned_size, group))
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mgoin/venvs/vllm/lib/python3.12/site-packages/flashinfer/comm/cuda_ipc.py", line 213, in create_shared_buffer
    dist.all_gather_object(handles, handle, group=group)
  File "/home/mgoin/venvs/vllm/lib/python3.12/site-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/mgoin/venvs/vllm/lib/python3.12/site-packages/torch/distributed/distributed_c10d.py", line 3039, in all_gather_object
    all_gather(object_size_list, local_size, group=group)
  File "/home/mgoin/venvs/vllm/lib/python3.12/site-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/mgoin/venvs/vllm/lib/python3.12/site-packages/torch/distributed/distributed_c10d.py", line 3706, in all_gather
    return handle_torch_function(
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mgoin/venvs/vllm/lib/python3.12/site-packages/torch/overrides.py", line 1721, in handle_torch_function
    result = mode.__torch_function__(public_api, types, args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mgoin/venvs/vllm/lib/python3.12/site-packages/torch/utils/_device.py", line 104, in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/mgoin/venvs/vllm/lib/python3.12/site-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/mgoin/venvs/vllm/lib/python3.12/site-packages/torch/distributed/distributed_c10d.py", line 3728, in all_gather
    work = group.allgather([tensor_list], [tensor])
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch.distributed.DistBackendError: NCCL error in: /pytorch/torch/csrc/distributed/c10d/NCCLUtils.cpp:77, invalid usage (run with NCCL_DEBUG=WARN for details), NCCL version 2.26.2
ncclInvalidUsage: This usually reflects invalid usage of NCCL library.
Last error:
Duplicate GPU detected : rank 0 and rank 1 both on CUDA device 1b000

/usr/lib/python3.12/multiprocessing/resource_tracker.py:254: UserWarning: resource_tracker: There appear to be 1 leaked shared_memory objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '
FAILED

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems we should also add a guard on using this fusion based on current_platform.is_device_capability(100) since currently the kernels are only built for Blackwell

ilmarkov and others added 5 commits July 11, 2025 08:20
Signed-off-by: ilmarkov <imarkov@redhat.com>
Signed-off-by: ilmarkov <imarkov@redhat.com>
Signed-off-by: ilmarkov <imarkov@redhat.com>
Signed-off-by: ilmarkov <imarkov@redhat.com>
Signed-off-by: ilmarkov <imarkov@redhat.com>
@ilmarkov ilmarkov force-pushed the imarkov/flashinfer_fused_allreduce branch from 994c874 to a812541 Compare July 11, 2025 12:21
@mgoin mgoin added performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed labels Jul 11, 2025
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM and verified the test locally, thanks!

@mgoin mgoin enabled auto-merge (squash) July 11, 2025 22:44
@mgoin mgoin changed the title Integration FlashInfer fused allreduce RMSNorm Integration SM100 FlashInfer fused allreduce RMSNorm Jul 12, 2025
@simon-mo simon-mo disabled auto-merge July 12, 2025 01:58
@simon-mo simon-mo merged commit fc0f41d into vllm-project:main Jul 12, 2025
65 of 71 checks passed
@shixianc
Copy link
Contributor

great work!
by any chance there's a hopper version of the kernel?

@mgoin
Copy link
Member

mgoin commented Jul 12, 2025

Good question @shixianc, currently FI jit-compiles their allreduce kernels only for sm100. However when we add sm90 flags their tests seem to work successfully. So we will post a PR to FI to build for Hopper as well, maybe other arches too

Chen-zexi pushed a commit to Chen-zexi/vllm that referenced this pull request Jul 13, 2025
)

Signed-off-by: ilmarkov <imarkov@redhat.com>
Co-authored-by: ilmarkov <imarkov@redhat.com>
@nvpohanh
Copy link
Contributor

@simon-mo @mgoin Quick question: is there any reason why we do not want to enable it by default? I think that will benefit the users who are not aware of this config flag.

Thanks!

@nvpohanh
Copy link
Contributor

Also, why is fi_allreduce_fusion_max_token_num set to 1024 by default? Do we see performance issue in this fusion when num_tokens is too large?

@shixianc
Copy link
Contributor

Also, why is fi_allreduce_fusion_max_token_num set to 1024 by default? Do we see performance issue in this fusion when num_tokens is too large?

I saw it's discussed in flashinfer flashinfer-ai/flashinfer#1223

@shixianc
Copy link
Contributor

        use_flashinfer = allreduce_in.shape[0] * allreduce_in.shape[
            1] * allreduce_in.element_size() <= min(
                _FI_MAX_SIZES[world_size],
                max_token_num * allreduce_in.shape[0] *
                allreduce_in.element_size(),
            )

should it be max_token_num * allreduce_in.shape**[1]** * allreduce_in.element_size() ?
otherwise it seems very small and use_flashinfer can barely activated

@nvpohanh
Copy link
Contributor

Also, why is fi_allreduce_fusion_max_token_num set to 1024 by default? Do we see performance issue in this fusion when num_tokens is too large?

I saw it's discussed in flashinfer flashinfer-ai/flashinfer#1223

Thanks for providing the link! However, I think that issue only applies to the lamport_oneshot kernel. We could still use the twoshot kernel when num_tokens is larger, right?

@shixianc
Copy link
Contributor

Good question @shixianc, currently FI jit-compiles their allreduce kernels only for sm100. However when we add sm90 flags their tests seem to work successfully. So we will post a PR to FI to build for Hopper as well, maybe other arches too

@mgoin I built for sm90a and benchmarked using same commands but found worse TPOT. Let me know if you guys have any benchmarks on hoppers, thanks!

patrickvonplaten pushed a commit to patrickvonplaten/vllm that referenced this pull request Jul 15, 2025
)

Signed-off-by: ilmarkov <imarkov@redhat.com>
Co-authored-by: ilmarkov <imarkov@redhat.com>
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
LyrisZhong pushed a commit to LyrisZhong/vllm that referenced this pull request Jul 23, 2025
)

Signed-off-by: ilmarkov <imarkov@redhat.com>
Co-authored-by: ilmarkov <imarkov@redhat.com>
avigny pushed a commit to avigny/vllm that referenced this pull request Jul 31, 2025
)

Signed-off-by: ilmarkov <imarkov@redhat.com>
Co-authored-by: ilmarkov <imarkov@redhat.com>
Signed-off-by: avigny <47987522+avigny@users.noreply.github.com>
x22x22 pushed a commit to x22x22/vllm that referenced this pull request Aug 5, 2025
)

Signed-off-by: ilmarkov <imarkov@redhat.com>
Co-authored-by: ilmarkov <imarkov@redhat.com>
Signed-off-by: x22x22 <wadeking@qq.com>
Pradyun92 pushed a commit to Pradyun92/vllm that referenced this pull request Aug 6, 2025
)

Signed-off-by: ilmarkov <imarkov@redhat.com>
Co-authored-by: ilmarkov <imarkov@redhat.com>
npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
)

Signed-off-by: ilmarkov <imarkov@redhat.com>
Co-authored-by: ilmarkov <imarkov@redhat.com>
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
)

Signed-off-by: ilmarkov <imarkov@redhat.com>
Co-authored-by: ilmarkov <imarkov@redhat.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
)

Signed-off-by: ilmarkov <imarkov@redhat.com>
Co-authored-by: ilmarkov <imarkov@redhat.com>
Signed-off-by: Paul Pak <paulpak58@gmail.com>
taneem-ibrahim pushed a commit to taneem-ibrahim/vllm that referenced this pull request Aug 14, 2025
)

Signed-off-by: ilmarkov <imarkov@redhat.com>
Co-authored-by: ilmarkov <imarkov@redhat.com>
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
)

Signed-off-by: ilmarkov <imarkov@redhat.com>
Co-authored-by: ilmarkov <imarkov@redhat.com>
Signed-off-by: Diego-Castan <diego.castan@ibm.com>
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 27, 2025
)

Signed-off-by: ilmarkov <imarkov@redhat.com>
Co-authored-by: ilmarkov <imarkov@redhat.com>
googlercolin pushed a commit to googlercolin/vllm that referenced this pull request Aug 29, 2025
)

Signed-off-by: ilmarkov <imarkov@redhat.com>
Co-authored-by: ilmarkov <imarkov@redhat.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants