Skip to content

Conversation

houseroad
Copy link
Collaborator

@houseroad houseroad commented Jul 21, 2025

Summary: Original implementation is inefficient, since we will create a copy of the original tensor. So create a triton kernel to get this done in place.

Test Plan

Launch the server with warm up requests

Before:
OOM

After:
No more OOM

Differential Revision: D78606604

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@facebook-github-bot
Copy link

This pull request was exported from Phabricator. Differential Revision: D78606604

@mergify mergify bot added the v1 label Jul 21, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a custom Triton kernel to optimize the calculation of token ranks from log probabilities, avoiding the creation of a large intermediate tensor. The change is well-motivated for performance.

My review identified a critical correctness issue in the Triton kernel where a strict 'greater than' comparison is used instead of 'greater than or equal', which will lead to incorrect rank calculations. I've also pointed out a type inconsistency between the Triton path and the PyTorch fallback path that should be addressed. Both issues include code suggestions for the fix.


# 6. Perform the comparison and sum the result within this block.
# `x_block > value` creates a boolean vector, and tl.sum treats True as 1.
local_count = tl.sum(x_block > value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The Triton kernel uses a strict greater than comparison (>), but the original implementation being replaced ((logprobs >= token_logprobs).sum(-1)) and the fallback implementation on line 107 use greater than or equal (>=). This will lead to incorrect token rank calculations, especially when there are ties in log probability values. The rank of a token should be the number of tokens with a log probability greater than or equal to its own, so >= is the correct operator.

Suggested change
local_count = tl.sum(x_block > value)
local_count = tl.sum(x_block >= value)

if HAS_TRITON:
return batched_count_greater_than_triton(x, values)
else:
return (x >= values).sum(-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The Triton kernel path returns a tensor of dtype=torch.int32 (as defined on line 69). For consistency, the PyTorch fallback path should also return a tensor of the same dtype. The .sum() operation on a boolean tensor defaults to int64, which could lead to unexpected type mismatches or performance issues downstream.

Suggested change
return (x >= values).sum(-1)
return (x >= values).sum(-1, dtype=torch.int32)


# 6. Perform the comparison and sum the result within this block.
# `x_block > value` creates a boolean vector, and tl.sum treats True as 1.
local_count = tl.sum(x_block > value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, this PR is for HBM usage reduction?

  • Originally, x >= values would create a new bool tensor before count True against it, so the HBM overhead would be x.nbytes.
  • In the new implementation, x_block > value would still create boolean vector before accumulatiing with tl.sum. However, blocks are not executed at the exact same time, so the HBM usage would become max_concurrent_block * x_block.nbytes

Is my understand accurate?

I kinda feel we're having quite a lot usage similar to (x >= y).sum(), should we introduce a new pytorch API to optimize this, in general.

@Jialin
Copy link
Contributor

Jialin commented Jul 21, 2025

Maybe add unittest to confirm the correctness?

@@ -174,7 +176,7 @@ def gather_logprobs(
token_logprobs = logprobs.gather(-1, token_ids)

# Compute the ranks of the actual token.
token_ranks = (logprobs >= token_logprobs).sum(-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in general I think the logit processor is quite inefficient with a bunch of intermediates. I think adopting some compilation techniques e.g. https://github.com/flashinfer-ai/flashinfer/blob/main/flashinfer/logits_processor/compiler.py would be good. I also think this is actually a good use case for torch.compile.

houseroad added a commit to houseroad/vllm that referenced this pull request Jul 21, 2025
Summary:

Original implementation is inefficient, since we will create a copy of the original tensor. So create a triton kernel to get this done in place.

Test Plan:
```
VLLM_GPU_MEMORY_UTILIZATION=0.8 \
VLLM_USE_FLASHINFER_SAMPLER=0 \
CUDA_VISIBLE_DEVICES=0 \
VLLM_USE_V1=1 \
SAFETENSORS_FAST_GPU=1 \
buck2 run mode/opt mode/inplace \
  -c fbcode.platform010_cuda_version=12.8 \
  -c fbcode.enable_vllm=true \
  -c fbcode.enable_gpu_sections=true \
  -c fbcode.nvcc_arch=h100a //smart/inference_platform_sp/llm_predictor_gpu:service -- \
  --max_seq_len=8192 \
  --max_batch_size=64 \
  --model_mf_bucket=fair_llms_prod \
  --model_mf_path=tree/894459467/0 \
  --allow_custom_stop_tokens \
  --vllm_engine --local_cache_dir=/tmp/llama4-planner \
  --try_local_cache \
  --thrift_server_port 12345 \
  --model_parallel_size=1 \
  --thrift_queue_timeout_ms=500 \
  --guided_decode \
  --token_logprobs
```

No more OOM

Rollback Plan:

Reviewed By: JialinOuyang-Meta

Differential Revision: D78606604
houseroad added a commit to houseroad/vllm that referenced this pull request Jul 21, 2025
Summary:

Original implementation is inefficient, since we will create a copy of the original tensor. So create a triton kernel to get this done in place.

Test Plan:
```
VLLM_GPU_MEMORY_UTILIZATION=0.8 \
VLLM_USE_FLASHINFER_SAMPLER=0 \
CUDA_VISIBLE_DEVICES=0 \
VLLM_USE_V1=1 \
SAFETENSORS_FAST_GPU=1 \
buck2 run mode/opt mode/inplace \
  -c fbcode.platform010_cuda_version=12.8 \
  -c fbcode.enable_vllm=true \
  -c fbcode.enable_gpu_sections=true \
  -c fbcode.nvcc_arch=h100a //smart/inference_platform_sp/llm_predictor_gpu:service -- \
  --max_seq_len=8192 \
  --max_batch_size=64 \
  --model_mf_bucket=fair_llms_prod \
  --model_mf_path=tree/894459467/0 \
  --allow_custom_stop_tokens \
  --vllm_engine --local_cache_dir=/tmp/llama4-planner \
  --try_local_cache \
  --thrift_server_port 12345 \
  --model_parallel_size=1 \
  --thrift_queue_timeout_ms=500 \
  --guided_decode \
  --token_logprobs
```

No more OOM

Rollback Plan:

Reviewed By: JialinOuyang-Meta

Differential Revision: D78606604
@houseroad houseroad force-pushed the export-D78606604 branch 2 times, most recently from f0856b4 to 93cef4d Compare July 21, 2025 07:25
houseroad added a commit to houseroad/vllm that referenced this pull request Jul 21, 2025
Summary:

Original implementation is inefficient, since we will create a copy of the original tensor. So create a triton kernel to get this done in place.

Test Plan:
```
VLLM_GPU_MEMORY_UTILIZATION=0.8 \
VLLM_USE_FLASHINFER_SAMPLER=0 \
CUDA_VISIBLE_DEVICES=0 \
VLLM_USE_V1=1 \
SAFETENSORS_FAST_GPU=1 \
buck2 run mode/opt mode/inplace \
  -c fbcode.platform010_cuda_version=12.8 \
  -c fbcode.enable_vllm=true \
  -c fbcode.enable_gpu_sections=true \
  -c fbcode.nvcc_arch=h100a //smart/inference_platform_sp/llm_predictor_gpu:service -- \
  --max_seq_len=8192 \
  --max_batch_size=64 \
  --model_mf_bucket=fair_llms_prod \
  --model_mf_path=tree/894459467/0 \
  --allow_custom_stop_tokens \
  --vllm_engine --local_cache_dir=/tmp/llama4-planner \
  --try_local_cache \
  --thrift_server_port 12345 \
  --model_parallel_size=1 \
  --thrift_queue_timeout_ms=500 \
  --guided_decode \
  --token_logprobs
```

No more OOM

Rollback Plan:

Reviewed By: JialinOuyang-Meta

Differential Revision: D78606604
@facebook-github-bot
Copy link

This pull request was exported from Phabricator. Differential Revision: D78606604

@houseroad
Copy link
Collaborator Author

Actually, switched to torch.compile, no more OOM. Thanks to @yinghai for the suggestion.

@houseroad houseroad added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 21, 2025
@facebook-github-bot
Copy link

This pull request was exported from Phabricator. Differential Revision: D78606604

houseroad added a commit to houseroad/vllm that referenced this pull request Jul 21, 2025
Summary:
Pull Request resolved: vllm-project#21283

Original implementation is inefficient, since we will create a copy of the original tensor. So create a triton kernel to get this done in place.

Test Plan:
```
VLLM_GPU_MEMORY_UTILIZATION=0.8 \
VLLM_USE_FLASHINFER_SAMPLER=0 \
CUDA_VISIBLE_DEVICES=0 \
VLLM_USE_V1=1 \
SAFETENSORS_FAST_GPU=1 \
buck2 run mode/opt mode/inplace \
  -c fbcode.platform010_cuda_version=12.8 \
  -c fbcode.enable_vllm=true \
  -c fbcode.enable_gpu_sections=true \
  -c fbcode.nvcc_arch=h100a //smart/inference_platform_sp/llm_predictor_gpu:service -- \
  --max_seq_len=8192 \
  --max_batch_size=64 \
  --model_mf_bucket=fair_llms_prod \
  --model_mf_path=tree/894459467/0 \
  --allow_custom_stop_tokens \
  --vllm_engine --local_cache_dir=/tmp/llama4-planner \
  --try_local_cache \
  --thrift_server_port 12345 \
  --model_parallel_size=1 \
  --thrift_queue_timeout_ms=500 \
  --guided_decode \
  --token_logprobs
```

No more OOM

Rollback Plan:

Reviewed By: JialinOuyang-Meta

Differential Revision: D78606604
@houseroad houseroad changed the title Inplace logprobs comparison to get the token rank [v1][sampler] Inplace logprobs comparison to get the token rank Jul 21, 2025
houseroad added a commit to houseroad/vllm that referenced this pull request Jul 21, 2025
Summary:
Pull Request resolved: vllm-project#21283

Original implementation is inefficient, since we will create a copy of the original tensor. So create a triton kernel to get this done in place.

Test Plan:
```
VLLM_GPU_MEMORY_UTILIZATION=0.8 \
VLLM_USE_FLASHINFER_SAMPLER=0 \
CUDA_VISIBLE_DEVICES=0 \
VLLM_USE_V1=1 \
SAFETENSORS_FAST_GPU=1 \
buck2 run mode/opt mode/inplace \
  -c fbcode.platform010_cuda_version=12.8 \
  -c fbcode.enable_vllm=true \
  -c fbcode.enable_gpu_sections=true \
  -c fbcode.nvcc_arch=h100a //smart/inference_platform_sp/llm_predictor_gpu:service -- \
  --max_seq_len=8192 \
  --max_batch_size=64 \
  --model_mf_bucket=fair_llms_prod \
  --model_mf_path=tree/894459467/0 \
  --allow_custom_stop_tokens \
  --vllm_engine --local_cache_dir=/tmp/llama4-planner \
  --try_local_cache \
  --thrift_server_port 12345 \
  --model_parallel_size=1 \
  --thrift_queue_timeout_ms=500 \
  --guided_decode \
  --token_logprobs
```

No more OOM

Rollback Plan:

Reviewed By: JialinOuyang-Meta

Differential Revision: D78606604

Signed-off-by: Lu Fang <lufang@fb.com>
Summary:
Pull Request resolved: vllm-project#21283

Original implementation is inefficient, since we will create a copy of the original tensor. So create a triton kernel to get this done in place.

Test Plan:
```
VLLM_GPU_MEMORY_UTILIZATION=0.8 \
VLLM_USE_FLASHINFER_SAMPLER=0 \
CUDA_VISIBLE_DEVICES=0 \
VLLM_USE_V1=1 \
SAFETENSORS_FAST_GPU=1 \
buck2 run mode/opt mode/inplace \
  -c fbcode.platform010_cuda_version=12.8 \
  -c fbcode.enable_vllm=true \
  -c fbcode.enable_gpu_sections=true \
  -c fbcode.nvcc_arch=h100a //smart/inference_platform_sp/llm_predictor_gpu:service -- \
  --max_seq_len=8192 \
  --max_batch_size=64 \
  --model_mf_bucket=fair_llms_prod \
  --model_mf_path=tree/894459467/0 \
  --allow_custom_stop_tokens \
  --vllm_engine --local_cache_dir=/tmp/llama4-planner \
  --try_local_cache \
  --thrift_server_port 12345 \
  --model_parallel_size=1 \
  --thrift_queue_timeout_ms=500 \
  --guided_decode \
  --token_logprobs
```

No more OOM

Rollback Plan:

Reviewed By: JialinOuyang-Meta

Differential Revision: D78606604

Signed-off-by: Lu Fang <lufang@fb.com>
@facebook-github-bot
Copy link

@houseroad has imported this pull request. If you are a Meta employee, you can view this in D78606604.

@houseroad houseroad requested a review from zou3519 July 21, 2025 17:10
@WoosukKwon WoosukKwon merged commit 8d0a01a into vllm-project:main Jul 21, 2025
62 of 65 checks passed
@yeqcharlotte
Copy link
Collaborator

@houseroad @zou3519 will this be JIT and trigger compilation on the fly?

@houseroad
Copy link
Collaborator Author

The overhead should be fine, since it's only a few ops, usually this is what torch.compile is really great at.

@zou3519
Copy link
Collaborator

zou3519 commented Jul 22, 2025

This will trigger compilation on the fly. It should only compile once.

It is possible to enforce this via torch.compiler.set_stance("fail_on_recompiles") and/or vLLM's existing support_torch_compile mechanism. vLLM's existing mechanism might be too heavy for this, but in the long term we may want something that interacts better with the vLLM-compile cache.

@yeqcharlotte
Copy link
Collaborator

This will trigger compilation on the fly. It should only compile once.

It is possible to enforce this via torch.compiler.set_stance("fail_on_recompiles") and/or vLLM's existing support_torch_compile mechanism. vLLM's existing mechanism might be too heavy for this, but in the long term we may want something that interacts better with the vLLM-compile cache.

I see. Essentially we may need some proper warm-up to avoid weird latency spikes?

@yinghai
Copy link
Contributor

yinghai commented Jul 22, 2025

I think you can do static shape compilation for this and I think vLLM does more than enough warmup runs lol. It specifically does a warmup with max_request sampling path being triggered.

zixi-qi pushed a commit to zixi-qi/vllm that referenced this pull request Jul 23, 2025
…-project#21283)

Signed-off-by: Lu Fang <lufang@fb.com>
Signed-off-by: qizixi <qizixi@meta.com>
LyrisZhong pushed a commit to LyrisZhong/vllm that referenced this pull request Jul 23, 2025
avigny pushed a commit to avigny/vllm that referenced this pull request Jul 31, 2025
…-project#21283)

Signed-off-by: Lu Fang <lufang@fb.com>
Signed-off-by: avigny <47987522+avigny@users.noreply.github.com>
wenscarl pushed a commit to wenscarl/vllm that referenced this pull request Aug 4, 2025
…-project#21283)

Signed-off-by: Lu Fang <lufang@fb.com>
Signed-off-by: shuw <shuw@nvidia.com>
x22x22 pushed a commit to x22x22/vllm that referenced this pull request Aug 5, 2025
…-project#21283)

Signed-off-by: Lu Fang <lufang@fb.com>
Signed-off-by: x22x22 <wadeking@qq.com>
Pradyun92 pushed a commit to Pradyun92/vllm that referenced this pull request Aug 6, 2025
npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
…-project#21283)

Signed-off-by: Lu Fang <lufang@fb.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
…-project#21283)

Signed-off-by: Lu Fang <lufang@fb.com>
Signed-off-by: Paul Pak <paulpak58@gmail.com>
taneem-ibrahim pushed a commit to taneem-ibrahim/vllm that referenced this pull request Aug 14, 2025
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
…-project#21283)

Signed-off-by: Lu Fang <lufang@fb.com>
Signed-off-by: Diego-Castan <diego.castan@ibm.com>
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 27, 2025
googlercolin pushed a commit to googlercolin/vllm that referenced this pull request Aug 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants