-
-
Notifications
You must be signed in to change notification settings - Fork 9.9k
[v1][sampler] Inplace logprobs comparison to get the token rank #21283
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
This pull request was exported from Phabricator. Differential Revision: D78606604 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a custom Triton kernel to optimize the calculation of token ranks from log probabilities, avoiding the creation of a large intermediate tensor. The change is well-motivated for performance.
My review identified a critical correctness issue in the Triton kernel where a strict 'greater than' comparison is used instead of 'greater than or equal', which will lead to incorrect rank calculations. I've also pointed out a type inconsistency between the Triton path and the PyTorch fallback path that should be addressed. Both issues include code suggestions for the fix.
vllm/v1/sample/ops/logprobs.py
Outdated
|
||
# 6. Perform the comparison and sum the result within this block. | ||
# `x_block > value` creates a boolean vector, and tl.sum treats True as 1. | ||
local_count = tl.sum(x_block > value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Triton kernel uses a strict greater than comparison (>
), but the original implementation being replaced ((logprobs >= token_logprobs).sum(-1)
) and the fallback implementation on line 107 use greater than or equal (>=
). This will lead to incorrect token rank calculations, especially when there are ties in log probability values. The rank of a token should be the number of tokens with a log probability greater than or equal to its own, so >=
is the correct operator.
local_count = tl.sum(x_block > value) | |
local_count = tl.sum(x_block >= value) |
vllm/v1/sample/ops/logprobs.py
Outdated
if HAS_TRITON: | ||
return batched_count_greater_than_triton(x, values) | ||
else: | ||
return (x >= values).sum(-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Triton kernel path returns a tensor of dtype=torch.int32
(as defined on line 69). For consistency, the PyTorch fallback path should also return a tensor of the same dtype. The .sum()
operation on a boolean tensor defaults to int64
, which could lead to unexpected type mismatches or performance issues downstream.
return (x >= values).sum(-1) | |
return (x >= values).sum(-1, dtype=torch.int32) |
vllm/v1/sample/ops/logprobs.py
Outdated
|
||
# 6. Perform the comparison and sum the result within this block. | ||
# `x_block > value` creates a boolean vector, and tl.sum treats True as 1. | ||
local_count = tl.sum(x_block > value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, this PR is for HBM usage reduction?
- Originally, x >= values would create a new bool tensor before count True against it, so the HBM overhead would be x.nbytes.
- In the new implementation, x_block > value would still create boolean vector before accumulatiing with tl.sum. However, blocks are not executed at the exact same time, so the HBM usage would become max_concurrent_block * x_block.nbytes
Is my understand accurate?
I kinda feel we're having quite a lot usage similar to (x >= y).sum(), should we introduce a new pytorch API to optimize this, in general.
Maybe add unittest to confirm the correctness? |
@@ -174,7 +176,7 @@ def gather_logprobs( | |||
token_logprobs = logprobs.gather(-1, token_ids) | |||
|
|||
# Compute the ranks of the actual token. | |||
token_ranks = (logprobs >= token_logprobs).sum(-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in general I think the logit processor is quite inefficient with a bunch of intermediates. I think adopting some compilation techniques e.g. https://github.com/flashinfer-ai/flashinfer/blob/main/flashinfer/logits_processor/compiler.py would be good. I also think this is actually a good use case for torch.compile.
Summary: Original implementation is inefficient, since we will create a copy of the original tensor. So create a triton kernel to get this done in place. Test Plan: ``` VLLM_GPU_MEMORY_UTILIZATION=0.8 \ VLLM_USE_FLASHINFER_SAMPLER=0 \ CUDA_VISIBLE_DEVICES=0 \ VLLM_USE_V1=1 \ SAFETENSORS_FAST_GPU=1 \ buck2 run mode/opt mode/inplace \ -c fbcode.platform010_cuda_version=12.8 \ -c fbcode.enable_vllm=true \ -c fbcode.enable_gpu_sections=true \ -c fbcode.nvcc_arch=h100a //smart/inference_platform_sp/llm_predictor_gpu:service -- \ --max_seq_len=8192 \ --max_batch_size=64 \ --model_mf_bucket=fair_llms_prod \ --model_mf_path=tree/894459467/0 \ --allow_custom_stop_tokens \ --vllm_engine --local_cache_dir=/tmp/llama4-planner \ --try_local_cache \ --thrift_server_port 12345 \ --model_parallel_size=1 \ --thrift_queue_timeout_ms=500 \ --guided_decode \ --token_logprobs ``` No more OOM Rollback Plan: Reviewed By: JialinOuyang-Meta Differential Revision: D78606604
Summary: Original implementation is inefficient, since we will create a copy of the original tensor. So create a triton kernel to get this done in place. Test Plan: ``` VLLM_GPU_MEMORY_UTILIZATION=0.8 \ VLLM_USE_FLASHINFER_SAMPLER=0 \ CUDA_VISIBLE_DEVICES=0 \ VLLM_USE_V1=1 \ SAFETENSORS_FAST_GPU=1 \ buck2 run mode/opt mode/inplace \ -c fbcode.platform010_cuda_version=12.8 \ -c fbcode.enable_vllm=true \ -c fbcode.enable_gpu_sections=true \ -c fbcode.nvcc_arch=h100a //smart/inference_platform_sp/llm_predictor_gpu:service -- \ --max_seq_len=8192 \ --max_batch_size=64 \ --model_mf_bucket=fair_llms_prod \ --model_mf_path=tree/894459467/0 \ --allow_custom_stop_tokens \ --vllm_engine --local_cache_dir=/tmp/llama4-planner \ --try_local_cache \ --thrift_server_port 12345 \ --model_parallel_size=1 \ --thrift_queue_timeout_ms=500 \ --guided_decode \ --token_logprobs ``` No more OOM Rollback Plan: Reviewed By: JialinOuyang-Meta Differential Revision: D78606604
f0856b4
to
93cef4d
Compare
Summary: Original implementation is inefficient, since we will create a copy of the original tensor. So create a triton kernel to get this done in place. Test Plan: ``` VLLM_GPU_MEMORY_UTILIZATION=0.8 \ VLLM_USE_FLASHINFER_SAMPLER=0 \ CUDA_VISIBLE_DEVICES=0 \ VLLM_USE_V1=1 \ SAFETENSORS_FAST_GPU=1 \ buck2 run mode/opt mode/inplace \ -c fbcode.platform010_cuda_version=12.8 \ -c fbcode.enable_vllm=true \ -c fbcode.enable_gpu_sections=true \ -c fbcode.nvcc_arch=h100a //smart/inference_platform_sp/llm_predictor_gpu:service -- \ --max_seq_len=8192 \ --max_batch_size=64 \ --model_mf_bucket=fair_llms_prod \ --model_mf_path=tree/894459467/0 \ --allow_custom_stop_tokens \ --vllm_engine --local_cache_dir=/tmp/llama4-planner \ --try_local_cache \ --thrift_server_port 12345 \ --model_parallel_size=1 \ --thrift_queue_timeout_ms=500 \ --guided_decode \ --token_logprobs ``` No more OOM Rollback Plan: Reviewed By: JialinOuyang-Meta Differential Revision: D78606604
This pull request was exported from Phabricator. Differential Revision: D78606604 |
Actually, switched to torch.compile, no more OOM. Thanks to @yinghai for the suggestion. |
This pull request was exported from Phabricator. Differential Revision: D78606604 |
Summary: Pull Request resolved: vllm-project#21283 Original implementation is inefficient, since we will create a copy of the original tensor. So create a triton kernel to get this done in place. Test Plan: ``` VLLM_GPU_MEMORY_UTILIZATION=0.8 \ VLLM_USE_FLASHINFER_SAMPLER=0 \ CUDA_VISIBLE_DEVICES=0 \ VLLM_USE_V1=1 \ SAFETENSORS_FAST_GPU=1 \ buck2 run mode/opt mode/inplace \ -c fbcode.platform010_cuda_version=12.8 \ -c fbcode.enable_vllm=true \ -c fbcode.enable_gpu_sections=true \ -c fbcode.nvcc_arch=h100a //smart/inference_platform_sp/llm_predictor_gpu:service -- \ --max_seq_len=8192 \ --max_batch_size=64 \ --model_mf_bucket=fair_llms_prod \ --model_mf_path=tree/894459467/0 \ --allow_custom_stop_tokens \ --vllm_engine --local_cache_dir=/tmp/llama4-planner \ --try_local_cache \ --thrift_server_port 12345 \ --model_parallel_size=1 \ --thrift_queue_timeout_ms=500 \ --guided_decode \ --token_logprobs ``` No more OOM Rollback Plan: Reviewed By: JialinOuyang-Meta Differential Revision: D78606604
93cef4d
to
e2942a3
Compare
Summary: Pull Request resolved: vllm-project#21283 Original implementation is inefficient, since we will create a copy of the original tensor. So create a triton kernel to get this done in place. Test Plan: ``` VLLM_GPU_MEMORY_UTILIZATION=0.8 \ VLLM_USE_FLASHINFER_SAMPLER=0 \ CUDA_VISIBLE_DEVICES=0 \ VLLM_USE_V1=1 \ SAFETENSORS_FAST_GPU=1 \ buck2 run mode/opt mode/inplace \ -c fbcode.platform010_cuda_version=12.8 \ -c fbcode.enable_vllm=true \ -c fbcode.enable_gpu_sections=true \ -c fbcode.nvcc_arch=h100a //smart/inference_platform_sp/llm_predictor_gpu:service -- \ --max_seq_len=8192 \ --max_batch_size=64 \ --model_mf_bucket=fair_llms_prod \ --model_mf_path=tree/894459467/0 \ --allow_custom_stop_tokens \ --vllm_engine --local_cache_dir=/tmp/llama4-planner \ --try_local_cache \ --thrift_server_port 12345 \ --model_parallel_size=1 \ --thrift_queue_timeout_ms=500 \ --guided_decode \ --token_logprobs ``` No more OOM Rollback Plan: Reviewed By: JialinOuyang-Meta Differential Revision: D78606604 Signed-off-by: Lu Fang <lufang@fb.com>
e2942a3
to
e509cd6
Compare
Summary: Pull Request resolved: vllm-project#21283 Original implementation is inefficient, since we will create a copy of the original tensor. So create a triton kernel to get this done in place. Test Plan: ``` VLLM_GPU_MEMORY_UTILIZATION=0.8 \ VLLM_USE_FLASHINFER_SAMPLER=0 \ CUDA_VISIBLE_DEVICES=0 \ VLLM_USE_V1=1 \ SAFETENSORS_FAST_GPU=1 \ buck2 run mode/opt mode/inplace \ -c fbcode.platform010_cuda_version=12.8 \ -c fbcode.enable_vllm=true \ -c fbcode.enable_gpu_sections=true \ -c fbcode.nvcc_arch=h100a //smart/inference_platform_sp/llm_predictor_gpu:service -- \ --max_seq_len=8192 \ --max_batch_size=64 \ --model_mf_bucket=fair_llms_prod \ --model_mf_path=tree/894459467/0 \ --allow_custom_stop_tokens \ --vllm_engine --local_cache_dir=/tmp/llama4-planner \ --try_local_cache \ --thrift_server_port 12345 \ --model_parallel_size=1 \ --thrift_queue_timeout_ms=500 \ --guided_decode \ --token_logprobs ``` No more OOM Rollback Plan: Reviewed By: JialinOuyang-Meta Differential Revision: D78606604 Signed-off-by: Lu Fang <lufang@fb.com>
e509cd6
to
8003dc7
Compare
@houseroad has imported this pull request. If you are a Meta employee, you can view this in D78606604. |
@houseroad @zou3519 will this be JIT and trigger compilation on the fly? |
The overhead should be fine, since it's only a few ops, usually this is what torch.compile is really great at. |
This will trigger compilation on the fly. It should only compile once. It is possible to enforce this via torch.compiler.set_stance("fail_on_recompiles") and/or vLLM's existing support_torch_compile mechanism. vLLM's existing mechanism might be too heavy for this, but in the long term we may want something that interacts better with the vLLM-compile cache. |
I see. Essentially we may need some proper warm-up to avoid weird latency spikes? |
I think you can do static shape compilation for this and I think vLLM does more than enough warmup runs lol. It specifically does a warmup with max_request sampling path being triggered. |
…-project#21283) Signed-off-by: Lu Fang <lufang@fb.com> Signed-off-by: qizixi <qizixi@meta.com>
…-project#21283) Signed-off-by: Lu Fang <lufang@fb.com>
…-project#21283) Signed-off-by: Lu Fang <lufang@fb.com> Signed-off-by: avigny <47987522+avigny@users.noreply.github.com>
…-project#21283) Signed-off-by: Lu Fang <lufang@fb.com> Signed-off-by: shuw <shuw@nvidia.com>
…-project#21283) Signed-off-by: Lu Fang <lufang@fb.com> Signed-off-by: x22x22 <wadeking@qq.com>
…-project#21283) Signed-off-by: Lu Fang <lufang@fb.com>
…-project#21283) Signed-off-by: Lu Fang <lufang@fb.com>
…-project#21283) Signed-off-by: Lu Fang <lufang@fb.com> Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
…-project#21283) Signed-off-by: Lu Fang <lufang@fb.com> Signed-off-by: Paul Pak <paulpak58@gmail.com>
…-project#21283) Signed-off-by: Lu Fang <lufang@fb.com>
…-project#21283) Signed-off-by: Lu Fang <lufang@fb.com> Signed-off-by: Diego-Castan <diego.castan@ibm.com>
…-project#21283) Signed-off-by: Lu Fang <lufang@fb.com>
…-project#21283) Signed-off-by: Lu Fang <lufang@fb.com>
Summary: Original implementation is inefficient, since we will create a copy of the original tensor. So create a triton kernel to get this done in place.
Test Plan
Launch the server with warm up requests
Before:
OOM
After:
No more OOM
Differential Revision: D78606604