Skip to content

Conversation

ProExpertProg
Copy link
Collaborator

@ProExpertProg ProExpertProg commented Jun 19, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

This PR refactors FP8 quantization kernels to use the CustomOp abstraction, allowing Inductor to generate fast(er) Triton kernels and automatically perform fusion with RMSNorm and SiluMul (already implemented as CustomOps). This gives significant speedups, demonstrated below.

All forward pass code for dense linear layers instantiates QuantFP8 inside layer's __init__ method and calls it instead of calling custom_ops.scaled_fp8_quant directly. Non-forward code (e.g. weight quantization during process_weights_after_loading) still uses the CUDA kernel as compilation is not worth it for a single execution. This could be changed in the future if necessary.

I also moved the GroupShape utility from fusion.py to quant_utils.py as it's now more widely used, and fixed up the manual fusion tests (which might now have to enable the fp8 custom op).

MoE, attention layers, and INT8 quantization can be done in the future. MoE specifically is difficult because the call to scaled_fp8_quant is nested in many levels of free methods, so all of those would need to become objects. Attention will require a custom compilation utility as the call to the op is hidden from torch.compile inside the unified_attention custom op.

Test Plan

Running lm_eval on static per-tensor and dynamic per-token manually, and CI.

For performance, I ran a a serving sweep for dynamic per-token and static per-tensor, and a detailed latency sweep for dynamic per-token

Test Result

For dynamic per-token quantization, I ran a full latency sweep with all combinations of custom ops enabled/disabled. When QuantFP8 is disabled and (at least) one of RMSNorm/SiluMul is disabled, custom fusion passes can run, so I tried those configs with and without fusion as well.

Speedup of various configurations versus all custom ops enabled below. Note that custom-fp8 is the default on main (fp8 custom op enabled, others disabled). This is all run on a B200 machine.

Serving sweep for redhatai/meta-llama-3.1-8B-Instruct-FP8

📊 TTFT Median (ms)

Source 1 5 10 15 20
custom-fp8 22.32 23.91 21.21 26.86 29.01
torch 21.71 22.53 20.40 25.36 27.63

📊 ITL Median (ms)

Source 1 5 10 15 20
custom-fp8 4.59 4.91 5.15 5.36 5.85
torch 4.35 4.51 4.72 4.91 5.34

📊 TPOT Median (ms)

Source 1 5 10 15 20
custom-fp8 4.60 5.04 5.29 6.14 7.12
torch 4.35 4.65 4.85 5.55 6.40

Serving sweep for redhatai/meta-llama-3.1-8B-Instruct-FP8-dynamic

📊 TTFT Median (ms)

Source 1 5 10 15 20
fp8-custom 19.33 21.66 18.18 25.38 26.91
fp8-torch 18.56 20.99 18.21 24.33 26.13

📊 ITL Median (ms)

Source 1 5 10 15 20
fp8-custom 4.53 5.04 5.23 5.52 5.90
fp8-torch 4.54 5.00 5.17 5.39 5.80

📊 TPOT Median (ms)

Source 1 5 10 15 20
fp8-custom 4.62 5.17 5.40 6.35 7.18
fp8-torch 4.59 5.13 5.30 6.13 6.97

Serving sweeps on H100 for dynamic/static models

📊 TTFT Median (ms)

Source 1 5 10 15 20
dynamic-custom-fp8 39.59 39.20 44.82 48.96 56.20
dynamic-torch 39.64 41.92 44.61 47.36 53.46
static-custom-fp8 39.58 39.35 42.20 47.31 55.70
static-torch 35.99 39.87 43.37 46.13 52.46

📊 TPOT Median (ms)

Source 1 5 10 15 20
dynamic-custom-fp8 5.57 6.09 7.75 9.35 12.08
dynamic-torch 5.35 5.89 7.35 8.77 11.19
static-custom-fp8 5.36 5.88 7.34 8.87 11.41
static-torch 4.98 5.49 6.81 8.12 10.30

📊 ITL Median (ms)

Source 1 5 10 15 20
dynamic-custom-fp8 5.43 5.69 6.27 6.90 7.80
dynamic-torch 5.23 5.45 6.01 6.60 7.32
static-custom-fp8 5.23 5.48 5.99 6.63 7.39
static-torch 4.89 5.13 5.61 6.16 6.83

latency-h100-fp8-custom

Latency sweep for redhatai/meta-llama-3.1-8B-Instruct-FP8-dynamic

We can see that torch outperforms all other implementations, including custom-fp8 (current default on main).

Speedup Comparison (vs. custom-all): decode (1 input token, 64 output tokens, batch-size 16-256)

config 16 64 128 256
custom-all-fused 1.07713 0.966885 0.974104 1.01833
custom-fp8 1.04612 1.0109 1.09161 1.14704
custom-fp8-rms 1.04737 1.01149 1.04728 1.06727
custom-fp8-rms-fused 1.08624 1.04375 1.11073 1.13944
custom-fp8-silu 1.00758 0.981558 1.02474 1.04292
custom-fp8-silu-fused 0.996788 0.990235 1.01944 1.06877
custom-rms 0.89273 0.848042 0.981708 1.02122
custom-rms-silu 1.09098 0.993617 1.00775 1.02594
custom-silu 1.08374 1.01851 1.07541 1.07973
torch 1.10491 1.06615 1.12593 1.16359

Speedup Comparison (vs. custom-all): mixe (1024 input tokens, 64 output tokens, batch-size 4-128)

config 4 10 16 64 128
custom-all-fused 1.01229 1.05216 0.975955 0.99673 0.992197
custom-fp8 1.09964 1.10017 1.0466 1.07014 1.05651
custom-fp8-rms 1.07246 1.08611 1.04403 1.06538 1.05176
custom-fp8-rms-fused 1.10617 1.08933 1.04709 1.0521 1.03928
custom-fp8-silu 1.07454 1.06007 1.00081 1.02444 1.00908
custom-fp8-silu-fused 1.05526 1.05346 0.998359 1.02058 1.00661
custom-rms 0.968318 0.998547 0.987231 1.04925 1.05135
custom-rms-silu 1.03398 1.04538 0.980946 1.01276 1.00726
custom-silu 1.10988 1.09457 1.03808 1.04239 1.02976
torch 1.14776 1.14311 1.07934 1.09905 1.08432

Speedup Comparison (vs. custom-all): prefill (512-2048 input tokens, 1 output token, batch-size 1)

config 512 2048
custom-all-fused 1.04529 0.998425
custom-fp8 1.03066 1.04794
custom-fp8-rms 0.995576 1.06126
custom-fp8-rms-fused 1.06914 1.02935
custom-fp8-silu 0.934619 0.999869
custom-fp8-silu-fused 1.05147 0.996881
custom-rms 0.963685 1.05201
custom-rms-silu 1.083 1.0228
custom-silu 1.04159 1.02152
torch 1.07182 1.09127

lm_eval

Dynamic per-token (CUDA kernel)

# Serve command
$ vllm serve redhatai/meta-llama-3.1-8B-Instruct-FP8-dynamic -O.custom_ops+="+quant_fp8"

# Results
local-completions (pretrained=redhatai/meta-llama-3.1-8B-Instruct-FP8-dynamic,base_url=http://0.0.0.0:6969/v1/completions,num_concurrent=50,max_retries=3), gen_kwargs: (None), limit: 100.0, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|| 0.76|±  |0.0429|
|     |       |strict-match    |     5|exact_match|| 0.73|±  |0.0446|

Dynamic per-token (torch implementation)

# Serve command
$ vllm serve redhatai/meta-llama-3.1-8B-Instruct-FP8-dynamic

# Results
local-completions (pretrained=redhatai/meta-llama-3.1-8B-Instruct-FP8-dynamic,base_url=http://0.0.0.0:6969/v1/completions,num_concurrent=50,max_retries=3), gen_kwargs: (None), limit: 100.0, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|| 0.74|±  |0.0441|
|     |       |strict-match    |     5|exact_match|| 0.72|±  |0.0451|

Static per-tensor (CUDA kernel)

# Serve command
$ vllm serve redhatai/meta-llama-3.1-8B-Instruct-FP8 -O.custom_ops+="+quant_fp8"

# Results
local-completions (pretrained=redhatai/meta-llama-3.1-8B-Instruct-FP8,base_url=http://0.0.0.0:6969/v1/completions,num_concurrent=50,max_retries=3), gen_kwargs: (None), limit: 100.0, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|| 0.74|±  |0.0441|
|     |       |strict-match    |     5|exact_match|| 0.73|±  |0.0446|

Static per-tensor (torch implementation)

# Serve command
$ vllm serve redhatai/meta-llama-3.1-8B-Instruct-FP8

# Results
local-completions (pretrained=redhatai/meta-llama-3.1-8B-Instruct-FP8,base_url=http://0.0.0.0:6969/v1/completions,num_concurrent=50,max_retries=3), gen_kwargs: (None), limit: 100.0, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|| 0.75|±  |0.0435|
|     |       |strict-match    |     5|exact_match|| 0.73|±  |0.0446|

(Optional) Documentation Update

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Jun 19, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @ProExpertProg, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly refactors and centralizes the FP8 quantization logic within the codebase. By introducing a new QuantFP8 abstraction, it provides a more unified, configurable, and maintainable approach to handling 8-bit floating-point operations, which are crucial for efficient model execution. The changes also include a new benchmark to validate the performance and correctness of these FP8 operations.

Highlights

  • New FP8 Quantization Abstraction: Introduced a new QuantFP8 class in vllm/model_executor/layers/fp8_quantization.py. This class provides a unified and configurable interface for FP8 quantization, supporting both per-token and per-tensor dynamic/static quantization. It encapsulates the logic for both native (PyTorch) and CUDA implementations.
  • Refactoring and Integration: Refactored existing FP8 quantization calls across various modules (e.g., Fp8LinearOp, attention backends, and different quantization schemes like FBGEMM, ModelOpt, PTPC, Quark) to utilize the new QuantFP8 class. This replaces direct calls to ops.scaled_fp8_quant with calls to instances of QuantFP8, making the quantization logic more modular and consistent.
  • Improved Configurability: Enhanced the configurability of FP8 quantization strategies by introducing parameters like act_quant_static (for static vs. dynamic quantization) and act_quant_group_shape (for per-token vs. per-tensor grouping) to the Fp8LinearOp and QuantFP8 classes. This allows for more flexible and optimized FP8 usage based on hardware capabilities and model requirements.
  • New FP8 Quantization Benchmark: Added a new benchmark script (benchmarks/kernels/bench_per_token_quant_fp8.py) to compare and measure the performance of per-token FP8 quantization using both PyTorch's native implementation and CUDA operations. This helps in verifying correctness and assessing performance gains.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new QuantFP8 class to encapsulate FP8 quantization logic and refactors existing code to use this class. The changes improve code organization and prepare for more flexible quantization schemes. However, there are several high and critical severity issues related to the correct initialization and usage of the new QuantFP8 class in various parts of the codebase, particularly in the attention backends and one of the compressed tensors schemes. Additionally, the QuantFP8 class itself has limitations in its native implementation and an assertion that might be too restrictive depending on supported quantization types. Addressing these issues is crucial for correctness and stability.

Copy link

mergify bot commented Jun 20, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ProExpertProg.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jun 20, 2025
@mergify mergify bot added performance Performance-related issues and removed needs-rebase labels Jun 20, 2025
@ProExpertProg ProExpertProg force-pushed the luka/fp8-custom-ops branch 2 times, most recently from 2442db7 to b284931 Compare July 8, 2025 18:22
@ProExpertProg ProExpertProg marked this pull request as ready for review July 8, 2025 21:48
@ProExpertProg ProExpertProg changed the title FP8 custom ops [Perf][fp8] Use CustomOp abstraction for fp8 quant for better perf Jul 8, 2025
@ProExpertProg
Copy link
Collaborator Author

@gemini-code-assist review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a CustomOp abstraction for FP8 quantization, which is a significant improvement for performance and code structure. The refactoring is extensive, touching many files, but it is applied consistently and effectively. The new QuantFP8 op and the updated Fp8LinearOp make the quantization logic cleaner and more maintainable. The performance benchmarks in the PR description are very detailed and clearly demonstrate the benefits of this change.

I have a couple of minor suggestions in the new benchmark file to improve code readability by avoiding shadowing Python built-in functions. Overall, this is a high-quality contribution.

@LucasWilkinson LucasWilkinson self-requested a review July 9, 2025 01:44
Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love to see GroupShape used more widely and just the general clean-up! And a nice perf boost. Thanks for doing this!

Signed-off-by: Luka Govedic <lgovedic@redhat.com>
Signed-off-by: Luka Govedic <lgovedic@redhat.com>
Signed-off-by: Luka Govedic <lgovedic@redhat.com>
Signed-off-by: Luka Govedic <lgovedic@redhat.com>
Signed-off-by: Luka Govedic <lgovedic@redhat.com>
Signed-off-by: Luka Govedic <lgovedic@redhat.com>
@mgoin mgoin enabled auto-merge (squash) July 10, 2025 21:19
@mgoin mgoin merged commit 31d5c17 into vllm-project:main Jul 11, 2025
76 checks passed
Chen-zexi pushed a commit to Chen-zexi/vllm that referenced this pull request Jul 13, 2025
…llm-project#19830)

Signed-off-by: Luka Govedic <lgovedic@redhat.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
patrickvonplaten pushed a commit to patrickvonplaten/vllm that referenced this pull request Jul 15, 2025
…llm-project#19830)

Signed-off-by: Luka Govedic <lgovedic@redhat.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
LyrisZhong pushed a commit to LyrisZhong/vllm that referenced this pull request Jul 23, 2025
…llm-project#19830)

Signed-off-by: Luka Govedic <lgovedic@redhat.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
avigny pushed a commit to avigny/vllm that referenced this pull request Jul 31, 2025
…llm-project#19830)

Signed-off-by: Luka Govedic <lgovedic@redhat.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
Signed-off-by: avigny <47987522+avigny@users.noreply.github.com>
Pradyun92 pushed a commit to Pradyun92/vllm that referenced this pull request Aug 6, 2025
…llm-project#19830)

Signed-off-by: Luka Govedic <lgovedic@redhat.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
…llm-project#19830)

Signed-off-by: Luka Govedic <lgovedic@redhat.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
…llm-project#19830)

Signed-off-by: Luka Govedic <lgovedic@redhat.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
…llm-project#19830)

Signed-off-by: Luka Govedic <lgovedic@redhat.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Paul Pak <paulpak58@gmail.com>
taneem-ibrahim pushed a commit to taneem-ibrahim/vllm that referenced this pull request Aug 14, 2025
…llm-project#19830)

Signed-off-by: Luka Govedic <lgovedic@redhat.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
…llm-project#19830)

Signed-off-by: Luka Govedic <lgovedic@redhat.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Diego-Castan <diego.castan@ibm.com>
rojagtap added a commit to rojagtap/vllm that referenced this pull request Aug 24, 2025
Purpose: vllm-project#19830 added QuantFp8, which uses the CustomOp abstraction to implement fp8 quantization in both CUDA and torch, allowing Inductor to achieve superior performance over the CUDA ops (which are unoptimized and also do not fuse by default). However, the class has to be instantiated during init, and MoE uses are currently in util free functions many levels deep. Those need to be mildly rearchitected to take advantage of the new abstraction.

The main changes include:
- Add a MoEInputQuantizer class with init/forward (and callable) interface for FP8/INT8/NVFP4/MXFP4 activation quantization.
	- Make the class abstract out details on decision of the quant fp8 op to be used

Signed-off-by: Rohan Jagtap <rohanj30@icloud.com>
rojagtap added a commit to rojagtap/vllm that referenced this pull request Aug 24, 2025
Purpose: vllm-project#19830 added QuantFp8, which uses the CustomOp abstraction to implement fp8 quantization in both CUDA and torch, allowing Inductor to achieve superior performance over the CUDA ops (which are unoptimized and also do not fuse by default). However, the class has to be instantiated during init, and MoE uses are currently in util free functions many levels deep. Those need to be mildly rearchitected to take advantage of the new abstraction.

The main changes include:
- Add a MoEInputQuantizer class with init/forward (and callable) interface for FP8/INT8/NVFP4/MXFP4 activation quantization.
	- Make the class abstract out details on decision of the quant fp8 op to be used

Signed-off-by: Rohan Jagtap <rohanj30@icloud.com>
rojagtap added a commit to rojagtap/vllm that referenced this pull request Aug 24, 2025
Purpose: vllm-project#19830 added QuantFp8, which uses the CustomOp abstraction to implement fp8 quantization in both CUDA and torch, allowing Inductor to achieve superior performance over the CUDA ops (which are unoptimized and also do not fuse by default). However, the class has to be instantiated during init, and MoE uses are currently in util free functions many levels deep. Those need to be mildly rearchitected to take advantage of the new abstraction.

The changes in this commit include:
- Adapt the `MoEPrepareAndFinalizeNoEP` class to the custom op wrapper changes

Signed-off-by: Rohan Jagtap <rohanj30@icloud.com>
rojagtap added a commit to rojagtap/vllm that referenced this pull request Aug 24, 2025
Purpose: vllm-project#19830 added QuantFp8, which uses the CustomOp abstraction to implement fp8 quantization in both CUDA and torch, allowing Inductor to achieve superior performance over the CUDA ops (which are unoptimized and also do not fuse by default). However, the class has to be instantiated during init, and MoE uses are currently in util free functions many levels deep. Those need to be mildly rearchitected to take advantage of the new abstraction.

The changes in this commit include:
- Refactor docstrings to use backticks for code
- refactor formatting for certain methods to match other calls

Signed-off-by: Rohan Jagtap <rohanj30@icloud.com>
rojagtap added a commit to rojagtap/vllm that referenced this pull request Aug 24, 2025
Purpose: vllm-project#19830 added QuantFp8, which uses the CustomOp abstraction to implement fp8 quantization in both CUDA and torch, allowing Inductor to achieve superior performance over the CUDA ops (which are unoptimized and also do not fuse by default). However, the class has to be instantiated during init, and MoE uses are currently in util free functions many levels deep. Those need to be mildly rearchitected to take advantage of the new abstraction.

The changes in this commit include:
- Minor refactoring to keep code consistent

Signed-off-by: Rohan Jagtap <rohanj30@icloud.com>
rojagtap added a commit to rojagtap/vllm that referenced this pull request Aug 24, 2025
Purpose: vllm-project#19830 added QuantFp8, which uses the CustomOp abstraction to implement fp8 quantization in both CUDA and torch, allowing Inductor to achieve superior performance over the CUDA ops (which are unoptimized and also do not fuse by default). However, the class has to be instantiated during init, and MoE uses are currently in util free functions many levels deep. Those need to be mildly rearchitected to take advantage of the new abstraction.

The changes in this commit include:
- Fix code smell for unhandled None object for quantizer

Signed-off-by: Rohan Jagtap <rohanj30@icloud.com>
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 27, 2025
…llm-project#19830)

Signed-off-by: Luka Govedic <lgovedic@redhat.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
googlercolin pushed a commit to googlercolin/vllm that referenced this pull request Aug 29, 2025
…llm-project#19830)

Signed-off-by: Luka Govedic <lgovedic@redhat.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants