Skip to content

Conversation

NickLucche
Copy link
Contributor

@NickLucche NickLucche commented Jun 26, 2025

With #19223, we're addressing most of the cases where P request blocks may be left starving.
However, there are still cases where if the router fails to communicate request abortion for whatever reason (eg in-flight request lost, router down..) while the request has not yet reached D or D fails to communicate the abortion to P, where the remote producer may be left with blocks that won't be cleared.

This PR addresses these final edge-cases by attaching a simple TTL to every request that needs to be read from local(D)<-remote (P).

cc @njhill

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @NickLucche, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

I've addressed a critical edge case in the disaggregated decode-prefill setup where KV cache blocks on the remote producer (P) could become permanently stranded if the consumer (D) failed to communicate request abortion or completion. This pull request introduces a robust timeout mechanism that automatically frees these blocks after a configurable period, significantly improving resource utilization and system stability in failure scenarios.

Highlights

  • KV Cache Block Timeout: I've implemented a timeout mechanism to automatically clear KV cache blocks on the remote producer (P) if the consumer (D) fails to read them within a specified duration. This prevents resource starvation in disaggregated decode-prefill setups.
  • Configurable Timeout: A new environment variable, VLLM_NIXL_ABORT_REQUEST_TIMEOUT, has been introduced, allowing users to configure the timeout duration for remote consumer reads (defaulting to 120 seconds).
  • NixlConnector Enhancements: I've modified the NixlConnector and NixlConnectorWorker to track requests that have completed prefill on the producer and are awaiting consumption by the decoder, enabling the new timeout logic to be applied.
  • Unit Test Coverage: A dedicated unit test (test_abort_timeout_on_prefiller) has been added to validate the end-to-end functionality of the remote consumer read timeout, simulating a scenario where communication fails and blocks are eventually cleared.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This PR introduces a timeout mechanism to clear request blocks in the remote producer when the router fails to communicate request abortion, addressing an edge case in the NixlConnector. The changes include adding a TTL to requests, updating metadata, and implementing timeout handling in the worker. The code also includes a new unit test to verify the timeout functionality.

Copy link

mergify bot commented Jun 27, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @NickLucche.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @NickLucche

Comment on lines 971 to 974
# Track the request that are waiting to be read and abort on timeout.
# Set to -1 so that timeout does not depend on model latency.
for req_id in metadata.reqs_to_send:
self._reqs_to_send[req_id] = -1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this would be needed per my other comments

@NickLucche
Copy link
Contributor Author

Hey @njhill I've addressed all the comments but the part in which D does a check in get_num_matched_tokens.
Big thanks for the review, code looks much cleaner with your additions.

Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @NickLucche! WDYT about adding the time check on the decode side too?

@NickLucche
Copy link
Contributor Author

Thanks for reviewing @njhill.

WDYT about adding the time check on the decode side too?

Right, I see its purpose in covering the edge case in which P has already freed the request blocks while D attempts to read them.
Do you have another scenario in mind?

@njhill
Copy link
Member

njhill commented Jul 1, 2025

No, that's the main thing. A safeguard so that we don't pull blocks which have already been freed and could contain random other cache. I think the changes required to add that should be minimal.

@NickLucche NickLucche force-pushed the pd-remote-timeout branch from 45499ab to 0d09b8b Compare July 1, 2025 17:05
Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @NickLucche. I realized a problem with the decode-side expiry check (see inline comments). I think it might be best to leave that out for now (sorry!), and add it once we have incorporated request-level error handling into the general connector logic.

Or we could keep it similar to what you have now but make sure to update it once that error handling is ready (since it should then only be a trivial change needed).

@@ -442,6 +459,10 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
# finish reading before safely freeing the blocks.
self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int)

# Map of remote agent name -> time offset to keep clocks synced.
self._remote_agent_time_offsets: dict[str, float] = {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we instead change self._remote_agents to hold a tuple (name, time_offset)?

Comment on lines 524 to 525
self._remote_agent_time_offsets[metadata.engine_id] = (
metadata.remote_node_time + rtt / 2 - got_metadata_time)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just compute the offset here and add it as an arg to add_remote_agent()?

Comment on lines 870 to 873
while self._reqs_expired_ttl:
req_id = next(iter(self._reqs_expired_ttl))
done_recving.add(req_id)
self._reqs_expired_ttl.remove(req_id)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
while self._reqs_expired_ttl:
req_id = next(iter(self._reqs_expired_ttl))
done_recving.add(req_id)
self._reqs_expired_ttl.remove(req_id)
done_recving.update(self._reqs_expired_ttl)
self._reqs_expired_ttl.clear()

Comment on lines 337 to 340
if delay_free_blocks:
# Prefill request on remote. It will be read from D upon completion
self._reqs_need_send[request.request_id] = (
time.perf_counter() + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be clearer to have a variable, and then just set this directly in the dict below

Suggested change
if delay_free_blocks:
# Prefill request on remote. It will be read from D upon completion
self._reqs_need_send[request.request_id] = (
time.perf_counter() + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT)
expiry_time: Optional[float] = None
if delay_free_blocks:
# Prefill request on remote. It will be read from D upon completion
expiry_time = time.perf_counter() + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT
self._reqs_need_send[request.request_id] = expiry_time

return delay_free_blocks, dict(
do_remote_prefill=True,
do_remote_decode=False,
remote_block_ids=computed_block_ids,
remote_engine_id=self.engine_id,
remote_host=self.side_channel_host,
remote_port=self.side_channel_port,
tp_size=self.vllm_config.parallel_config.tensor_parallel_size)
tp_size=self.vllm_config.parallel_config.tensor_parallel_size,
request_ttl=self._reqs_need_send.get(request.request_id, -1))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should call this something like expires since TTL implies a relative time (duration).

Comment on lines 993 to 999
assert self._remote_agent_time_offsets[
meta.remote_engine_id] is not None
remote_offset = self._remote_agent_time_offsets[meta.remote_engine_id]
if time.perf_counter() + remote_offset > meta.request_ttl:
logger.warning("Request remote TTL expired for request %s", req_id)
self._reqs_expired_ttl.add(req_id)
return
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is the best place for this. It would be better to add it on the scheduler side of the connector in get_num_new_matched_tokens. If expired, we simply return 0, False from that method. And then there's no need to include in the internal metadata or worker-side finished method.

Right now we don't read the blocks but we also don't notify that there's any problem, so it will be pretty much just as bad. The decode will continue with random data in the kvcache and produce nonsense output.

However ... I just realized that this won't work if we're using the time offset exchanged during the handshake because the scheduler side won't know that, and the handshake may not have even happened yet at that point 🤔

If we keep things on the worker side, we'll need some of the other connector changes that have been proposed for error handling, so that we can effectively fail the transfer for this individual request.

assert self._remote_agent_time_offsets[
meta.remote_engine_id] is not None
remote_offset = self._remote_agent_time_offsets[meta.remote_engine_id]
if time.perf_counter() + remote_offset > meta.request_ttl:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should subtract additional "buffer" time here to allow for the transfer time. Maybe 30 seconds or something like that.

@NickLucche
Copy link
Contributor Author

@njhill reverted D-side changes, let's take another crack at it once the error handling at request level lands

# consumer. This is only applicable when using NixlConnector in a
# disaggregated decode-prefill setup.
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT":
lambda: int(os.getenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", "120"))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we'd set this to a runtime value depending on queue metrics. Still, we may have to resort to heuristics even in that case (eg > max q time) .
So I thought this was ok for an initial implementation.

# Track the request that are waiting to be read and abort on timeout.
# Set to -1 so that timeout does not depend on model latency.
for req_id in metadata.reqs_to_send:
self._reqs_to_send[req_id] = -1
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if timeout is taken from queue permanence metrics, then we could even put the start here.

@@ -325,6 +331,11 @@ def request_finished(
# If prompt < block_size, no xfer so free blocks immediately.
delay_free_blocks = len(computed_block_ids) > 0

if delay_free_blocks:
# Prefill request on remote. It will be read from D upon completion
self._reqs_need_send[request.request_id] = time.perf_counter(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

swapped back from time.monotic() so that we can send it to D

@NickLucche NickLucche force-pushed the pd-remote-timeout branch from e253ac6 to b12f394 Compare July 3, 2025 14:21
@NickLucche NickLucche requested a review from njhill July 4, 2025 08:19
Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @NickLucche

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 5, 2025
@aarnphm aarnphm enabled auto-merge (squash) July 5, 2025 15:50
@njhill
Copy link
Member

njhill commented Jul 7, 2025

@NickLucche the new test currently requires NIXL, could it be changed to work with the mock NIXL using

    @patch(
        "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", FakeNixlWrapper)

like the other tests in that file?

@NickLucche
Copy link
Contributor Author

Good point, overlook on my side thanks for the help!

Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
This reverts commit 0d09b8b.

Signed-off-by: NickLucche <nlucches@redhat.com>
This reverts commit b7d5c64.

Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
auto-merge was automatically disabled July 7, 2025 10:07

Head branch was pushed to by a user without write access

@NickLucche NickLucche force-pushed the pd-remote-timeout branch from b12f394 to e249b4c Compare July 7, 2025 10:07
@njhill
Copy link
Member

njhill commented Jul 7, 2025

Thanks @NickLucche ... that test is still failing:

[2025-07-07T11:43:22Z] >       assert len(kv_connector_metadata.requests) == 1
[2025-07-07T11:43:22Z] E       AttributeError: 'NixlConnectorMetadata' object has no attribute 'requests'

Signed-off-by: NickLucche <nlucches@redhat.com>
@NickLucche
Copy link
Contributor Author

Yep, updated old tests

@njhill njhill merged commit 71d1d75 into vllm-project:main Jul 8, 2025
71 checks passed
huydhn pushed a commit to huydhn/vllm that referenced this pull request Jul 8, 2025
Chen-zexi pushed a commit to Chen-zexi/vllm that referenced this pull request Jul 13, 2025
patrickvonplaten pushed a commit to patrickvonplaten/vllm that referenced this pull request Jul 15, 2025
…llm-project#20139)

Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
LyrisZhong pushed a commit to LyrisZhong/vllm that referenced this pull request Jul 23, 2025
avigny pushed a commit to avigny/vllm that referenced this pull request Jul 31, 2025
…llm-project#20139)

Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: avigny <47987522+avigny@users.noreply.github.com>
ganyi1996ppo pushed a commit to vllm-project/vllm-ascend that referenced this pull request Aug 5, 2025
…tion (#2085)

### What this PR does / why we need it?

This PR addresses a critical issue where Node D (Device) failures cause
Node P (Processor) to hang due to inability to release KV cache.

**Trigger Scenarios:**  
1. Node D fails mid-inference (e.g., network disconnection)  
2. Node D rejects requests at a certain stage (e.g., via API server)  
3. Load-test script termination causes Node P or D to abort queued
requests

**Root Cause Analysis:**  
1. Currently, Node D sends a "KV cache pull complete, release approved"
message to Node P
2. This message is transmitted via the worker connector. If PD
connection breaks or requests are rejected upstream, Node D cannot send
the message
3. Node P will never release KV cache without receiving this message  

**Solution:**  
Following VLLM community's approach (NIXL connector timeout mechanism),
we're implementing:
- A timeout mechanism with comprehensive warnings  
- Updated README documentation  
- Reference: VLLM's optimization PR
[#20139](vllm-project/vllm#20139)

**Note:** The full disaster recovery solution is still in design. This
PR will be merged into v091-dev branch simply but will evolve in main
([PR #2174](#2174)).


### Does this PR introduce _any_ user-facing change?


### How was this patch tested?


---------

Signed-off-by: underfituu <hzhucong@163.com>
Pradyun92 pushed a commit to Pradyun92/vllm that referenced this pull request Aug 6, 2025
npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
…llm-project#20139)

Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 27, 2025
googlercolin pushed a commit to googlercolin/vllm that referenced this pull request Aug 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants