Skip to content

Conversation

NickLucche
Copy link
Contributor

@NickLucche NickLucche commented May 28, 2025

Yet another take at #18079. It builds on the same commits so the rationale of the PR is the same.

The issue with the previous approach is that it appears using a higher number of descriptors per read -block_size as many, each region smaller by block_size times, so total number of bytes moved is unchanged- causes significant slowdowns. Mind that this is not happening for homogenous TP, where memory regions are seemingly merged by nixl prior to transfer.

Changing NIXL+UCX versions has a noticeable effect on performance, so we could in principle tackle the above directly at the transport layer.
Instead, here we take a different approach to factor out the transport layer altogether, and instantiate the kv cache with a memory layout [2, num_blocks, kv_heads, block_size, head_dim] . We then permute back to the original NHD to provide a view that guarantees correctness in the rest of the codebase.

This enables the splitting to be carried out on dim2, leading to much better performance as we maintain the same number of descriptors as well as bytes per-read (minus a factor of tp_ratio). The code is also somewhat easier to read with one less nested dim to account for.

This PR requires this #18775 to be merged first, as we need the proper scaffolding code for enforcing a different KV cache layout.

Here's some numbers:
image

This has been tested on NIXL 0.2.1 (4f37f07) and UCX 1.18.0.


In the MLA case, most of the splitting complexity above is not needed as kv caches are replicated and they can just be copied over in their entirety just like homogenous TP.
Codewise the changes are minimal as we're using the same logic for discovery as well as rank assignment so we can conveniently support both.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label May 28, 2025
Copy link

mergify bot commented May 29, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @NickLucche.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label May 29, 2025
@NickLucche NickLucche force-pushed the heterogenous-tp-permutekv branch from 7766cc4 to 39118d2 Compare May 30, 2025 10:30
@mergify mergify bot removed the needs-rebase label May 30, 2025
@NickLucche
Copy link
Contributor Author

Also tested on MLA with deepseek-vl2-small

@tlrmchlsmth
Copy link
Collaborator

Also tested on MLA with deepseek-vl2-small

@NickLucche it looks like that's not an MLA model fortunately.

We look for kv_lora_rank to see if the model uses MLA:

return self.hf_text_config.kv_lora_rank is not None

And there is no kv_lora_rank in that model's config:
https://huggingface.co/deepseek-ai/deepseek-vl2-small/blob/main/config.json

Could you try on deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct?
Or RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8 if you want a smaller model.

Comment on lines 96 to 99
@functools.lru_cache
def get_kv_connector_cache_layout():
vllm_config = get_current_vllm_config()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will the @functools.lru_cache will break things if someone creates two LLMEngines? (maybe only when using the UniProcExecutor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure, this should be a noop for all cases but PD with Nixl, and in that case every instance must have the same kv shape and layout to transfer in any case.
@njhill do you see how I could break things here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this needs to be cached, it should only be called during initialization anyhow.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually forgot to mention, I was anticipating a potential runtime use as in v0 https://github.com/vllm-project/vllm/blob/main/vllm/attention/backends/flashinfer.py#L1033.
I can still remove it as this is not the case in v1 right now, but we should consider it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove the @functools.lru_cache. I'm strongly suspicious of some edge cases where this could break and there's no benefit to caching here

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this PR only works with attn backends that can use the HND layout. (I.e. only FlashInfer and FlashAttn.

This is OK for this PR but please make sure we're raising an exception if the wrong attn backend is used.

Copy link

mergify bot commented Jun 2, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @NickLucche.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jun 2, 2025
@NickLucche
Copy link
Contributor Author

Thanks a lot for reviewing!

it looks like that's not an MLA model fortunately.

mm I think it is, and provided we start it wit hf_overrides we seem to detect it just fine.

 vllm serve deepseek-ai/deepseek-vl2-small --trust_remote_code --hf_overrides '{"architectures": ["DeepseekVLV2ForCausalLM"]}'
. . .
INFO 06-02 08:08:29 [cuda.py:153] Forcing kv cache block size to 64 for FlashMLA backend.
INFO 06-02 08:08:39 [cuda.py:192] Using FlashMLA backend on V1 engine.

Anyways for the sake of completeness I also tested with DeepSeek-Coder-V2-Lite-Instruct, getting Measured value: 0.7619408642911296.
Basically there's no kv permutation involved with MLA and no rank splitting due to replication.

@NickLucche NickLucche force-pushed the heterogenous-tp-permutekv branch from 378ec58 to 72a9da3 Compare June 2, 2025 09:16
@mergify mergify bot removed the needs-rebase label Jun 2, 2025
Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work thanks @NickLucche

Comment on lines 96 to 99
@functools.lru_cache
def get_kv_connector_cache_layout():
vllm_config = get_current_vllm_config()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this needs to be cached, it should only be called during initialization anyhow.

Copy link

mergify bot commented Jun 3, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @NickLucche.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jun 3, 2025
@NickLucche NickLucche force-pushed the heterogenous-tp-permutekv branch from cc8afb3 to 51a6309 Compare June 4, 2025 07:32
@mergify mergify bot removed the needs-rebase label Jun 4, 2025
address race condition

optimize req state checking loop

release_xfer_handle on status DONE

send notif to agent_name

fix req abort

Signed-off-by: nicklucche <nlucches@redhat.com>
docs

Signed-off-by: nicklucche <nlucches@redhat.com>
Signed-off-by: nicklucche <nlucches@redhat.com>
Signed-off-by: nicklucche <nlucches@redhat.com>
Signed-off-by: nicklucche <nlucches@redhat.com>
Signed-off-by: nicklucche <nlucches@redhat.com>
@tlrmchlsmth tlrmchlsmth added the ready ONLY add when PR is ready to merge/full CI is needed label Jun 4, 2025
@tlrmchlsmth tlrmchlsmth enabled auto-merge (squash) June 4, 2025 14:43
@njhill njhill disabled auto-merge June 4, 2025 16:52
Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work thanks @NickLucche.

Most of my comments are minor - I don't think any of my comments necessarily need to hold up getting this merged.

Comment on lines +364 to +365
# Map of engine_id -> kv_caches_base_addr. For TP case, each local
# rank will still only pull from a single remote TP worker.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NickLucche probably a stupid question but we only support D_TP >= P_TP specifically D_TP = N*D_TP right? We can't have larger P size than D size

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a simplification I carried over from dynamo work.
Basically it's just makes sense given the framing of the problem: D is memory bound so greater TP size will yield better performance.

In theory one could support both, but the code gets messier because you have to discern between a single D reading from N prefill workers or N Ds reading from a single P (as in this case). Sync code also gets less clean.

Comment on lines +672 to +677
remote_block_size = nixl_agent_meta.block_len / (
self.slot_size_bytes)
assert self.block_len == nixl_agent_meta.block_len
else:
remote_block_size = nixl_agent_meta.block_len / (
self.slot_size_bytes * tp_ratio)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these be // rather than /?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am asserting equality this should be an exact division. It's something like A=2ABC/2BC

Copy link
Member

@njhill njhill Jun 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok sure.. I was suggesting more because this is integer division rather than float division... here remote_block_size will actually be a float

>>> type(8 / 2)
<class 'float'>

Signed-off-by: nicklucche <nlucches@redhat.com>
Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again @NickLucche!

@njhill njhill enabled auto-merge (squash) June 4, 2025 21:11
@njhill njhill merged commit b2fac67 into vllm-project:main Jun 4, 2025
74 checks passed
minpeter pushed a commit to minpeter/vllm that referenced this pull request Jun 24, 2025
Signed-off-by: nicklucche <nlucches@redhat.com>
Signed-off-by: minpeter <kali2005611@gmail.com>
@lhtin
Copy link
Contributor

lhtin commented Jul 3, 2025

@NickLucche Do you known whether ​​Homogeneous​​ and ​​Heterogeneous TP​​ are supported in ​​multinode TP​​ scenarios? Based on the code I've reviewed, there's currently only a single remote_host and remote_port configuration, which suggests that ​​multinode TP​​ in NIXL PD is not supported.

@NickLucche
Copy link
Contributor Author

Remote_host/port pair are forwarded from the proxy/sidecar server which is aware of the deployment layout. So yeah this is intended to multi node use

@lhtin
Copy link
Contributor

lhtin commented Jul 4, 2025

@NickLucche Could you provide an example for this part? From the code snippet I see below, it appears that the Decode node only receives a single remote_host and remote_port from kv_transfer_params, not an array. However, in a multi-node scenario (e.g., TP16 with ranks 0-7 on one node and ranks 8-15 on another), wouldn’t we need at least two remote_host entries?

self.requests[request_id] = ReqMeta(
local_block_ids=local_block_ids,
remote_block_ids=kv_transfer_params["remote_block_ids"],
remote_engine_id=kv_transfer_params["remote_engine_id"],
remote_host=kv_transfer_params["remote_host"],
remote_port=kv_transfer_params["remote_port"],
# P workers don't need to receive tp_size from proxy here.
tp_size=kv_transfer_params.get("tp_size", 1),
)

leoli1208 pushed a commit to leoli1208/vllm that referenced this pull request Jul 22, 2025
Signed-off-by: nicklucche <nlucches@redhat.com>
avigny pushed a commit to avigny/vllm that referenced this pull request Jul 31, 2025
Signed-off-by: nicklucche <nlucches@redhat.com>
Signed-off-by: avigny <47987522+avigny@users.noreply.github.com>
googlercolin pushed a commit to googlercolin/vllm that referenced this pull request Aug 29, 2025
Signed-off-by: nicklucche <nlucches@redhat.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants