Skip to content

Conversation

yewentao256
Copy link
Collaborator

@yewentao256 yewentao256 commented Jun 5, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results

Motivation

The scaled_int8_quant kernel family still relied on scalar loads/stores, leaving significant throughput untapped. Issue #18866 highlighted vectorization as the next major opportunity. Thank @mgoin for the great issue!

And thank @ztang2370 for the good start #19062, #19109 by @mgoin for the great incremental compilation document.

What’s in this PR

  1. Kernel vectorization (VEC_SIZE = 16)
    • static_scaled_int8_quant_kernel
    • static_scaled_int8_azp_quant_kernel
    • dynamic_scaled_int8_quant_kernel
    • dynamic_scaled_int8_azp_quant_kernel

Make a abstraction in csrc/quantization/vectorization_utils.cuh so that this func can be reused later.

  1. Zero-point refactor

    • Merges the min/max scan with the quantization loop to cut one global-memory pass.
  2. Benchmark tooling

    • Adds benchmarks/kernels/bench_int8_gemm.py for side-by-side BF16 vs. INT8 micro-benchmarks.

Test

Tested on H100:

Accuracy validation: lm_eval --model vllm --model_args pretrained=RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8,max_model_len=32768 --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto

before:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match||0.7733|±  |0.0115|
|     |       |strict-match    |     5|exact_match||0.7604|±  |0.0118|

after:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match||0.7733|±  |0.0115|
|     |       |strict-match    |     5|exact_match||0.7604|±  |0.0118|

End to end throughput: vllm bench throughput --model RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8 --load-format dummy --input-len 1000 --output-len 100 --max-model-len 32768

before:

Throughput: 35.31 requests/s, 38822.03 total tokens/s, 3531.41 output tokens/s
Throughput: 34.61 requests/s, 38026.74 total tokens/s, 3461.44 output tokens/s

after:

Throughput: 37.93 requests/s, 41685.29 total tokens/s, 3792.61 output tokens/s
Throughput: 38.57 requests/s, 42403.11 total tokens/s, 3856.95 output tokens/s

Kernel flops test with bench_int8_gemm.py

Before:

batch_size torch-bf16 int8-tensor-w-tensor-a int8-channel-w-token-a int8-tensor-w-tensor-a-noquant int8-channel-w-token-a-noquant
1 3.03654 3.92181 3.60902 5.25341 5.21451
16 57.3562 63.2098 59.7113 86.128 87.9333
64 234.556 253.902 238.535 352.491 356.684
128 285.766 384.991 367.512 487.873 496.294
256 462.462 531.309 500.083 639.285 648.289
512 702.442 865.719 790.126 1251.02 1253.05
1024 720.504 903.559 814.465 1280.2 1322.11
2048 689.227 936.958 849.424 1318.31 1303.01
4096 723.089 965.33 883.938 1527.99 1404.14
8192 742.027 949.993 872.24 1470.08 1351.67
16384 701.668 953.361 877.533 1433.51 1332.62

After:

batch_size torch-bf16 int8-tensor-w-tensor-a int8-channel-w-token-a int8-tensor-w-tensor-a-noquant int8-channel-w-token-a-noquant
1.00 3.03 3.87 2.98 5.29 5.23
16.00 57.55 62.83 48.69 86.51 88.29
64.00 226.81 250.57 192.12 354.63 358.78
128.00 283.94 383.05 313.06 499.55 503.99
256.00 463.65 525.42 450.55 641.67 647.82
512.00 672.04 934.00 769.39 1265.80 1266.29
1024.00 684.19 1001.14 841.92 1321.50 1301.72
2048.00 707.74 1071.18 873.07 1311.56 1296.30
4096.00 728.31 1111.62 913.55 1530.56 1468.94
8192.00 743.42 1094.42 902.95 1475.45 1432.45
16384.00 720.14 1102.09 917.71 1481.84 1440.72

Unit test:

pytest test_int8_quant.py 
============================== test session starts ==============================
platform linux -- Python 3.12.11, pytest-8.4.0, pluggy-1.6.0
rootdir: /home/wentao/vllm-source
configfile: pyproject.toml
plugins: anyio-4.9.0
collected 1346 items                                                            

test_int8_quant.py ...................................................... [  4%]
......................................................................... [  9%]
......................................................................... [ 14%]
......................................................................... [ 20%]
......................................................................... [ 25%]
......................................................................... [ 31%]
......................................................................... [ 36%]
......................................................................... [ 41%]
......................................................................... [ 47%]
......................................................................... [ 52%]
......................................................................... [ 58%]
......................................................................... [ 63%]
......................................................................... [ 69%]
......................................................................... [ 74%]
......................................................................... [ 79%]
......................................................................... [ 85%]
......................................................................... [ 90%]
......................................................................... [ 96%]
...................................................                       [100%]

======================= 1346 passed in 252.81s (0:04:12) ========================

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @yewentao256, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

Summary of Changes

Hello everyone, gemini-code-assist here to provide a summary of this pull request. This PR focuses on performance optimizations for the int8 quantization kernels within vLLM. The primary goal is to improve the throughput and efficiency of models using int8 quantization, addressing issue #18866. The changes involve implementing vectorization techniques in the CUDA kernels responsible for quantization and adding a new benchmark script to measure the performance gains of these kernels.

Highlights

  • Kernel Vectorization: The core optimization involves applying vectorization using vec_n_t (specifically VEC_SIZE = 16) to the CUDA kernels for static and dynamic int8 quantization (static_scaled_int8_quant_kernel, dynamic_scaled_int8_quant_kernel, dynamic_scaled_int8_azp_quant_kernel). This allows the kernels to process multiple elements simultaneously, improving data throughput on the GPU.
  • New Benchmark Script: A new benchmark script (benchmarks/kernels/bench_int8_gemm.py) has been added. This script uses Triton to compare the performance (TFLOP/s) of BF16 GEMM against various int8 GEMM configurations (per-tensor, per-channel, with/without activation quantization time) across different batch sizes.
  • Performance Improvements: The PR description includes benchmark results showing notable improvements. End-to-end throughput increased from ~35 req/s to ~40 req/s for the tested model. Kernel benchmarks also show increased TFLOP/s for the optimized int8 kernels, particularly at larger batch sizes.

Changelog

  • benchmarks/kernels/bench_int8_gemm.py
    • Added a new Triton benchmark script to compare BF16 and int8 GEMM kernel performance.
    • Includes benchmarking for different int8 quantization schemes (per-tensor, per-channel) and activation quantization timings.
    • Configured to run benchmarks for specified models and tensor parallel sizes.
  • csrc/quantization/compressed_tensors/int8_quant_kernels.cu
    • Included ../vectorization.cuh (line 6).
    • Implemented vectorization (VEC_SIZE = 16) in static_scaled_int8_quant_kernel for improved data processing (lines 111-142).
    • Implemented vectorization (VEC_SIZE = 16) in dynamic_scaled_int8_quant_kernel for the absmax scan and quantization loop (lines 194-251).
    • Added MinMax struct and MinMaxOp for efficient min/max reduction in dynamic AZP kernel (lines 254-263).
    • Implemented vectorization (VEC_SIZE = 16) in dynamic_scaled_int8_azp_quant_kernel for the min/max scan and quantization loop (lines 269-348).
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.


Kernels run fast,
Int8 math, a quick blast,
Vector lanes wide,
Performance takes stride,
Latency's shadow is cast.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant performance optimizations to the int8 quantization kernels, backed by comprehensive benchmark results showing notable throughput improvements without accuracy degradation. The core changes involve vectorizing CUDA kernels, which is a standard and effective approach for performance enhancement on GPUs.

The new benchmark script (bench_int8_gemm.py) is a valuable addition for quantifying these improvements across various configurations.

The CUDA kernel modifications in int8_quant_kernels.cu appear solid, incorporating vectorization, pre-calculation of inverse scales, and optimized reduction strategies (like combined Min/Max). These changes are well-aligned with the reported performance gains.

One area for potential clarification/improvement is within the benchmark script concerning the setup for "per-channel" weight quantization scenarios. This is detailed in a specific comment.

Overall, great work on these optimizations!

Summary of Findings

  • Benchmark Clarity for Per-Channel Quantization: The benchmark setup for 'per-channel' weight quantization modes (e.g., int8-channel-w-token-a) appears to use a single scale factor for quantizing weights, which is then broadcast to a per-channel format for the matrix multiplication. This might not reflect true per-channel quantization where weights are quantized with distinct scales per channel. Clarification on whether this setup is intentional or if it should be adjusted for more realistic per-channel weight quantization benchmarking would be beneficial.

Merge Readiness

The core kernel optimizations in this PR are well-implemented and show significant performance benefits. There is one medium-severity point regarding the setup of 'per-channel' quantization benchmarks that would be good to clarify or address to ensure the benchmarks accurately reflect the intended scenarios. Once this point is resolved, the PR should be in good shape for merging. I am unable to approve the pull request, so please have others review and approve this code before merging.

Copy link

github-actions bot commented Jun 5, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Copy link
Collaborator

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks reasonable to me.

Would leave this to our kernel expert: @chenyang78, @mgoin, @tlrmchlsmth, and @LucasWilkinson to review.

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work overall! I'd like to see the PR title and description to have more context on what exactly the changes are to achieve this "optimization", and what hardware it was tested on to give context to the performance numbers.

I also would like to see an expansion to the test cases in vllm/tests/kernels/quantization/test_int8_quant.py to included these vectorization test cases from test_fp8_quant.py

HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases

@@ -107,16 +108,37 @@ template <typename scalar_t, typename scale_type>
__global__ void static_scaled_int8_quant_kernel(
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
scale_type const* scale_ptr, const int hidden_size) {
int const VEC_SIZE = 16;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be best to use constexpr here, in case it matters to the compiler for the #pragma unroll later. At least I did this for fp8 i.e. constexpr size_t VEC_SIZE = 16;

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed, thanks! Made a common function vectorize_with_alignment for it to avoid this mistake

Comment on lines 305 to 309
// reduce the min and max values across the block in one go
using BlockReduce = cub::BlockReduce<MinMax, 1024>;
__shared__ typename BlockReduce::TempStorage reduce_storage;
MinMax block_min_max =
BlockReduce(reduce_storage).Reduce(thread_min_max, MinMaxOp());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice idea!

@yewentao256 yewentao256 changed the title [Perf] Optimizations for int8 quant kernels [Perf] Vectorize static / dynamic INT8 quant kernels (VEC_SIZE = 16) & zero-point refactor — +17 % e2e throughput on NVIDIA H100 Jun 6, 2025
@yewentao256 yewentao256 changed the title [Perf] Vectorize static / dynamic INT8 quant kernels (VEC_SIZE = 16) & zero-point refactor — +17 % e2e throughput on NVIDIA H100 [Perf] Vectorize static / dynamic INT8 quant kernels (VEC_SIZE = 16) & zero-point refactor — +15 % e2e throughput on NVIDIA H100 Jun 6, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
@yewentao256 yewentao256 requested a review from WoosukKwon as a code owner June 6, 2025 22:45
@yewentao256 yewentao256 changed the title [Perf] Vectorize static / dynamic INT8 quant kernels (VEC_SIZE = 16) & zero-point refactor — +15 % e2e throughput on NVIDIA H100 [Perf] Vectorize static / dynamic INT8 quant kernels (VEC_SIZE = 16) & zero-point refactor — +10 % e2e throughput on NVIDIA H100 Jun 6, 2025
@yewentao256
Copy link
Collaborator Author

yewentao256 commented Jun 6, 2025

Nice work overall! I'd like to see the PR title and description to have more context on what exactly the changes are to achieve this "optimization", and what hardware it was tested on to give context to the performance numbers.

I also would like to see an expansion to the test cases in vllm/tests/kernels/quantization/test_int8_quant.py to included these vectorization test cases from test_fp8_quant.py

HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases

@mgoin Thanks for the review, your insight is really valuable!

I did one step further and make a common function vectorize_with_alignment abstraction so that this could be used in somewhere else (next pr), please take a look.

@yewentao256 yewentao256 closed this Jun 6, 2025
@yewentao256 yewentao256 reopened this Jun 6, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall, just a few thoughts:

Comment on lines +67 to +71
# Dynamic per-token quant for A, static per-tensor quant for B
scale_b = torch.tensor([1.0], device=device, dtype=torch.float32)
b_int8, scale_b_int8, _ = vllm_scaled_int8_quant(b, scale_b)
assert scale_b_int8.numel() == 1
a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this is a bnechmark script but I think this could still be refactored. Perhaps a few functions/objects and a dictionary?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#19364

Great idea! I will have another PR optimizing this, because generally I reuse the code from benchmarks/kernels/bench_fp8_gemm.py and we can update them together

Comment on lines +94 to +95
def run_quant():
return vllm_scaled_mm(a_int8, b_int8, scale_a_int8, scale_b_int8, dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why doesn't run_quant include scaled_int8_quant here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because this is for the noquant branch, where we don't measure the time for activations quant (for comparison with the other branch)

Comment on lines 254 to 257
struct MinMax {
float min;
float max;
};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this utility but you could put a lot more of the code in here I think:

  • constructor: initialize to numeric_limits
  • operator&=/operator+= (or void reduce): combine MinMax with another MinMax, or add a value (updates both min and max members).

That way the code that uses it will be much cleaner.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great idea! Fixed

Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like the vectorization util! A few more comments

// 2. vectorize the main part
for (int i = tid; i < num_vec; i += stride) {
vout_t tmp;
vec_op(tmp, v_in[i]);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this just do the for loop here and call the vec_op? Or maybe vec_op can have a default parameter value DefaultVec<VEC_SIZE>{sca_op} that loops the sca_op?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DefaultVec<VEC_SIZE> looks better for me, fixed


// 1. prefill the when it is unsafe to vectorize
for (int i = tid; i < prefix_elems; i += stride) {
sca_op(out[i], in[i]);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you make these return vout_t/OutT instead of passing in a parameter?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure about this, is this functional kernel preferred (return something) or out as a param kernel preferred in vllm?

Copy link
Collaborator

@ProExpertProg ProExpertProg Jun 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer the return, it's more functional, looks more clear, and I think it's a better general c++/CUDA practice, and the compiler is nominally better at optimizing it (in general - here I think everything will be inlined anyway). So I always default to returning.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#19430

I agree with you, and raise another issue for community discussion for a clearer code standard. If the return format is preferred, we can adjust all of the out-pram code in a new pr.

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work!

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the changes to remove the int64 index calculation and replacing static_cast<float>() with float() intentional? I think we should keep these otherwise

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, great work iterating on this!

@mgoin mgoin added performance Performance-related issues quantization ready ONLY add when PR is ready to merge/full CI is needed labels Jun 10, 2025
@mgoin
Copy link
Member

mgoin commented Jun 10, 2025

@yewentao256 can you please merge with latest main and push to fix some of the failing tests?

@NickLucche
Copy link
Contributor

Can we keep performance gains off the title of the PR?

@tlrmchlsmth
Copy link
Collaborator

Nice work

@yewentao256 yewentao256 changed the title [Perf] Vectorize static / dynamic INT8 quant kernels (VEC_SIZE = 16) & zero-point refactor — +10 % e2e throughput on NVIDIA H100 [Perf] Vectorize static / dynamic INT8 quant kernels (VEC_SIZE = 16) & zero-point refactor Jun 12, 2025
@yewentao256 yewentao256 changed the title [Perf] Vectorize static / dynamic INT8 quant kernels (VEC_SIZE = 16) & zero-point refactor [Perf] Vectorize static / dynamic INT8 quant kernels Jun 12, 2025
@yewentao256
Copy link
Collaborator Author

Can we keep performance gains off the title of the PR?

Sure, updated, thanks!

@yewentao256
Copy link
Collaborator Author

Could you force merge into the main? @houseroad
I think this CI fail comes from this issue #18954 and can be reproduced in main branch

@vllm-bot vllm-bot merged commit b6efafd into vllm-project:main Jun 12, 2025
95 of 97 checks passed
minpeter pushed a commit to minpeter/vllm that referenced this pull request Jun 24, 2025
)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: minpeter <kali2005611@gmail.com>
xjpang pushed a commit to xjpang/vllm that referenced this pull request Jun 30, 2025
)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
wseaton pushed a commit to wseaton/vllm that referenced this pull request Jun 30, 2025
)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
avigny pushed a commit to avigny/vllm that referenced this pull request Jul 31, 2025
)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: avigny <47987522+avigny@users.noreply.github.com>
googlercolin pushed a commit to googlercolin/vllm that referenced this pull request Aug 29, 2025
)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Performance-related issues quantization ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants