Skip to content

Conversation

Abatom
Copy link
Contributor

@Abatom Abatom commented Jun 30, 2025

Fix the issue of occasional garbled output when receiving a temporarily created empty tensor and ncclRecv are not in the same stream.

Reproduction steps

  1. Launch 1P1D. Referring to [V1][P/D] An native implementation of xPyD based on P2P NCCL #18242.
  2. Use a load testing tool for stress testing.
python3 benchmark_serving.py \
    --backend vllm \
    --model base_model \
    --tokenizer meta-llama/Llama-3.1-8B-Instruct \
    --dataset-name "random" \
    --host 10.0.1.1 \
    --port 10001 \
    --random-input-len 1024 \
    --random-output-len 1024 \
    --ignore-eos \
    --burstiness 100 \
    --percentile-metrics "ttft,tpot,itl,e2el" \
    --metric-percentiles "90,95,99" \
    --seed $(date +%s) \
    --trust-remote-code \
    --request-rate 1 \
    --num-prompts 300
  1. Simultaneously, randomly use the curl command for single requests, and invalid outputs such as "!!!!!!!!!!" may appear probabilistically.
curl -X POST -s http://10.0.1.1:10001/v1/completions \
-H "Content-Type: application/json" \
-d '{
    "model": "base_model",
    "prompt": "San Francisco is a",
    "max_tokens": 10,
    "temperature": 0
}'

Signed-off-by: Abatom <abzhonghua@gmail.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @Abatom, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a concurrency issue in the P2pNcclConnector that leads to garbled output during stress testing. By ensuring that tensor creation occurs within the correct CUDA stream, it prevents potential race conditions and improves the reliability of distributed tensor transfers.

Highlights

  • Bug Fix: Addresses an issue where garbled output occurs in the P2pNcclConnector when receiving a temporarily created empty tensor, especially under stress testing conditions.
  • Concurrency: Ensures that the tensor creation happens within the context of the recv_stream to avoid potential synchronization issues with NCCL operations.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a critical race condition in the P2pNcclConnector that could lead to garbled output under stress. The fix is well-targeted and correct. By ensuring that the tensor allocation (torch.empty) and the subsequent ncclRecv operation are both executed on the same CUDA stream, the change prevents data corruption by properly ordering the asynchronous GPU operations. This is an essential fix for the stability and correctness of the P2P KV transfer mechanism.

Comment on lines +313 to +317
with torch.cuda.stream(self.recv_stream):
tensor = torch.empty(data["shape"],
dtype=getattr(
torch, data["dtype"]),
device=self.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This change is a crucial fix that correctly addresses a potential race condition.

By placing the torch.empty call within the recv_stream context, you ensure that the tensor memory allocation is properly ordered with respect to the subsequent ncclRecv call, which also uses self.recv_stream.

Without this, the allocation would occur on the default CUDA stream. If the default stream and recv_stream are not synchronized, the ncclRecv operation could begin writing to the tensor's memory before the allocation is complete, leading to the data corruption and garbled output described in the pull request. This change correctly serializes these operations on the same stream, which is essential for correctness and stability under load.

@simon-mo simon-mo enabled auto-merge (squash) June 30, 2025 16:34
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jun 30, 2025
@simon-mo simon-mo disabled auto-merge June 30, 2025 23:44
@simon-mo simon-mo merged commit ded1fb6 into vllm-project:main Jun 30, 2025
85 of 87 checks passed
@Abatom Abatom deleted the xpyd-stream branch July 2, 2025 03:34
CSWYF3634076 pushed a commit to CSWYF3634076/vllm that referenced this pull request Jul 2, 2025
…clConnector (vllm-project#20263)

Signed-off-by: Abatom <abzhonghua@gmail.com>
avigny pushed a commit to avigny/vllm that referenced this pull request Jul 31, 2025
…clConnector (vllm-project#20263)

Signed-off-by: Abatom <abzhonghua@gmail.com>
Signed-off-by: avigny <47987522+avigny@users.noreply.github.com>
googlercolin pushed a commit to googlercolin/vllm that referenced this pull request Aug 29, 2025
…clConnector (vllm-project#20263)

Signed-off-by: Abatom <abzhonghua@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants