Skip to content

Conversation

ByronHsu
Copy link
Collaborator

@ByronHsu ByronHsu commented Jun 15, 2025

Motivation

Based on #7193

  1. Support decode retract
  2. Upstream decode code
  3. Add a retract test

Modifications

Checklist

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @ByronHsu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a key feature for the disaggregated decode server: the ability to retract requests and offload their KV cache to CPU when GPU memory is constrained, and then resume them later. This improves the system's resilience to memory pressure and allows for better utilization. The memory allocation and scheduling logic has been updated to support this, along with a new test to verify the behavior.

Highlights

  • Request Retraction and Resumption: Implemented the ability to retract requests from GPU memory to CPU when memory is low in the disaggregated decode server, and resume them later when memory becomes available. This involves offloading and reloading the KV cache.
  • Enhanced Memory Management: Refined the memory allocation logic in the decode server (DecodePreallocQueue) to better estimate available tokens, considering running requests, potential future growth, and the space occupied by retracted requests. Added checks to prevent adding requests that exceed total capacity or available space.
  • Disaggregated Decode Scheduling: Updated the scheduling logic to prioritize resuming retracted requests before allocating new ones from the pending queue. Modified event loops to correctly account for requests in the waiting queue when determining idle state.
  • Testing for Retraction: Added a new test case (TestDisaggregationSimulatedRetract) to specifically validate the retraction and resumption functionality in a disaggregated setup using a GSM8K evaluation.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configureGemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces request retraction and resumption capabilities for the disaggregated decode engine, significantly enhancing memory management. Key changes include modifications to DecodePreallocQueue for handling retracted requests, updates to token allocation logic, and new methods for offloading/loading KV cache to/from CPU. The PR also adds a new test suite to simulate and verify the retraction mechanism. Overall, these changes are crucial for stability and efficiency under high load. Some minor points like a redundant assignment and a return type change need attention.

@@ -412,10 +552,9 @@ def add(self, decode_req: DecodeRequest) -> None:
def extend(self, decode_reqs: List[DecodeRequest]) -> None:
self.queue.extend(decode_reqs)

def pop_transferred(self) -> List[DecodeRequest]:
def pop_transferred(self) -> List[Req]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The return type of pop_transferred has changed from List[DecodeRequest] to List[Req]. This is a significant API change for this method. Please ensure that all call sites of pop_transferred (e.g., in SchedulerDisaggregationDecodeMixin) are updated to handle List[Req] correctly and that this change aligns with the intended data flow. If DecodeRequest wrapper is no longer needed at that stage, this is fine, but consistency is key.

@@ -285,10 +375,23 @@ def pop_preallocated(self) -> List[DecodeRequest]:
if self.req_to_metadata_buffer_idx_allocator.available_size() <= 0:
break

# Memory estimation: don't add if the projected memory cannot be met
# TODO: add new_token ratio
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The TODO comment mentions adding new_token_ratio. If this ratio is intended to influence required_tokens_for_request or the subsequent memory check, it might be beneficial to clarify how it would be incorporated. For example, would it adjust self.num_reserved_decode_tokens or decode_req.req.sampling_params.max_new_tokens dynamically?

Comment on lines +483 to +485
assert (
req_pool_indices is not None
), "req_pool_indices is full! There is a bug in memory estimation."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The assertion message has been improved, which is good. Consider if there are specific scenarios or conditions under which req_pool_indices could be None despite the memory estimation logic. This could help in further refining the estimation or handling edge cases more gracefully, perhaps by logging more context before asserting if that's feasible in a performance-critical path.

Comment on lines +514 to +516
assert (
kv_loc is not None
), "KV cache is full! There is a bug in memory estimation."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the previous comment on req_pool_indices, the improved assertion message for kv_loc is helpful. It might be worth considering if additional diagnostic information (e.g., available size, requested size, number of active requests) could be logged if this assertion fails frequently during development or testing, to aid in debugging the memory estimation.

Comment on lines +1454 to +1458
if len(retracted_reqs) == 0:
# Corner case: only one request left
raise ValueError(
"Failed to retract any request. No space left for only one request."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This error condition for failing to retract any request when only one is left is a good safeguard. Ensure that this scenario is adequately covered by tests, especially in conjunction with the new retraction logic in decode.py.

@ByronHsu ByronHsu changed the title [PD] Update decode.py [PD] Support decode retract and update decode.py Jun 15, 2025
@ByronHsu ByronHsu merged commit db0cc57 into main Jun 15, 2025
35 of 53 checks passed
@ByronHsu ByronHsu deleted the byron/decode-sync branch June 15, 2025 02:48
@ByronHsu ByronHsu mentioned this pull request Jun 15, 2025
@zhyncs
Copy link
Member

zhyncs commented Jun 15, 2025

cool cc @yizhang2077

@CSEEduanyu
Copy link

+ self.num_reserved_decode_tokens

@ByronHsu
hi,I would like to know if the retract is currently in effect? Why do we need to predict the output length using a fixed value instead of keeping it consistent when pd is not separated?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants