Skip to content

Conversation

yewentao256
Copy link
Collaborator

@yewentao256 yewentao256 commented Jul 16, 2025

Purpose

Add cuda kernel for per-token-group-quant instead of triton

Modified from SGL, https://github.com/sgl-project/sglang/blob/570d33437bf0b4ac42e00ad468ddc43f9e0b376f/sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu, thanks for the code!

Add support for row-major E8M0, float32 scale tensor, and optimized the cuda kernel using shared memory.

Test

B200

Accuracy

VLLM_USE_DEEP_GEMM=1 lm_eval   --model vllm   --model_args "pretrained=Qwen/Qwen3-30B-A3B-FP8,max_model_len=32768,enforce_eager=True"   --trust_remote_code   --tasks gsm8k   --num_fewshot 5   --batch_size auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match||0.8469|±  |0.0099|
|     |       |strict-match    |     5|exact_match||0.8855|±  |0.0088|

# main
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match||0.8476|±  |0.0099|
|     |       |strict-match    |     5|exact_match||0.8855|±  |0.0088|

Performance

VLLM_USE_DEEP_GEMM=1 vllm bench throughput --model Qwen/Qwen3-30B-A3B-FP8 --load-format dummy --input-len 1000 --output-len 100 --trust_remote_code --enable-expert-parallel
Throughput: 27.00 requests/s, 29643.76 total tokens/s, 2700.48 output tokens/s
# main
Throughput: 25.35 requests/s, 27828.75 total tokens/s, 2535.13 output tokens/s

R1(EP+DP):

export VLLM_USE_DEEP_GEMM=1 
export VLLM_ALL2ALL_BACKEND="deepep_high_throughput"
export VLLM_RANDOMIZE_DP_DUMMY_INPUTS=1
vllm serve deepseek-ai/DeepSeek-R1        --load-format dummy        --trust-remote-code        --enforce-eager        --enable-expert-parallel        --data-parallel-size 8

vllm bench serve        --model deepseek-ai/DeepSeek-R1        --dataset-name random        --random-input-len 256        --random-output-len 100        --num-prompts 1000        --request-rate inf
============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  38.77     
Total input tokens:                      254242    
Total generated tokens:                  100000    
Request throughput (req/s):              25.79     
Output token throughput (tok/s):         2579.29   
Total Token throughput (tok/s):          9136.92   
---------------Time to First Token----------------
Mean TTFT (ms):                          14384.65  
Median TTFT (ms):                        14924.79  
P99 TTFT (ms):                           22782.33  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          243.08    
Median TPOT (ms):                        237.89    
P99 TPOT (ms):                           362.43    
---------------Inter-token Latency----------------
Mean ITL (ms):                           243.09    
Median ITL (ms):                         161.80    
P99 ITL (ms):                            5376.92   
==================================================

# DeepGemm(main)
============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  40.67     
Total input tokens:                      254242    
Total generated tokens:                  100000    
Request throughput (req/s):              24.59     
Output token throughput (tok/s):         2458.55   
Total Token throughput (tok/s):          8709.23   
---------------Time to First Token----------------
Mean TTFT (ms):                          14935.81  
Median TTFT (ms):                        15457.36  
P99 TTFT (ms):                           23764.34  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          256.56    
Median TPOT (ms):                        251.40    
P99 TPOT (ms):                           379.87    
---------------Inter-token Latency----------------
Mean ITL (ms):                           256.56    
Median ITL (ms):                         170.53    
P99 ITL (ms):                            5640.72   
==================================================

H100

# Now
Throughput: 23.88 requests/s, 26216.81 total tokens/s, 2388.29 output tokens/s

# Main
Throughput: 22.35 requests/s, 24537.01 total tokens/s, 2235.26 output tokens/s

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the ci/build label Jul 16, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new CUDA kernel for per-token-group FP8 quantization. The changes include the kernel implementation, its C++ and Python bindings, and integration into the quantization utilities. The new CUDA kernel is used when available, with a fallback to the Triton kernel.

There are a few high-severity issues that should be addressed:

  • The C++ function signature for the new op should align with existing conventions in the codebase.
  • The new CUDA kernel can be further optimized to reduce global memory bandwidth.
  • A check in the C++ code is overly restrictive and could lead to runtime errors for valid input shapes.

@yewentao256 yewentao256 marked this pull request as draft July 16, 2025 23:17
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
@yewentao256 yewentao256 marked this pull request as ready for review July 17, 2025 18:53
@yewentao256
Copy link
Collaborator Author

CC @mgoin

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need a kernel unit test now to compare the cuda kernel against the triton/torch impl for the 4 cases we have to consider now (row float, row e8m0, col float, col e8m0)

Comment on lines +123 to +124
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(output_q.is_contiguous());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Contiguous might be a problem for MLA, so please test a couple DeepSeek evals/benchmarks

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, so I choose to fallback to triton when input is not contiguous.
Now it works:

VLLM_USE_DEEP_GEMM=1 lm_eval   --model vllm   --model_args "pretrained=deepseek-ai/DeepSeek-R1,data_parallel_size=8,gpu_memory_utilization=0.95,max_model_len=16384,enable_expert_parallel=True"   --tasks gsm8k   --batch_size auto   --num_fewshot 5
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9560|±  |0.0056|
|     |       |strict-match    |     5|exact_match|↑  |0.9545|±  |0.0057|

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
@yewentao256 yewentao256 requested a review from WoosukKwon as a code owner July 18, 2025 20:53
@@ -366,6 +366,7 @@ def per_token_group_quant_fp8(
dtype: Optional[torch.dtype] = None,
column_major_scales: bool = False,
out_q: Optional[torch.Tensor] = None,
use_ue8m0: bool = is_blackwell_deep_gemm_used(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I worry about setting this as a default variable since this function could be used on Blackwell, but for the CUTLASS or FlashInfer FP8 block kernels that are now on SM100

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_blackwell_deep_gemm_used will check the env VLLM_USE_DEEP_GEMM as well, so it won't cause trouble now.
And this default is no so good as well, since DeepGemm now supports float32 scale on B200 now, I will have another pr for this, letting the user decide whether to use e8m0

@mgoin mgoin added ready ONLY add when PR is ready to merge/full CI is needed performance Performance-related issues labels Jul 21, 2025
@mgoin mgoin added the deepseek Related to DeepSeek models label Jul 21, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
@mgoin mgoin enabled auto-merge (squash) July 22, 2025 02:28
@vllm-bot vllm-bot merged commit 774d0c0 into vllm-project:main Jul 22, 2025
98 of 100 checks passed
@yewentao256 yewentao256 deleted the wye/cuda-kernel-for-per-token-group-quant branch July 22, 2025 14:38
@fxmarty-amd
Copy link
Contributor

fxmarty-amd commented Jul 22, 2025

I think this PR causes this build error on MI300:

FAILED: CMakeFiles/_C.dir/csrc/quantization/fp8/per_token_group_quant.hip.o
/opt/rocm/lib/llvm/bin/clang++  -DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1 -DPy_LIMITED_API=3 -DTORCH_EXTENSION_NAME=_C -DUSE_C10D_GLOO -DUSE_C10D_NCCL -DUSE_DISTRIBUTED -DUSE_PROF_API=1 -DUSE_RPC -DUSE_TENSORPIPE -D_C_EXPORTS -D__HIP_PLATFORM_AMD__ -D__HIP_PLATFORM_AMD__=1 -D__HIP_ROCclr__=1 -I/shared_volume/repos/vllm/csrc -isystem /usr/include/python3.12 -isystem /usr/local/lib/python3.12/dist-packages/torch/include -isystem /usr/local/lib/python3.12/dist-packages/torch/include/torch/csrc/api/include -isystem /opt/rocm-6.4.1/include/hiprand -isystem /opt/rocm-6.4.1/include/rocrand -Wno-unused-result -O2 -g -DNDEBUG --offload-arch=gfx942 -fPIC -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -DHIP_ENABLE_WARP_SYNC_BUILTINS=1 -DUSE_ROCM -DENABLE_FP8 -U__HIP_NO_HALF_CONVERSIONS__ -U__HIP_NO_HALF_OPERATORS__ -Werror=unused-variable -fno-gpu-rdc -D_GLIBCXX_USE_CXX11_ABI=1 -DTORCH_HIP_VERSION=604 -Wno-shift-count-negative -Wno-shift-count-overflow -Wno-duplicate-decl-specifier -DCAFFE2_USE_MIOPEN -DTHRUST_DEVICE_SYSTEM=THRUST_DEVICE_SYSTEM_HIP -std=c++17 -DHIPBLASLT_VEC_EXT -DHIP_ENABLE_WARP_SYNC_BUILTINS -MD -MT CMakeFiles/_C.dir/csrc/quantization/fp8/per_token_group_quant.hip.o -MF CMakeFiles/_C.dir/csrc/quantization/fp8/per_token_group_quant.hip.o.d -o CMakeFiles/_C.dir/csrc/quantization/fp8/per_token_group_quant.hip.o -x hip -c /shared_volume/repos/vllm/build/temp.linux-x86_64-cpython-312/csrc/quantization/fp8/per_token_group_quant.hip
In file included from /shared_volume/repos/vllm/build/temp.linux-x86_64-cpython-312/csrc/quantization/fp8/per_token_group_quant.hip:2:
In file included from /opt/rocm-6.4.1/lib/llvm/bin/../../../include/hip/hip_runtime.h:62:
In file included from /opt/rocm-6.4.1/lib/llvm/bin/../../../include/hip/amd_detail/amd_hip_runtime.h:115:
In file included from /opt/rocm-6.4.1/lib/llvm/bin/../../../include/hip/amd_detail/amd_hip_atomic.h:26:
In file included from /opt/rocm-6.4.1/lib/llvm/bin/../../../include/hip/amd_detail/amd_device_functions.h:340:
/opt/rocm-6.4.1/lib/llvm/bin/../../../include/hip/amd_detail/amd_warp_sync_functions.h:274:52: error: static assertion failed due to requirement 'sizeof(unsigned int) == 8': The mask must be a 64-bit integer. Implicitly promoting a smaller integer is almost always an error.
  274 |       __hip_internal::is_integral<MaskT>::value && sizeof(MaskT) == 8,
      |                                                    ^~~~~~~~~~~~~~~~~~
/shared_volume/repos/vllm/build/temp.linux-x86_64-cpython-312/csrc/quantization/fp8/per_token_group_quant.hip:20:20: note: in instantiation of function template specialization '__shfl_xor_sync<unsigned int, float>' requested here
   20 |   val = fmaxf(val, __shfl_xor_sync(mask, val, 8));
      |                    ^
/opt/rocm-6.4.1/lib/llvm/bin/../../../include/hip/amd_detail/amd_warp_sync_functions.h:274:66: note: expression evaluates to '4 == 8'
  274 |       __hip_internal::is_integral<MaskT>::value && sizeof(MaskT) == 8,
      |                                                    ~~~~~~~~~~~~~~^~~~
1 error generated when compiling for gfx942.

Checking out to the previous commit 2c8db17 and cherry-picking 226b452 build goes fine.

(or maybe something is wrong in my env cc @gshtras)

@gshtras
Copy link
Collaborator

gshtras commented Jul 22, 2025

I think this PR causes this build error on MI300:

FAILED: CMakeFiles/_C.dir/csrc/quantization/fp8/per_token_group_quant.hip.o
/opt/rocm/lib/llvm/bin/clang++  -DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1 -DPy_LIMITED_API=3 -DTORCH_EXTENSION_NAME=_C -DUSE_C10D_GLOO -DUSE_C10D_NCCL -DUSE_DISTRIBUTED -DUSE_PROF_API=1 -DUSE_RPC -DUSE_TENSORPIPE -D_C_EXPORTS -D__HIP_PLATFORM_AMD__ -D__HIP_PLATFORM_AMD__=1 -D__HIP_ROCclr__=1 -I/shared_volume/repos/vllm/csrc -isystem /usr/include/python3.12 -isystem /usr/local/lib/python3.12/dist-packages/torch/include -isystem /usr/local/lib/python3.12/dist-packages/torch/include/torch/csrc/api/include -isystem /opt/rocm-6.4.1/include/hiprand -isystem /opt/rocm-6.4.1/include/rocrand -Wno-unused-result -O2 -g -DNDEBUG --offload-arch=gfx942 -fPIC -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -DHIP_ENABLE_WARP_SYNC_BUILTINS=1 -DUSE_ROCM -DENABLE_FP8 -U__HIP_NO_HALF_CONVERSIONS__ -U__HIP_NO_HALF_OPERATORS__ -Werror=unused-variable -fno-gpu-rdc -D_GLIBCXX_USE_CXX11_ABI=1 -DTORCH_HIP_VERSION=604 -Wno-shift-count-negative -Wno-shift-count-overflow -Wno-duplicate-decl-specifier -DCAFFE2_USE_MIOPEN -DTHRUST_DEVICE_SYSTEM=THRUST_DEVICE_SYSTEM_HIP -std=c++17 -DHIPBLASLT_VEC_EXT -DHIP_ENABLE_WARP_SYNC_BUILTINS -MD -MT CMakeFiles/_C.dir/csrc/quantization/fp8/per_token_group_quant.hip.o -MF CMakeFiles/_C.dir/csrc/quantization/fp8/per_token_group_quant.hip.o.d -o CMakeFiles/_C.dir/csrc/quantization/fp8/per_token_group_quant.hip.o -x hip -c /shared_volume/repos/vllm/build/temp.linux-x86_64-cpython-312/csrc/quantization/fp8/per_token_group_quant.hip
In file included from /shared_volume/repos/vllm/build/temp.linux-x86_64-cpython-312/csrc/quantization/fp8/per_token_group_quant.hip:2:
In file included from /opt/rocm-6.4.1/lib/llvm/bin/../../../include/hip/hip_runtime.h:62:
In file included from /opt/rocm-6.4.1/lib/llvm/bin/../../../include/hip/amd_detail/amd_hip_runtime.h:115:
In file included from /opt/rocm-6.4.1/lib/llvm/bin/../../../include/hip/amd_detail/amd_hip_atomic.h:26:
In file included from /opt/rocm-6.4.1/lib/llvm/bin/../../../include/hip/amd_detail/amd_device_functions.h:340:
/opt/rocm-6.4.1/lib/llvm/bin/../../../include/hip/amd_detail/amd_warp_sync_functions.h:274:52: error: static assertion failed due to requirement 'sizeof(unsigned int) == 8': The mask must be a 64-bit integer. Implicitly promoting a smaller integer is almost always an error.
  274 |       __hip_internal::is_integral<MaskT>::value && sizeof(MaskT) == 8,
      |                                                    ^~~~~~~~~~~~~~~~~~
/shared_volume/repos/vllm/build/temp.linux-x86_64-cpython-312/csrc/quantization/fp8/per_token_group_quant.hip:20:20: note: in instantiation of function template specialization '__shfl_xor_sync<unsigned int, float>' requested here
   20 |   val = fmaxf(val, __shfl_xor_sync(mask, val, 8));
      |                    ^
/opt/rocm-6.4.1/lib/llvm/bin/../../../include/hip/amd_detail/amd_warp_sync_functions.h:274:66: note: expression evaluates to '4 == 8'
  274 |       __hip_internal::is_integral<MaskT>::value && sizeof(MaskT) == 8,
      |                                                    ~~~~~~~~~~~~~~^~~~
1 error generated when compiling for gfx942.

Checking out to the previous commit 2c8db17 and cherry-picking 226b452 build goes fine.

(or maybe something is wrong in my env cc @gshtras)

From Float8_e4m3fn alone doesn't look like this file is supposed to be used on ROCm
I believe the issue here is with the way __shfl_xor_syncis called with a 4-byte mask value, which would only work on CUDA

@mgoin
Copy link
Member

mgoin commented Jul 22, 2025

Hey @gshtras thanks for reporting this and apologies for the oversight. Can you see that #21392 fixes your issue?

@j0hngou
Copy link

j0hngou commented Jul 22, 2025

I think this PR causes this build error on MI300:

FAILED: CMakeFiles/_C.dir/csrc/quantization/fp8/per_token_group_quant.hip.o
/opt/rocm/lib/llvm/bin/clang++  -DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1 -DPy_LIMITED_API=3 -DTORCH_EXTENSION_NAME=_C -DUSE_C10D_GLOO -DUSE_C10D_NCCL -DUSE_DISTRIBUTED -DUSE_PROF_API=1 -DUSE_RPC -DUSE_TENSORPIPE -D_C_EXPORTS -D__HIP_PLATFORM_AMD__ -D__HIP_PLATFORM_AMD__=1 -D__HIP_ROCclr__=1 -I/shared_volume/repos/vllm/csrc -isystem /usr/include/python3.12 -isystem /usr/local/lib/python3.12/dist-packages/torch/include -isystem /usr/local/lib/python3.12/dist-packages/torch/include/torch/csrc/api/include -isystem /opt/rocm-6.4.1/include/hiprand -isystem /opt/rocm-6.4.1/include/rocrand -Wno-unused-result -O2 -g -DNDEBUG --offload-arch=gfx942 -fPIC -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -DHIP_ENABLE_WARP_SYNC_BUILTINS=1 -DUSE_ROCM -DENABLE_FP8 -U__HIP_NO_HALF_CONVERSIONS__ -U__HIP_NO_HALF_OPERATORS__ -Werror=unused-variable -fno-gpu-rdc -D_GLIBCXX_USE_CXX11_ABI=1 -DTORCH_HIP_VERSION=604 -Wno-shift-count-negative -Wno-shift-count-overflow -Wno-duplicate-decl-specifier -DCAFFE2_USE_MIOPEN -DTHRUST_DEVICE_SYSTEM=THRUST_DEVICE_SYSTEM_HIP -std=c++17 -DHIPBLASLT_VEC_EXT -DHIP_ENABLE_WARP_SYNC_BUILTINS -MD -MT CMakeFiles/_C.dir/csrc/quantization/fp8/per_token_group_quant.hip.o -MF CMakeFiles/_C.dir/csrc/quantization/fp8/per_token_group_quant.hip.o.d -o CMakeFiles/_C.dir/csrc/quantization/fp8/per_token_group_quant.hip.o -x hip -c /shared_volume/repos/vllm/build/temp.linux-x86_64-cpython-312/csrc/quantization/fp8/per_token_group_quant.hip
In file included from /shared_volume/repos/vllm/build/temp.linux-x86_64-cpython-312/csrc/quantization/fp8/per_token_group_quant.hip:2:
In file included from /opt/rocm-6.4.1/lib/llvm/bin/../../../include/hip/hip_runtime.h:62:
In file included from /opt/rocm-6.4.1/lib/llvm/bin/../../../include/hip/amd_detail/amd_hip_runtime.h:115:
In file included from /opt/rocm-6.4.1/lib/llvm/bin/../../../include/hip/amd_detail/amd_hip_atomic.h:26:
In file included from /opt/rocm-6.4.1/lib/llvm/bin/../../../include/hip/amd_detail/amd_device_functions.h:340:
/opt/rocm-6.4.1/lib/llvm/bin/../../../include/hip/amd_detail/amd_warp_sync_functions.h:274:52: error: static assertion failed due to requirement 'sizeof(unsigned int) == 8': The mask must be a 64-bit integer. Implicitly promoting a smaller integer is almost always an error.
  274 |       __hip_internal::is_integral<MaskT>::value && sizeof(MaskT) == 8,
      |                                                    ^~~~~~~~~~~~~~~~~~
/shared_volume/repos/vllm/build/temp.linux-x86_64-cpython-312/csrc/quantization/fp8/per_token_group_quant.hip:20:20: note: in instantiation of function template specialization '__shfl_xor_sync<unsigned int, float>' requested here
   20 |   val = fmaxf(val, __shfl_xor_sync(mask, val, 8));
      |                    ^
/opt/rocm-6.4.1/lib/llvm/bin/../../../include/hip/amd_detail/amd_warp_sync_functions.h:274:66: note: expression evaluates to '4 == 8'
  274 |       __hip_internal::is_integral<MaskT>::value && sizeof(MaskT) == 8,
      |                                                    ~~~~~~~~~~~~~~^~~~
1 error generated when compiling for gfx942.

Checking out to the previous commit 2c8db17 and cherry-picking 226b452 build goes fine.
(or maybe something is wrong in my env cc @gshtras)

From Float8_e4m3fn alone doesn't look like this file is supposed to be used on ROCm I believe the issue here is with the way __shfl_xor_syncis called with a 4-byte mask value, which would only work on CUDA

I was about to report the same thing, MI300A GPU, build worked before.

yeqcharlotte pushed a commit to yeqcharlotte/vllm that referenced this pull request Jul 23, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
zixi-qi pushed a commit to zixi-qi/vllm that referenced this pull request Jul 23, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: qizixi <qizixi@meta.com>
LyrisZhong pushed a commit to LyrisZhong/vllm that referenced this pull request Jul 23, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
avigny pushed a commit to avigny/vllm that referenced this pull request Jul 31, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: avigny <47987522+avigny@users.noreply.github.com>
wenscarl pushed a commit to wenscarl/vllm that referenced this pull request Aug 4, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: shuw <shuw@nvidia.com>
x22x22 pushed a commit to x22x22/vllm that referenced this pull request Aug 5, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: x22x22 <wadeking@qq.com>
Pradyun92 pushed a commit to Pradyun92/vllm that referenced this pull request Aug 6, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: Paul Pak <paulpak58@gmail.com>
taneem-ibrahim pushed a commit to taneem-ibrahim/vllm that referenced this pull request Aug 14, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: Diego-Castan <diego.castan@ibm.com>
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 28, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
googlercolin pushed a commit to googlercolin/vllm that referenced this pull request Aug 29, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build deepseek Related to DeepSeek models performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants