Skip to content

Conversation

djmmoss
Copy link
Contributor

@djmmoss djmmoss commented Jul 3, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

Add GroupGEMM support for SM100

Test Plan

python -m pytest tests/kernels/moe/test_cutlass_moe.py

lm_eval --model vllm --model_args pretrained=/scratch/models/Llama-4-Maverick-17B-128E-Instruct-FP8,tensor_parallel_size=4,max_model_len=2048,gpu_memory_utilization=0.9,max_num_seqs=32 --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto

lm_eval --model vllm --model_args pretrained=/scratch/models/Mixtral-8x7B-Instruct-v0.1-FP8,tensor_parallel_size=4,max_model_len=2048,gpu_memory_utilization=0.9 --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto

Test Result

$ python -m pytest tests/kernels/moe/test_cutlass_moe.py 
================================================================================================= test session starts ==================================================================================================
platform linux -- Python 3.12.3, pytest-8.1.1, pluggy-1.5.0
rootdir: /home/scratch.dmoss_gpu_1/repos/vllm
configfile: pyproject.toml
plugins: anyio-4.8.0, hypothesis-6.127.9, flakefinder-1.1.0, rerunfailures-15.0, xdist-3.6.1, shard-0.1.2, xdoctest-1.0.2, typeguard-4.3.0
collected 426 items                                                                                                                                                                                                    
Running 426 items in this shard

tests/kernels/moe/test_cutlass_moe.py .......................................................................................................................................................................... [ 39%]
................................................................................................................................................................................................................ [ 88%]
................................................                                                                                                                                                                 [100%]

============================================================================================ 426 passed in 63.30s (0:01:03) ============================================================================================

vllm (pretrained=/scratch/models/Mixtral-8x7B-Instruct-v0.1-FP8,tensor_parallel_size=4,max_model_len=2048,gpu_memory_utilization=0.9,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.6126|±  |0.0134|
|     |       |strict-match    |     5|exact_match|↑  |0.6126|±  |0.0134|

vllm (pretrained=/scratch/models/Llama-4-Maverick-17B-128E-Instruct-FP8,tensor_parallel_size=4,max_model_len=2048,gpu_memory_utilization=0.9,max_num_seqs=32,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|_  |0.9189|_  |0.0075|
|     |       |strict-match    |     5|exact_match|_  |0.9212|_  |0.0074|

(Optional) Documentation Update

Copy link

github-actions bot commented Jul 3, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @djmmoss, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces comprehensive support for NVIDIA's SM100 (Blackwell) architecture for Grouped GEMM (General Matrix Multiply) operations within the CUTLASS library. This enhancement specifically targets FP8 quantized Mixture-of-Experts (MoE) models, expanding the hardware compatibility for efficient MoE inference and allowing vLLM to leverage the latest NVIDIA GPU capabilities.

Highlights

  • SM100 GroupGEMM Support: Added a new CUDA kernel (grouped_mm_c3x_sm100.cu) and associated build configurations to enable GroupGEMM operations on NVIDIA's SM100 (Blackwell) architecture, specifically for FP8 quantized Mixture-of-Experts (MoE) models.
  • Generalized CUTLASS Kernel: Refactored the core CUTLASS GroupGEMM template (grouped_mm_c3x.cuh) to be architecture-agnostic, allowing it to be used for both SM90 (Hopper) and SM100 (Blackwell) by passing the ArchTag as a template parameter.
  • Dynamic Kernel Dispatch & Build System Updates: Updated the C++ and Python logic to dynamically select the appropriate GroupGEMM kernel (SM90 or SM100) based on the detected GPU architecture and CUDA compiler version. CMakeLists.txt was modified to conditionally compile the SM100 kernel when CUDA 12.8 or newer is available and targeting SM100 architectures.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@mergify mergify bot added the ci/build label Jul 3, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for SM100 (Blackwell) architecture for CUTLASS grouped GEMM operations, which is essential for running FP8 quantized Mixture-of-Experts (MoE) models. The changes are comprehensive, touching CMake build files, CUDA C++ kernels, and Python-level logic.

My review has identified a couple of critical issues that need to be addressed. There is an incorrect runtime dispatch logic in one of the C++ files that could lead to calling the wrong kernel, and a logical error in the Python code for detecting supported quantization schemes due to operator precedence. I've provided specific suggestions to fix these issues. Addressing these points will ensure the new functionality is robust and correct.

Comment on lines 250 to 257
#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100
cutlass_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, per_act_token, per_out_ch);
return;
#endif
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The dispatch logic here is incorrect. It relies only on compile-time preprocessor directives (#if defined...) and doesn't use the runtime SM version (version_num) for dispatching. This will lead to runtime errors if the code is compiled for multiple architectures (e.g., both SM90 and SM100).

For instance, if both ENABLE_CUTLASS_MOE_SM100 and ENABLE_CUTLASS_MOE_SM90 are defined, this code will always try to call cutlass_moe_mm_sm100, even when running on an SM90 GPU, which will fail.

The dispatch logic should be based on the version_num queried at runtime, similar to how cutlass_scaled_mm is implemented in this file. Please update this block and the subsequent one for SM90 to use version_num for correct runtime dispatch.

#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100
  if (version_num >= 100) {
    cutlass_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
                         expert_offsets, problem_sizes, a_strides, b_strides,
                         c_strides, per_act_token, per_out_ch);
    return;
  }
#endif

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@djmmoss I think Gemini is actually correct here; we should also add runtime switches

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fair, I've added the check in

Comment on lines 332 to 335
return (
self._check_scheme_supported(90, error=False, match_exact=True)
or self._check_scheme_supported(100, error=False, match_exact=True)
and self._is_fp8_w8a8(weight_quant, input_quant))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There is a logical error in this condition due to Python's operator precedence, where and is evaluated before or.

The current logic is equivalent to:
self._check_scheme_supported(...) or (self._check_scheme_supported(...) and self._is_fp8_w8a8(...))

This means if SM90 is supported, the function will return True even if the quantization scheme is not w8a8, which is incorrect.

You should add parentheses to group the or condition to ensure the architecture check is performed before checking the quantization scheme.

Suggested change
return (
self._check_scheme_supported(90, error=False, match_exact=True)
or self._check_scheme_supported(100, error=False, match_exact=True)
and self._is_fp8_w8a8(weight_quant, input_quant))
return (
(self._check_scheme_supported(90, error=False, match_exact=True) or
self._check_scheme_supported(100, error=False, match_exact=True)) and
self._is_fp8_w8a8(weight_quant, input_quant))

Copy link
Collaborator

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work! Could you also add some benchmark comparison with Triton?

Comment on lines 84 to 87
elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant):
return CompressedTensorsW4A4MoeMethod()
elif quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant):
elif quant_config._is_fp8_w8a8_sm90_or_sm100(weight_quant,
input_quant):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be divided to _is_fp8_w8a8_sm90 and _is_fp8_w8a8_sm100 two branchs?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case we don't need to have complicated logic in _is_fp8_w8a8_sm90_or_sm100, and may be better for the future architecture.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed 👍

@mgoin
Copy link
Member

mgoin commented Jul 4, 2025

Could you also add some benchmark comparison with Triton?

Since there are no tuned configs in this PR, I expect this will give poor performance. @djmmoss do you plan on including those in this PR or a followup?

This seems reasonable to me at a kernel level

@mgoin mgoin added kernel ready ONLY add when PR is ready to merge/full CI is needed labels Jul 4, 2025
@djmmoss
Copy link
Contributor Author

djmmoss commented Jul 4, 2025

@mgoin we are planning to add the tuned configs to this PR, please hold off on merging until they are in place 👍

@djmmoss djmmoss requested a review from LucasWilkinson as a code owner July 7, 2025 17:26
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are planning to add the tuned configs to this PR, please hold off on merging until they are in place

Copy link

mergify bot commented Jul 9, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @djmmoss.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jul 9, 2025
@mgoin mgoin changed the title [feat]: add SM100 support for cutlass groupGEMM [feat]: add SM100 support for cutlass FP8 groupGEMM Jul 9, 2025
@jiahanc jiahanc force-pushed the dmoss/group_gemm_sm100 branch from cf2d006 to dac867e Compare July 16, 2025 21:23
@mergify mergify bot removed the needs-rebase label Jul 16, 2025
@jiahanc
Copy link
Contributor

jiahanc commented Jul 17, 2025

Kernel perf benchmark

Model Configuration triton_moe triton_moe_cuda_graphs grouped_gemm_moe grouped_gemm_moe_cuda_graphs
nm-testing/Mixtral-8x7B-Instruct-v0.1, num_experts=8, topk=2, per_act_token=True per_out_ch=True, MKN=((1, 4096, 28672)) 8.7 8.2 5.8 4.6
nm-testing/Mixtral-8x7B-Instruct-v0.1, num_experts=8, topk=2, per_act_token=True per_out_ch=True, MKN=((4, 4096, 28672)) 11.8 11.2 8.6 7.4
nm-testing/Mixtral-8x7B-Instruct-v0.1, num_experts=8, topk=2, per_act_token=True per_out_ch=True, MKN=((8, 4096, 28672)) 17.3 16.8 13.6 12.3
nm-testing/Mixtral-8x7B-Instruct-v0.1, num_experts=8, topk=2, per_act_token=True per_out_ch=True, MKN=((16, 4096, 28672)) 72.0 71.4 13.5 12.4
nm-testing/Mixtral-8x7B-Instruct-v0.1, num_experts=8, topk=2, per_act_token=True per_out_ch=True, MKN=((32, 4096, 28672)) 72.0 71.4 13.7 12.4
nm-testing/Mixtral-8x7B-Instruct-v0.1, num_experts=8, topk=2, per_act_token=True per_out_ch=True, MKN=((64, 4096, 28672)) 72.2 71.5 15.2 14.2
nm-testing/Mixtral-8x7B-Instruct-v0.1, num_experts=8, topk=2, per_act_token=True per_out_ch=True, MKN=((128, 4096, 28672)) 73.4 72.7 15.4 14.7
nm-testing/Mixtral-8x7B-Instruct-v0.1, num_experts=8, topk=2, per_act_token=True per_out_ch=True, MKN=((256, 4096, 28672)) 108.6 108.0 16.8 16.0
nm-testing/Mixtral-8x7B-Instruct-v0.1, num_experts=8, topk=2, per_act_token=True per_out_ch=True, MKN=((512, 4096, 28672)) 157.1 156.4 22.0 21.5
nm-testing/Mixtral-8x7B-Instruct-v0.1, num_experts=8, topk=2, per_act_token=True per_out_ch=True, MKN=((1, 14336, 4096)) 4.5 4.1 4.9 3.2
nm-testing/Mixtral-8x7B-Instruct-v0.1, num_experts=8, topk=2, per_act_token=True per_out_ch=True, MKN=((4, 14336, 4096)) 6.4 5.9 5.6 4.5
nm-testing/Mixtral-8x7B-Instruct-v0.1, num_experts=8, topk=2, per_act_token=True per_out_ch=True, MKN=((8, 14336, 4096)) 9.8 9.3 8.2 7.0
nm-testing/Mixtral-8x7B-Instruct-v0.1, num_experts=8, topk=2, per_act_token=True per_out_ch=True, MKN=((16, 14336, 4096)) 35.5 34.9 8.3 7.0
nm-testing/Mixtral-8x7B-Instruct-v0.1, num_experts=8, topk=2, per_act_token=True per_out_ch=True, MKN=((32, 14336, 4096)) 35.4 34.9 8.5 7.3
nm-testing/Mixtral-8x7B-Instruct-v0.1, num_experts=8, topk=2, per_act_token=True per_out_ch=True, MKN=((64, 14336, 4096)) 35.5 34.9 9.0 8.1
nm-testing/Mixtral-8x7B-Instruct-v0.1, num_experts=8, topk=2, per_act_token=True per_out_ch=True, MKN=((128, 14336, 4096)) 35.9 35.3 9.8 8.7
nm-testing/Mixtral-8x7B-Instruct-v0.1, num_experts=8, topk=2, per_act_token=True per_out_ch=True, MKN=((256, 14336, 4096)) 57.4 56.7 11.2 10.3
nm-testing/Mixtral-8x7B-Instruct-v0.1, num_experts=8, topk=2, per_act_token=True per_out_ch=True, MKN=((512, 14336, 4096)) 85.4 84.7 15.1 14.1

@jiahanc
Copy link
Contributor

jiahanc commented Jul 17, 2025

@mgoin the PR is read to go. The cutlass perf has been tuned. We implemented a fallback logic to triton MOE in small batch size because we noticed the extra kernels overhead using cutlass covers the group gemm perf gain, which lead to worse e2e perf in low latency cases.

@shixianc
Copy link
Contributor

shixianc commented Jul 17, 2025

@jiahanc I made a swap_ab change last week

constexpr int SWAP_AB_THRESHOLD = 64;
template <bool SWAP_AB>
__global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids,
int32_t* problem_sizes1,
int32_t* problem_sizes2,
int32_t* atomic_buffer,
const int topk_length, const int n,
const int k) {
int expert_id = blockIdx.x;
int occurrences = 0;
for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
occurrences += (topk_ids[i] == expert_id);
}
atomicAdd(&atomic_buffer[expert_id], occurrences);
__syncthreads();
if (threadIdx.x == 0) {
int final_occurrences = atomic_buffer[expert_id];
if constexpr (!SWAP_AB) {
problem_sizes1[expert_id * 3] = final_occurrences;
problem_sizes1[expert_id * 3 + 1] = 2 * n;
problem_sizes1[expert_id * 3 + 2] = k;
problem_sizes2[expert_id * 3] = final_occurrences;
problem_sizes2[expert_id * 3 + 1] = k;
problem_sizes2[expert_id * 3 + 2] = n;
} else {
problem_sizes1[expert_id * 3] = 2 * n;
problem_sizes1[expert_id * 3 + 1] = final_occurrences;
problem_sizes1[expert_id * 3 + 2] = k;
problem_sizes2[expert_id * 3] = k;
problem_sizes2[expert_id * 3 + 1] = final_occurrences;
problem_sizes2[expert_id * 3 + 2] = n;
}
}
}

the current implementation looks exactly the same especially on handing the problem_sizes in moe_data.cu. Do you think we can merge mine first and do a rebase on top of it?

@jiahanc
Copy link
Contributor

jiahanc commented Jul 17, 2025

@jiahanc I made a swap_ab change last week

constexpr int SWAP_AB_THRESHOLD = 64;
template <bool SWAP_AB>
__global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids,
int32_t* problem_sizes1,
int32_t* problem_sizes2,
int32_t* atomic_buffer,
const int topk_length, const int n,
const int k) {
int expert_id = blockIdx.x;
int occurrences = 0;
for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
occurrences += (topk_ids[i] == expert_id);
}
atomicAdd(&atomic_buffer[expert_id], occurrences);
__syncthreads();
if (threadIdx.x == 0) {
int final_occurrences = atomic_buffer[expert_id];
if constexpr (!SWAP_AB) {
problem_sizes1[expert_id * 3] = final_occurrences;
problem_sizes1[expert_id * 3 + 1] = 2 * n;
problem_sizes1[expert_id * 3 + 2] = k;
problem_sizes2[expert_id * 3] = final_occurrences;
problem_sizes2[expert_id * 3 + 1] = k;
problem_sizes2[expert_id * 3 + 2] = n;
} else {
problem_sizes1[expert_id * 3] = 2 * n;
problem_sizes1[expert_id * 3 + 1] = final_occurrences;
problem_sizes1[expert_id * 3 + 2] = k;
problem_sizes2[expert_id * 3] = k;
problem_sizes2[expert_id * 3 + 1] = final_occurrences;
problem_sizes2[expert_id * 3 + 2] = n;
}
}
}

the current implementation looks exactly the same especially on handing the problem_sizes in moe_data.cu. Do you think we can merge mine first and do a rebase on top of it?

@shixianc , sure, we can hold off the PR till yours merged.

@jiahanc
Copy link
Contributor

jiahanc commented Jul 17, 2025

hold off till #20911 merge first

@jiahanc jiahanc force-pushed the dmoss/group_gemm_sm100 branch from 0f0697b to 89eb69c Compare July 17, 2025 20:29
Copy link

mergify bot commented Jul 18, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @djmmoss.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jul 18, 2025
@mgoin
Copy link
Member

mgoin commented Jul 18, 2025

Nice! Please resolve the merge conflict and I'll enable full CI

@jiahanc jiahanc force-pushed the dmoss/group_gemm_sm100 branch from 89eb69c to 2de546c Compare July 18, 2025 17:09
@mergify mergify bot removed the needs-rebase label Jul 18, 2025
jiahanc added 2 commits July 18, 2025 10:13
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
@jiahanc
Copy link
Contributor

jiahanc commented Jul 18, 2025

@mgoin Finished rebase, may you help review?

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM nice work, just one request for documentation

Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
@mgoin mgoin enabled auto-merge (squash) July 18, 2025 21:23
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you for the support!

Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
auto-merge was automatically disabled July 19, 2025 00:19

Head branch was pushed to by a user without write access

@jiahanc
Copy link
Contributor

jiahanc commented Jul 20, 2025

@mgoin The CI failure seems to be unrelated to the change. May you help take a look and force merge if it looks good?

@vllm-bot vllm-bot merged commit 2c8db17 into vllm-project:main Jul 22, 2025
95 of 98 checks passed
@djmmoss djmmoss deleted the dmoss/group_gemm_sm100 branch July 22, 2025 16:20
yeqcharlotte pushed a commit to yeqcharlotte/vllm that referenced this pull request Jul 23, 2025
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Co-authored-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
zixi-qi pushed a commit to zixi-qi/vllm that referenced this pull request Jul 23, 2025
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Co-authored-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
Signed-off-by: qizixi <qizixi@meta.com>
LyrisZhong pushed a commit to LyrisZhong/vllm that referenced this pull request Jul 23, 2025
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Co-authored-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
avigny pushed a commit to avigny/vllm that referenced this pull request Jul 31, 2025
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Co-authored-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
Signed-off-by: avigny <47987522+avigny@users.noreply.github.com>
wenscarl pushed a commit to wenscarl/vllm that referenced this pull request Aug 4, 2025
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Co-authored-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
Signed-off-by: shuw <shuw@nvidia.com>
x22x22 pushed a commit to x22x22/vllm that referenced this pull request Aug 5, 2025
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Co-authored-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
Signed-off-by: x22x22 <wadeking@qq.com>
Pradyun92 pushed a commit to Pradyun92/vllm that referenced this pull request Aug 6, 2025
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Co-authored-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Co-authored-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Co-authored-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Co-authored-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Paul Pak <paulpak58@gmail.com>
taneem-ibrahim pushed a commit to taneem-ibrahim/vllm that referenced this pull request Aug 14, 2025
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Co-authored-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Co-authored-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Diego-Castan <diego.castan@ibm.com>
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 28, 2025
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Co-authored-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
googlercolin pushed a commit to googlercolin/vllm that referenced this pull request Aug 29, 2025
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Co-authored-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build kernel ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants