Skip to content

Conversation

vadiklyutiy
Copy link
Contributor

@vadiklyutiy vadiklyutiy commented May 3, 2025

Speed up of _prepare_intups for models with M-RoPE by caching constant tensor in M-RoPE implementation.

A huge part of _prepare_inputs for Qwen2.5-VL-3B is work for M-RoPE. 1/3 of time is creating small constant CPU tensors in MRotaryEmbedding.get_next_input_positions_tensor(). This PR make a cache for these tensors to speed up.

Performance results.
Tested with Qwen2.5-VL-3B

  • prepare_intups itself speeded up by 35%

  • E2E

vllm serve Qwen/Qwen2.5-VL-3B-Instruct --disable-log-requests --max-num-seqs 1024  --block-size 16 --max-num-batched-tokens 2048

For workload used https://github.com/CentML/flexible-inference-bench
Below command sends 1000 requests, with 50 reqs per second, with one image 512x512 per request.

fib benchmark -rps 50 --input-token-distribution uniform 250 300     --output-token-distribution uniform 150 250 --num-of-imgs-per-req 1     --img-ratios-per-req 512x512 -n 1000 --base-url http://localhost:8000     --endpoint v1/chat/completions --backend openai-chat

Before 31.93 reqs/s
After 32.54 reqs/s
Speed up 1.9%

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@centml.ai>
Copy link

github-actions bot commented May 3, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! The change makes sense to me and I left a comment - cc @imkero

@@ -1404,6 +1405,7 @@ def get_next_input_positions(
]

@staticmethod
@functools.lru_cache(maxsize=1024)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a particular reason why we set this to 1024?

Copy link
Contributor

@imkero imkero May 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a particular reason why we set this to 1024?

I guess this is expected to be large enough to catch as much output mrope positions as possible.

Copy link
Contributor

@imkero imkero May 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a particular reason why we set this to 1024?

I guess this is expected to be large enough to catch as much output mrope positions as possible.

In fact we are caching a lot of 3x1 tensors like [[32],[32],[32]], [[33],[33],[33]]

I will take a closer look if we can have a better way to do this.

@imkero
Copy link
Contributor

imkero commented May 5, 2025

Thanks for the PR! The change makes sense to me and I left a comment - cc @imkero

@ywang96 LGTM for those changes.

And maybe please take a look on #16881 ? It optimize the rest part (new request) of MRoPE position preparation, and i will take get_next_input_positions_tensor into consideration because they are reported to be time-costing as well.

Thanks @vadiklyutiy for reporting this with benchmark

@imkero
Copy link
Contributor

imkero commented May 5, 2025

I feel that those changes should not take too much effect on the E2E time cost as they cost <0.1ms each op... even x1000 will bring only 0.1s difference

To make the benchmark reproducible across different runs, I suggest setting --input-token-distribution uniform 200 200 and --output-token-distribution uniform 200 200, and also pass image data URLs instead of downloading image via http URLs

(those settings require some modifications on fib)

@vadiklyutiy
Copy link
Contributor Author

vadiklyutiy commented May 5, 2025

I feel that those changes should not take too much effect on the E2E time cost as they cost <0.1ms each op... even x1000 will bring only 0.1s difference

ms or us? 0.1 ms * 1000 = 100s

@imkero
Copy link
Contributor

imkero commented May 5, 2025

I feel that those changes should not take too much effect on the E2E time cost as they cost <0.1ms each op... even x1000 will bring only 0.1s difference

ms or us? 0.1 ms * 1000 = 100s

sorry, i mean 0.1ms * 1000 = 0.1s in total, maybe i'm still missing something...

I am doing a more precious & reproducible benchmark with those suggestions applied

@vadiklyutiy
Copy link
Contributor Author

and also pass image data URLs instead of downloading image via http URLs

Behaviour with the same image for 1000 requests and different images for 1000 requests is a bit different. So, different images was intentionally implemented in diff.
In perform results above I made 21 runs for before and after and took median. Also checked with nsys. _prepare_inputs takes around 5% before and around 3% after what pretty good corresponding E2E results.

@imkero
Copy link
Contributor

imkero commented May 5, 2025

Alright i forgot multiplying the output tokens num.

Just benchmarking piecewise for 200,000 times get_next_input_positions_tensor takes nearly 1sec so it does make sense i think

@vadiklyutiy
Copy link
Contributor Author

, I suggest setting --input-token-distribution uniform 200 200 and --output-token-distribution uniform 200 200

Agree, this make results more stable

@vadiklyutiy
Copy link
Contributor Author

vadiklyutiy commented May 5, 2025

And maybe please take a look on #16881 ? It optimize the rest part (new request) of MRoPE position preparation, and i will take get_next_input_positions_tensor into consideration because they are reported to be time-costing as well.

Thanks @vadiklyutiy for reporting this with benchmark

If numpy+numba speed ups get_next_input_positions_tensor by >10x we can definitely waive caching. But my intuition don't provide any estimation what speed up we will get :)
If you can give some numbers for speed up of your implementation of get_next_input_positions_tensor it would great.

@imkero
Copy link
Contributor

imkero commented May 5, 2025

And maybe please take a look on #16881 ? It optimize the rest part (new request) of MRoPE position preparation, and i will take get_next_input_positions_tensor into consideration because they are reported to be time-costing as well.
Thanks @vadiklyutiy for reporting this with benchmark

If numpy+numba speed ups get_next_input_positions_tensor by >10x we can definitely waive caching. But my intuition don't provide any estimation what speed up we will get :) If you can give some numbers for speed up of your implementation of get_next_input_positions_tensor it would great.

a simple piecewise benchmark script here:

import timeit
import torch
import numpy as np
import numba
from functools import lru_cache

def mrope_get_next_input_positions_tensor(
    mrope_position_delta: int,
    context_len: int,
    seq_len: int,
) -> torch.Tensor:
    return torch.arange(
        mrope_position_delta + context_len,
        mrope_position_delta + seq_len,
    ).expand(3, -1)

@lru_cache(1024)
def mrope_get_next_input_positions_tensor_lru(
    mrope_position_delta: int,
    context_len: int,
    seq_len: int,
) -> torch.Tensor:
    return torch.arange(
        mrope_position_delta + context_len,
        mrope_position_delta + seq_len,
    ).expand(3, -1)

@numba.jit(nopython=True)
def mrope_assign_next_input_positions(
    out: np.ndarray,
    out_offset: int,
    mrope_position_delta: int,
    context_len: int,
    num_new_tokens: int,
):
    for dim in range(3):
        for idx in range(num_new_tokens):
            out[dim,
                out_offset + idx] = mrope_position_delta + context_len + idx


out = torch.empty(3, 1000, dtype=torch.int64)
out_np = out.numpy() # they shares the underlying data

def run_torch():
    out_offset = 5
    mrope_position_delta = 100
    context_len = 20
    seq_len = 21

    positions = mrope_get_next_input_positions_tensor(mrope_position_delta, context_len, seq_len)
    out[:, out_offset:out_offset + (seq_len - context_len)] = positions

def run_torch_lru():
    out_offset = 5
    mrope_position_delta = 100
    context_len = 20
    seq_len = 21

    positions = mrope_get_next_input_positions_tensor_lru(mrope_position_delta, context_len, seq_len)
    out[:, out_offset:out_offset + (seq_len - context_len)] = positions

def run_np():
    out_offset = 5
    mrope_position_delta = 100
    context_len = 20
    seq_len = 21
    
    mrope_assign_next_input_positions(out_np, out_offset, mrope_position_delta, context_len, seq_len - context_len)

run_torch()
run_torch_lru()
run_np()

r1 = timeit.timeit(run_torch, number=200000)
print(f"run_torch: {r1:.3f}s")

r2 = timeit.timeit(run_np, number=200000)
print(f"run_numba: {r2:.3f}s")

r3 = timeit.timeit(run_torch_lru, number=200000)
print(f"run_torch_lru: {r3:.3f}s")

And the result is here:

run_torch: 1.990s
run_numba: 0.066s
run_torch_lru: 0.872s

@imkero
Copy link
Contributor

imkero commented May 5, 2025

I have confirmed some speedup with this PR's changes

And it seems consistent with the piecewise benchmark result above (torch vs torch lru cached)

Tested on NVIDIA A10

vllm serve Qwen/Qwen2.5-VL-3B-Instruct --disable-log-requests --max-num-seqs 128 --block-size 16 --max-num-batched-tokens 8192 --max-model-len 8192 --no-enable-prefix-caching
fib benchmark -rps 50 --input-token-distribution uniform 200 200 --output-token-distribution uniform 200 200 --num-of-imgs-per-req 1 --img-ratios-per-req 512x512 -n 1024 --base-url http://localhost:8000 --endpoint v1/chat/completions --backend openai-chat
code base duration(s) req/s input tokens/s output tokens/s
main branch 172.86 5.92 3231.29 1184.74
PR #17617 171.74 5.96 3252.25 1192.53

@vadiklyutiy
Copy link
Contributor Author

I have confirmed some speedup with this PR's changes

Tested on NVIDIA A10

vllm serve Qwen/Qwen2.5-VL-3B-Instruct --disable-log-requests --max-num-seqs 128 --block-size 16 --max-num-batched-tokens 8192 --max-model-len 8192 --no-enable-prefix-caching
fib benchmark -rps 50 --input-token-distribution uniform 200 200 --output-token-distribution uniform 200 200 --num-of-imgs-per-req 1 --img-ratios-per-req 512x512 -n 1024 --base-url http://localhost:8000 --endpoint v1/chat/completions --backend openai-chat

code base duration(s) req/s input tokens/s output tokens/s
main branch 172.86 5.92 3231.29 1184.74
PR #17617 171.74 5.96 3252.25 1192.53

I tested with H100. reqs/s for H100 is around 30. So, with A10 model itself takes longer. _prepare_input including get_next_input_positions_tensor is CPU part and time of _prepare_input doesn't depend what GPU we have. So, the less improvement(in %) with A10 fit with the theory.

@imkero
Copy link
Contributor

imkero commented May 5, 2025

I have confirmed some speedup with this PR's changes

Tested on NVIDIA A10

vllm serve Qwen/Qwen2.5-VL-3B-Instruct --disable-log-requests --max-num-seqs 128 --block-size 16 --max-num-batched-tokens 8192 --max-model-len 8192 --no-enable-prefix-caching
fib benchmark -rps 50 --input-token-distribution uniform 200 200 --output-token-distribution uniform 200 200 --num-of-imgs-per-req 1 --img-ratios-per-req 512x512 -n 1024 --base-url http://localhost:8000 --endpoint v1/chat/completions --backend openai-chat

code base duration(s) req/s input tokens/s output tokens/s
main branch 172.86 5.92 3231.29 1184.74
PR #17617 171.74 5.96 3252.25 1192.53

I tested with H100. reqs/s for H100 is around 30. So, with A10 model itself takes longer. _prepare_input including get_next_input_positions_tensor is CPU part and time of _prepare_input doesn't depend what GPU we have. So, the less improvement(in %) with A10 fit with the theory.

Yes, I think for CPU overhead focusing on the duration section would make sense (-1.1 seconds)

@vadiklyutiy
Copy link
Contributor Author

And the result is here:

run_torch: 1.990s
run_numba: 0.066s
run_torch_lru: 0.872s

Results look awesome! @imkero Thank you for adding get_next_input_positions_tensor in #16881. Waiting of merge of #16881

@vadiklyutiy vadiklyutiy marked this pull request as draft May 5, 2025 11:29
@imkero
Copy link
Contributor

imkero commented May 5, 2025

Results look awesome! @imkero Thank you for adding get_next_input_positions_tensor in #16881. Waiting of merge of #16881

Pretty thanks for pointing out the time comsumption of get_next_input_positions_tensor.

@mergify mergify bot added the qwen Related to Qwen models label Jun 19, 2025
Copy link

mergify bot commented Jun 25, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @vadiklyutiy.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jun 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs-rebase qwen Related to Qwen models
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants