Skip to content

Conversation

gshtras
Copy link
Collaborator

@gshtras gshtras commented Jul 1, 2025

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
@gshtras gshtras added the rocm Related to AMD ROCm label Jul 1, 2025
Copy link

github-actions bot commented Jul 1, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @gshtras, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request implements a necessary adaptation for compatibility with the upcoming ROCm 7.0 release. It modifies how the WARP_SIZE constant is determined within the codebase, moving away from a runtime-dependent variable to a compile-time macro that intelligently selects the correct warp size based on the target GPU architecture. This change ensures the stability and correct compilation of kernels under the new ROCm environment.

Highlights

  • ROCm 7.0 Compatibility: This pull request addresses a critical breaking change introduced in ROCm 7.0, where the warpSize variable is no longer a constexpr. This change prevents its direct use in compile-time contexts, which previously relied on its constant nature.
  • Updated WARP_SIZE Definition: To maintain compile-time constant behavior, the WARP_SIZE macro's definition has been updated. It now explicitly sets WARP_SIZE to 64 when compiling for ROCm on GFX9 architectures (e.g., AMD Instinct MI100/MI250) and to 32 for all other platforms, including NVIDIA CUDA and other ROCm targets.
  • Consistent Macro Application: The revised WARP_SIZE definition has been applied consistently across multiple core files: csrc/attention/attention_kernels.cuh, csrc/attention/paged_attention_v1.cu, csrc/attention/paged_attention_v2.cu, and csrc/cuda_compat.h. This ensures uniform behavior and compatibility across the codebase.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This PR addresses the breaking change in ROCm 7.0 related to warpSize. The changes look correct, but I recommend centralizing the WARP_SIZE definition to avoid code duplication.

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
@gshtras
Copy link
Collaborator Author

gshtras commented Jul 1, 2025

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a breaking change in ROCm 7.0 where warpSize is no longer a constexpr. The approach is to centralize the WARP_SIZE definition in csrc/cuda_compat.h and use preprocessor directives to set a compile-time constant.

The refactoring to a central header is a good improvement. However, the implementation for determining WARP_SIZE on ROCm is not fully correct. It handles GFX9 GPUs but makes a potentially incorrect assumption for newer GFX10+ architectures, which can lead to bugs if the code is compiled with a 64-wide wavefront. I've provided a critical review comment with a suggested fix that aligns with AMD's recommendation to use the __AMDGPU_WAVEFRONT_SIZE__ macro for a more robust solution.

@gshtras gshtras added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 8, 2025
Copy link
Contributor

@SageMoore SageMoore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, @gshtras

#ifndef USE_ROCM
#define WARP_SIZE 32
#if defined(USE_ROCM) && defined(__GFX9__)
#define WARP_SIZE 64
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is warpSize from the previous implementation 64?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

@gshtras gshtras merged commit ed10f3c into vllm-project:main Jul 15, 2025
106 checks passed
@gshtras gshtras deleted the deprecate_warpsize branch July 15, 2025 18:01
hj-mistral pushed a commit to hj-mistral/vllm that referenced this pull request Jul 19, 2025
…#20330)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Signed-off-by: Himanshu Jaju <hj@mistral.ai>
LyrisZhong pushed a commit to LyrisZhong/vllm that referenced this pull request Jul 23, 2025
…#20330)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
gshtras added a commit to ROCm/vllm that referenced this pull request Jul 31, 2025
Commits included:

Using cuda_compat to defint the WARP_SIZE once

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>

On ROCm toe constant compile time warp size can not be used on the host side that can be shared for multiple architectures with different values

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>

Formatting

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>

Refactor to use cuda_compat, and not the unhippified version

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>

Leaving CUDA side as just a simple define

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
gshtras added a commit to ROCm/vllm that referenced this pull request Jul 31, 2025
Commits included:

Using cuda_compat to defint the WARP_SIZE once



On ROCm toe constant compile time warp size can not be used on the host side that can be shared for multiple architectures with different values



Formatting



Refactor to use cuda_compat, and not the unhippified version



Leaving CUDA side as just a simple define

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
avigny pushed a commit to avigny/vllm that referenced this pull request Jul 31, 2025
…#20330)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Signed-off-by: avigny <47987522+avigny@users.noreply.github.com>
x22x22 pushed a commit to x22x22/vllm that referenced this pull request Aug 5, 2025
…#20330)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Signed-off-by: x22x22 <wadeking@qq.com>
Pradyun92 pushed a commit to Pradyun92/vllm that referenced this pull request Aug 6, 2025
…#20330)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
…#20330)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
…#20330)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
…#20330)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Signed-off-by: Paul Pak <paulpak58@gmail.com>
taneem-ibrahim pushed a commit to taneem-ibrahim/vllm that referenced this pull request Aug 14, 2025
…#20330)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
…#20330)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Signed-off-by: Diego-Castan <diego.castan@ibm.com>
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 27, 2025
…#20330)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants