Skip to content

Conversation

orozery
Copy link
Contributor

@orozery orozery commented Jun 12, 2025

This PR enables the worker-side KV connector to pass on arbitrary metadata to the scheduler-side connector.
This allows a standard and easy mechanism to aggregate kv-connector events from all workers.

In a nut-shell, we introduce the following connector APIs:
build_worker_connector_meta on the worker side, allowing the worker to build metadata to be sent back to the scheduler.
get_finished - on the scheduler side (which was previously a worker side) - gets the connector metadata from all workers and yields the finished request transfers.

This PR makes the following changes:

  1. Change MultiprocExecutor to get ModelRunnerOutput from all workers,
    and aggregate the finished_sending and finished_recving from all.
  2. Remove the worker aggregation of those fields in NixlConnector.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Jun 12, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @orozery, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request refactors the communication mechanism between workers and the scheduler for KV cache transfer status. It introduces a generic metadata channel from workers to the scheduler, allowing workers to send connector-specific information. The scheduler is now responsible for collecting and aggregating this metadata from all workers to determine the global state of KV transfers, particularly for distributed connectors like Nixl where completion requires coordination across multiple ranks.

Highlights

  • New Worker-to-Scheduler Metadata API: Introduced a new base method build_worker_connector_meta on the worker side of the KV connector interface. This method allows workers to build and return arbitrary metadata (KVConnectorMetadata) to the scheduler during each execution step.
  • Scheduler-side Aggregation of KV Transfer Status: Moved the responsibility of aggregating finished KV transfers (sending and receiving) from the worker side to the scheduler side. The scheduler-side base connector now has a get_finished method that takes the collected metadata from all workers and determines which requests have completed their transfers across the entire distributed setup.
  • ModelRunnerOutput Update: The ModelRunnerOutput dataclass, which workers return to the scheduler, has been updated to include a list of kv_connector_metadata instead of separate finished_sending and finished_recving sets. This allows workers to pass richer, connector-specific information.
  • Executor Aggregation: The multiprocess executor (MultiprocExecutor) now collects ModelRunnerOutput from all workers and aggregates the kv_connector_metadata from each worker into a single list before passing the combined output to the scheduler.
  • Nixl Connector Adaptation: The Nixl connector implementation has been updated to use the new metadata mechanism. It defines NixlWorkerConnectorMetadata to carry finished transfer IDs from the worker. The logic for tracking completion across all tensor parallel ranks has been moved from the worker-side NixlConnectorWorker to the scheduler-side NixlConnectorScheduler.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configureGemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@orozery
Copy link
Contributor Author

orozery commented Jun 12, 2025

@sdavidbd we can change the get_finished API here to return not just the finished requests, but also invalid_block_ids

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a mechanism for workers to send metadata to the scheduler via the KV connector, which is a significant architectural change. The core idea is to centralize the aggregation of KV transfer events and status on the scheduler side. Key changes include introducing build_worker_connector_meta on the worker side and moving/refining get_finished to the scheduler side.

Overall, the changes seem to implement the described functionality. However, there are a few critical areas to address:

  1. A potential logic swap in vllm/v1/core/sched/scheduler.py regarding how finished_sending and finished_recving statuses from the connector are interpreted and acted upon. This could lead to incorrect behavior like freeing blocks prematurely or not marking requests as ready when they are.
  2. Type mismatches and potential runtime errors in NixlConnectorMetadata, ModelRunnerOutput, and gpu_worker.py related to handling None values and list assignments for kv_connector_metadata.

Addressing these points will be crucial for the correctness and stability of this new metadata flow.

@orozery
Copy link
Contributor Author

orozery commented Jun 12, 2025

@njhill @robertgshaw2-redhat putting this for preliminary review before weekend starts over here.
I did not yet test the changes I made to the nixl and multi connectors. I'm betting I introduced bugs. Will test it next week.

@orozery orozery force-pushed the connector-metadata-worker-output branch 5 times, most recently from f4351d0 to 916c8e2 Compare June 15, 2025 16:19
@orozery orozery marked this pull request as ready for review June 15, 2025 16:31
@orozery orozery changed the title [WIP] [V1] - Enable worker -> scheduler connector metadata [V1] - Enable worker -> scheduler connector metadata Jun 15, 2025
assert isinstance(output, ModelRunnerOutput)
return output if self.is_driver_worker else None
if has_kv_transfer_group():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's safer to use is_v1_kv_transfer_group until V0 is officially deprecated.

finished_sending: Optional[set[str]] = None
finished_recving: Optional[set[str]] = None
# KV Cache Connector metadata.
kv_connector_metadata: Optional[list["KVConnectorMetadata"]] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

KVConnectorMetadata was originally intended for scheduler-to-worker signaling. Using it in the opposite direction (worker-to-scheduler) could blur its semantics. It might be cleaner to introduce a separate class like KVConnectorOutput for this purpose.

Also, as mentioned above, I think aggregation in multi-worker setups should be handled at the MultiprocExecutor level rather than in the Scheduler. In that case kv_connector_metadata should be typed as: Optional["KVConnectorMetadata"]

Copy link
Contributor Author

@orozery orozery Jun 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want the scheduler connector to get access to all of the KVConnectorMetadata from all workers.
Only the connector knows what's inside the metadata. From the MultiprocExecutor perspective it's opaque.
The Executor returns a single ModelRunnerOutput, so the way I found to let the scheduler connector access all metadatas is having the executor simply compose all of the metadatas to a list.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand your approach, but I still believe the alternative I suggested - employing connector-specific aggregation logic at the executor - is cleaner. This way, the executor can return a single, aggregated KVConnectorMetadata object, and the scheduler connector continues to work with a unified metadata instance rather than a list. It keeps the interface consistent and offloads connector-specific logic to where it belongs.

Also, could we revisit the idea of separating the metadata classes for scheduler-to-worker and worker-to-scheduler communication?

return EMPTY_MODEL_RUNNER_OUTPUT
if has_kv_transfer_group():
with set_forward_context(None, self.vllm_config):
self.maybe_setup_kv_connector(scheduler_output)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're missing a call to clear_connector_metadata in this case (also before this change).

return output if self.is_driver_worker else None
if has_kv_transfer_group():
kv_connector_metadata = \
get_kv_transfer_group().build_worker_connector_meta(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

build_worker_connector_meta should be called in gpu_model_runner.execute_model, before invoking clear_connector_metadata.

Alternatively, we could delegate the state clearing to build_worker_connector_meta itself and remove the clear_connector_metadata API. This would also make it symmetric with build_connector_meta, which is responsible for resetting the scheduler connector’s state.

@@ -1028,21 +1028,27 @@ def _update_waiting_for_remote_kv(self, request: Request) -> bool:
self.finished_recving_kv_req_ids.remove(request.request_id)
return True

def _update_from_kv_xfer_finished(self,
model_runner_output: ModelRunnerOutput):
def _update_from_kv_connector_metadata(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kv_connector_metadata may carry distinct signals for different code paths. I think it would be cleaner if update_from_output used dedicated connector APIs to extract the relevant information from kv_connector_metadata, e.g.:

finished_sending, finished_recving = self.connector.get_finished(kv_connector_metadata)
_update_from_kv_xfer_finished(finished_sending, finished_recving)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better the scheduler connector aggregates the metadatas only once, and outputs everything the scheduler needs back (finished reqs, invalid blocks, etc).
So obviously get_finished is not a good name. Maybe something like process_worker_output(..) -> ConnectorOutput where ConnectorOutput is a new struct that will contain all relevant fields (which were previously laid out flat on ModelRunnerOutput.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that metadata aggregation should happen only once (ideally at the executor level). After that, the scheduler would hold a single aggregated instance of the worker-side KVConnectorMetadata (which is still opaque and connector-specific).

From there, the scheduler can use dedicated connector APIs (e.g., get_finished) to extract only the information it needs. This keeps the design more flexible and scalable, rather than relying on a single API to unpack all possible data upfront.


def get_finished(
self,
model_runner_output: ModelRunnerOutput,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why ModelRunnerOutput and not KVConnectorMetadata?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To allow the connector full awareness of the model output (maybe someone will want sampled_token_ids).
Same way the connector gets full access of the SchedulerOutput.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The scheduler connector only gets access to SchedulerOutput for the purpose of creating metadata for the worker connector. Similarly, the worker connector should only access ModelRunnerOutput to generate metadata for the scheduler connector.

In any case, I don’t think the scheduler connector should have access to ModelRunnerOutput. That separation helps keep responsibilities clear and avoids unnecessary coupling.

self, finished_req_ids: set[str]
) -> tuple[Optional[set[str]], Optional[set[str]]]:
def build_worker_connector_meta(
self, scheduler_output: SchedulerOutput,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to pass in SchedulerOutput? I think we should make it symmetric with the scheduler-side build_connector_meta and pass just ModelRunnerOutput.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your suggestion will work assuming the connector already got the SchedulerOutput (by bind_connector_metadata).
But there's also clear_connector_metadata in the way, so this seems more fragile to me to try to correspond to the correct scheduler output. I would prefer to directly pass in the scheduler output here to make it easier and more explicit for the worker side connector to build its metadata.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe clear_connector_metadata should always be the last worker-side API invoked in each step, after which the worker connector’s state is reset. As I mentioned earlier, a cleaner alternative is to make build_worker_connector_meta responsible for both building the worker-side metadata and resetting the state - mirroring the behavior of build_connector_meta on the scheduler side - thereby removing the need for a separate clear_connector_metadata API.

In any case, the worker connector shouldn't need access to the SchedulerOutput. As you said, it should already receive everything it needs via bind_connector_metadata.

This keeps the design symmetric:

  • The scheduler connector builds the metadata for the worker connector from SchedulerOutput and resets its state.
  • The worker connector builds the metadata for the scheduler connector from ModelRunnerOutput and resets its state.

Comment on lines 231 to 233
kv_connector_metadata = []
for i, output in enumerate(outputs):
kv_connector_metadata += output.kv_connector_metadata or []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be cleaner to use connector-specific aggregation logic here - for example, by introducing a new worker-side KVConnector API dedicated to aggregating the metadata.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about this, but I did not want to introduce the connector inside the executor. Currently it's only in scheduler.py.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're already introducing KVConnectorMetadata into the executor - so it seems reasonable to also give the executor its own KVConnector instance. This could be a new EXECUTOR role with a single API dedicated to aggregating worker-side metadata.

Personally, I find this cleaner than having each worker return a list with a single metadata object and adding ad-hoc logic in the executor to manually merge those lists. Delegating aggregation to the connector keeps the logic encapsulated and consistent.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@njhill What are your thoughts on introducing the connector to the executor to allow aggregation of workers output there?

@sdavidbd
Copy link
Contributor

@sdavidbd we can change the get_finished API here to return not just the finished requests, but also invalid_block_ids

As suggested in my review, I think we should introduce a new connector API to extract the invalid block IDs from the worker-side connector metadata.

Copy link

mergify bot commented Jun 19, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @orozery.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jun 19, 2025
@orozery orozery mentioned this pull request Jun 19, 2025
1 task
@njhill
Copy link
Member

njhill commented Jun 20, 2025

Thanks @orozery. I like this but just trying to think thorough implications/alternatives. In the interests of keeping the interface as simple as possible and minimizing concerns on the connector impl side.

One idea is whether it would make sense to abstract this return flow of information in the form of events, i.e. a generalization of what we already have with the lists of finished request ids.

It may be that we can then still encapsulate the TP aggregation of these within the framework, since we would require a positive response from all workers. One or more negative results (failures) would translate to a negative result when aggregated.

@orozery
Copy link
Contributor Author

orozery commented Jun 22, 2025

Thanks @orozery. I like this but just trying to think thorough implications/alternatives. In the interests of keeping the interface as simple as possible and minimizing concerns on the connector impl side.

One idea is whether it would make sense to abstract this return flow of information in the form of events, i.e. a generalization of what we already have with the lists of finished request ids.

It may be that we can then still encapsulate the TP aggregation of these within the framework, since we would require a positive response from all workers. One or more negative results (failures) would translate to a negative result when aggregated.

@njhill IIUC (please correct me) your suggestion is as follows:
Do not allow arbitrary metadata flow back from workers to scheduler (similarly to the unconstrained KVConnectorMetadata flowing from scheduler to workers).
Instead, allow each worker to report only a list of completed events.
Also, you do not want these "Events" to be opaque, but explicit. Perhaps something like:
ModelRunnerOutput.finished_connector_events: list[tuple[int, bool]] #[(event_id, is_success)]
This means it will be the scheduler connector role to define the events (including their IDs) and to pass them on to workers via KVConnectorMetadata.

@sdavidbd your thoughts on this?

@sdavidbd
Copy link
Contributor

Thanks @orozery. I like this but just trying to think thorough implications/alternatives. In the interests of keeping the interface as simple as possible and minimizing concerns on the connector impl side.
One idea is whether it would make sense to abstract this return flow of information in the form of events, i.e. a generalization of what we already have with the lists of finished request ids.
It may be that we can then still encapsulate the TP aggregation of these within the framework, since we would require a positive response from all workers. One or more negative results (failures) would translate to a negative result when aggregated.

@njhill IIUC (please correct me) your suggestion is as follows: Do not allow arbitrary metadata flow back from workers to scheduler (similarly to the unconstrained KVConnectorMetadata flowing from scheduler to workers). Instead, allow each worker to report only a list of completed events. Also, you do not want these "Events" to be opaque, but explicit. Perhaps something like: ModelRunnerOutput.finished_connector_events: list[tuple[int, bool]] #[(event_id, is_success)] This means it will be the scheduler connector role to define the events (including their IDs) and to pass them on to workers via KVConnectorMetadata.

@sdavidbd your thoughts on this?

I really like the idea of making the worker-side connector metadata explicit rather than opaque - especially since it's ultimately consumed by the framework. Given the choice between:

  1. An opaque metadata object that requires a connector-specific class, aggregation logic, and accessor methods for the framework to extract relevant information,
  2. A concrete metadata class with well-defined fields, known aggregation semantics, and direct framework access,

- I’d strongly prefer the latter.

Regarding aggregation, I think we can keep it simple and sufficient by following two principles:

  • Metadata fields should be typed as collections.
  • Aggregation should be either a union or intersection, depending on semantics.

For example:

  • Finished requests → intersection
  • Invalid blocks → union
  • KV events → union

@orozery
Copy link
Contributor Author

orozery commented Jun 26, 2025

@sdavidbd I started implementing with ModelRunnerOutput.finished_connector_events: list[tuple[int, bool]] #[(event_id, is_success)] but got complicated when trying to make NixlConnector adopt it.
I found myself somehow encoding stuff into the event_id, which seemed hacky to me.

So I made some re-thinking and came up with a suggestion which is somewhere in the middle between opaque and explicit:
ModelRunnerOutput.kv_connector_worker_events: Optional[list["KVConnectorWorkerEvent"]]

This justifies why each worker actually returns a list (in the previous solution, it was not clear why each worker returns a singleton [worker_metadata]).
BTW this also answers your other suggestion on using a separate metadata struct (and not KVConnectorMetadata) for worker -> scheduler.

@orozery orozery force-pushed the connector-metadata-worker-output branch 2 times, most recently from d94ea04 to ce105e7 Compare July 1, 2025 14:51
@njhill njhill merged commit cc876d0 into vllm-project:main Jul 10, 2025
79 checks passed
Comment on lines +253 to +274
finished_sending = set[str]()
finished_recving = set[str]()
for output in outputs:
# update finished_sending
for req_id in output.finished_sending or []:
new_count = self._send_remaining_count[req_id] - 1
if new_count == 0:
# got response from all workers, report back to scheduler
finished_sending.add(req_id)
del self._send_remaining_count[req_id]
else:
self._send_remaining_count[req_id] = new_count

# update finished_recving
for req_id in output.finished_recving or []:
new_count = self._recv_remaining_count[req_id] - 1
if new_count == 0:
# got response from all workers, report back to scheduler
finished_recving.add(req_id)
del self._recv_remaining_count[req_id]
else:
self._recv_remaining_count[req_id] = new_count
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
finished_sending = set[str]()
finished_recving = set[str]()
for output in outputs:
# update finished_sending
for req_id in output.finished_sending or []:
new_count = self._send_remaining_count[req_id] - 1
if new_count == 0:
# got response from all workers, report back to scheduler
finished_sending.add(req_id)
del self._send_remaining_count[req_id]
else:
self._send_remaining_count[req_id] = new_count
# update finished_recving
for req_id in output.finished_recving or []:
new_count = self._recv_remaining_count[req_id] - 1
if new_count == 0:
# got response from all workers, report back to scheduler
finished_recving.add(req_id)
del self._recv_remaining_count[req_id]
else:
self._recv_remaining_count[req_id] = new_count
def update_finished_set(
req_ids: list[str], remaining_count_dict: dict[str, int], finished_set: set[str]
) -> None:
for req_id in req_ids or []:
new_count = remaining_count_dict[req_id] - 1
if new_count == 0:
finished_set.add(req_id)
del remaining_count_dict[req_id]
else:
remaining_count_dict[req_id] = new_count
finished_sending = set[str]()
finished_recving = set[str]()
for output in outputs:
update_finished_set(output.finished_sending, self._send_remaining_count, finished_sending)
update_finished_set(output.finished_recving, self._recv_remaining_count, finished_recving)

@sdavidbd
Copy link
Contributor

Thanks @orozery looks great, just one minor comment.

I like how much cleaner it is to do the finished handling in the worker rather than the model runner but it has the downside of the connector interaction no longer being encapsulated in the model runner. Not suggesting to revert this, just a thought/observation.

@njhill @orozery Also, I’m uncomfortable with this being done after the call to clear_connector_metadata in the model runner. In stateless connector implementations, the necessary information might already be gone by that point. In future work, if we introduce a unified interface for building worker-connector metadata to propagate backward, we could consider removing clear_connector_metadata entirely. Instead, we would delegate the responsibility for clearing the worker-connector state to the metadata build API itself, maintaining symmetry with build_connector_meta on the scheduler-connector side (see my earlier comments here and here).

@njhill
Copy link
Member

njhill commented Jul 10, 2025

Thanks @sdavidbd I agree about the clear_connector_metadata call happening last, and that there was a missing call to clear_connector_metadata even before this PR. Perhaps we could also move that call out into the worker, after the call to get_finished()?

I'm also unsure why we set it to an empty KVConnectorMetadata in base.py, probably better to change that to an abstract class and have _connector_metadata be Optional.

Would you be interested in making a follow-on PR to address these things?

@njhill
Copy link
Member

njhill commented Jul 10, 2025

@sdavidbd I opened a PR #20756 for this and included you as coauthor. Just want to get it merged quickly since there's another PR waiting to be rebased on this.

Chen-zexi pushed a commit to Chen-zexi/vllm that referenced this pull request Jul 13, 2025
patrickvonplaten pushed a commit to patrickvonplaten/vllm that referenced this pull request Jul 15, 2025
…ect#19555)

Signed-off-by: Or Ozeri <oro@il.ibm.com>
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
@kouroshHakha
Copy link
Collaborator

kouroshHakha commented Jul 16, 2025

@orozery @njhill @sdavidbd why are we moving the logic out of the nixl connector? Since the logic is not replicated across other distributed backends it is breaking ray's support for kv-connector. Essentially the logic has to be repeated inside the workers and executors which doesn't make sense. It seems to have belonged to the right components which was the connector implementation itself.

@orozery
Copy link
Contributor Author

orozery commented Jul 16, 2025

@orozery @njhill @sdavidbd why are we moving the logic out of the nixl connector? Since the logic is not replicated across other distributed backends it is breaking ray's support for kv-connector. Essentially the logic has to be repeated inside the workers and executors which doesn't make sense. It seems to have belonged to the right components which was the connector implementation itself.

My motivation was to allow connectors to re-use this logic without having to re-implement it.
Also, to support a more general logic which works also for PP and not just TP, by utilizing the existing scheduler<->worker communication channels (IPC using shared memory in the case of MultiprocExecutor).
I was not aware of the ray + nixl setup. Perhaps we can add a test that will prevent a future break (The current nixl tests use MultiprocExecutor).
Regarding different workers support - I don't see how this PR affects it. Today you need to integrate the connector API into each worker implementation (I believe the GPU worker is the only one who has it at the moment).

BTW I don't think it will be hard to fix the ray executor.
The MultiprocExeuctor._aggregate_workers_output can be moved to a util file that will be used by all executors.

@njhill
Copy link
Member

njhill commented Jul 16, 2025

Apologies, I also hadn't considered the ray executor implications. I agree with @orozery though that it should be straightforward to make a corresponding update to the ray executor, and also that we should cover that in the CI. I am out on vacation this week but can help with that next week if needed.

@kouroshHakha
Copy link
Collaborator

kouroshHakha commented Jul 16, 2025

My motivation was to allow connectors to re-use this logic without having to re-implement it.

IC, the reasoning for reusing this logic connectors make sense. But does it make sense for it to be implemented in a base class / utility that connectors use? or should it belong to the executor? I rather have executor logic agnostic to using a connector or otherwise. It's some sort of conceptual leakage.

I was not aware of the ray + nixl setup. Perhaps we can add a test that will prevent a future break (The current nixl tests use MultiprocExecutor).

Yeah the test coverage on nixl path is unfortunately still low. I just realized this now. We should add more tests (more importantly nixl dependency is not added to CI yet)

Regarding different workers support - I don't see how this PR affects it. Today you need to integrate the connector API into each worker implementation (I believe the GPU worker is the only one who has it at the moment).

I think you are right. This might not need to change since ray creates a wrapper worker around these

BTW I don't think it will be hard to fix the ray executor.
The MultiprocExeuctor._aggregate_workers_output can be moved to a util file that will be used by all executors.

True, but I think this logic still belongs to connectors conceptually.

@kouroshHakha
Copy link
Collaborator

@njhill @orozery @robertgshaw2-redhat What do you guys think about the reusability of this logic for done receiveing and done trasfering, I think it should be completely abstracted from executor and delegated completely to the connector implementations (reused across them)

@njhill
Copy link
Member

njhill commented Jul 17, 2025

My motivation was to allow connectors to re-use this logic without having to re-implement it.

IC, the reasoning for reusing this logic connectors make sense. But does it make sense for it to be implemented in a base class / utility that connectors use? or should it belong to the executor? I rather have executor logic agnostic to using a connector or otherwise. It's some sort of conceptual leakage.

I think that it makes sense to live in a utility class that executors use, since they are responsible for orchestrating the workers and for the multi-worker "abstraction". This could change later though if we have a more significant rethink of the connector interface, for example decoupling it from the workers altogether.

juncgu added a commit to juncgu/vllm that referenced this pull request Jul 21, 2025
Signed-off-by: Juncheng Gu <juncgu@gmail.com>
LyrisZhong pushed a commit to LyrisZhong/vllm that referenced this pull request Jul 23, 2025
avigny pushed a commit to avigny/vllm that referenced this pull request Jul 31, 2025
…ect#19555)

Signed-off-by: Or Ozeri <oro@il.ibm.com>
Signed-off-by: avigny <47987522+avigny@users.noreply.github.com>
Pradyun92 pushed a commit to Pradyun92/vllm that referenced this pull request Aug 6, 2025
npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
…ect#19555)

Signed-off-by: Or Ozeri <oro@il.ibm.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
taneem-ibrahim pushed a commit to taneem-ibrahim/vllm that referenced this pull request Aug 14, 2025
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
…ect#19555)

Signed-off-by: Or Ozeri <oro@il.ibm.com>
Signed-off-by: Diego-Castan <diego.castan@ibm.com>
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 27, 2025
googlercolin pushed a commit to googlercolin/vllm that referenced this pull request Aug 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants