Skip to content

Conversation

rasmith
Copy link
Contributor

@rasmith rasmith commented Apr 28, 2025

This adds a flag to use override dtype in VllmConfig instead of using the kv_cache_dtype flag so any FP8 model will work instead of just those with fp8 kv cache

Signed-off-by: Randall Smith <Randall.Smith@amd.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Signed-off-by: Randall Smith <Randall.Smith@amd.com>
@ProExpertProg
Copy link
Collaborator

I think generally we prefer CLI flags and config over environment variables. Am I missing any context for why this PR is needed?

@ProExpertProg
Copy link
Collaborator

Is this just for output scaling? So it decouples that from the kvcache dtype? Either way, could you add more details to the description of the PR?

@rasmith
Copy link
Contributor Author

rasmith commented May 1, 2025

Is this just for output scaling? So it decouples that from the kvcache dtype? Either way, could you add more details to the description of the PR?

I think it's so any FP8 model will work instead of just those with FP8 KV cache.

Copy link

mergify bot commented May 8, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @rasmith.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label May 8, 2025
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
@mergify mergify bot removed the needs-rebase label May 8, 2025
@rasmith rasmith changed the title [AMD] [Quantization] Add VLLM_ROCM_USE_FP8_SCALES flag [AMD] [Quantization] Add flag for using fp8 scales instead of using kv_cache_dtype trigger May 8, 2025
@rasmith
Copy link
Contributor Author

rasmith commented May 8, 2025

Is this just for output scaling? So it decouples that from the kvcache dtype? Either way, could you add more details to the description of the PR?

@ProExpertProg Please take another look, also had to fix set_current_vllm_config, please see description.

vllm/config.py Outdated
@@ -4363,7 +4367,8 @@ def set_current_vllm_config(vllm_config: VllmConfig, check_compile=False):
" if you want it to be supported.",
vllm_config.model_config.model)
finally:
_current_vllm_config = old_vllm_config
if was_raised:
_current_vllm_config = old_vllm_config
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this necessary? Shouldn't we always restore the current config?

Copy link
Contributor Author

@rasmith rasmith May 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because _current_vllm_config is always getting overwritten with old_vllm_config when set_current_vllm_config is called, whether there was an exception or not.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I understand that. But why would we not want to restore old config if there was no exception?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohhh I see now, I think you might be using this incorrectly. set_vllm_config is meant to be used as a context manager:

   with set_vllm_config(...):
      ...

@@ -766,9 +767,15 @@ def forward(
query.dtype,
seq_lens,
make_attn_mask=causal_mask) # type: ignore

vllm_config = get_current_vllm_config()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't be reading config in the forward method. Instead it should be read during init

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rocm_flash_attn doesn't seem to have any other access to the VllmConfig object. Is there another way for it to get access to the value it needs?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I meant the backend/impl's __init__ (mentioned in the meeting)

vllm/config.py Outdated
@@ -397,6 +397,8 @@ class ModelConfig:
available.\n
- "vllm" will use the vLLM model implementation.\n
- "transformers" will use the Transformers model implementation."""
use_fp8_scales: bool = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this needs a better name. One idea is override-attention-dtype and then it's specified as fp8 on the CLI/in the config

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean by "and then it's specified as fp8 on the CLI/in the config"?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a string property that specifies the datatype (so not limited to fp8) - explained in the meeting

rasmith added 5 commits May 20, 2025 15:15
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
@rasmith
Copy link
Contributor Author

rasmith commented May 22, 2025

@ProExpertProg Please take another look, I was able to remove the changes to set_current_vllm_config after adding the call to get_current_vllm_config in init. I renamed the use_fp8_scales flag in favor of the override flag.

@rasmith rasmith changed the title [AMD] [Quantization] Add flag for using fp8 scales instead of using kv_cache_dtype trigger [AMD] [Quantization] Add override flag for attention dtype instead of using kv_cache_dtype trigger May 22, 2025
Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great and is much cleaner. My only remaining concern is that we should really warn the user if the flag is ignored. If somebody specifies --overide-attention-dtype=fp8 on NVIDIA or when not using the ROCMFlash backend, we should print a warning saying the flag is not actually doing anything

vllm/config.py Outdated
@@ -407,6 +407,8 @@ class ModelConfig:
available.\n
- "vllm" will use the vLLM model implementation.\n
- "transformers" will use the Transformers model implementation."""
override_attention_dtype: str = "fp8"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be None by default:

Suggested change
override_attention_dtype: str = "fp8"
override_attention_dtype: Optional[str] = None

@@ -580,6 +581,7 @@ def __init__(
logger.debug("Using naive (SDPA) attention in ROCmBackend")

self.aiter_kv_scales_initialized = False
self.vllm_config = get_current_vllm_config()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to save the whole config, just do self.force_fp8_attention = ...

rasmith added 3 commits May 22, 2025 21:52
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
use_fp8_scales = (layer._q_scale and layer._k_scale
and layer._v_scale and layer._prob_scale
and self.kv_cache_dtype == "fp8")
and self.force_fp8_attention)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we check here if the KV cache is in fp8 already?

rasmith added 3 commits June 3, 2025 10:24
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Sorry for the delay

@gshtras gshtras added the ready ONLY add when PR is ready to merge/full CI is needed label Jun 10, 2025
@mergify mergify bot added the rocm Related to AMD ROCm label Jun 11, 2025
@gshtras gshtras merged commit c7ea0b5 into vllm-project:main Jun 11, 2025
78 checks passed
minpeter pushed a commit to minpeter/vllm that referenced this pull request Jun 24, 2025
… using kv_cache_dtype trigger (vllm-project#17331)

Signed-off-by: Randall Smith <Randall.Smith@amd.com>
Signed-off-by: minpeter <kali2005611@gmail.com>
xjpang pushed a commit to xjpang/vllm that referenced this pull request Jun 30, 2025
… using kv_cache_dtype trigger (vllm-project#17331)

Signed-off-by: Randall Smith <Randall.Smith@amd.com>
wseaton pushed a commit to wseaton/vllm that referenced this pull request Jun 30, 2025
… using kv_cache_dtype trigger (vllm-project#17331)

Signed-off-by: Randall Smith <Randall.Smith@amd.com>
avigny pushed a commit to avigny/vllm that referenced this pull request Jul 31, 2025
… using kv_cache_dtype trigger (vllm-project#17331)

Signed-off-by: Randall Smith <Randall.Smith@amd.com>
Signed-off-by: avigny <47987522+avigny@users.noreply.github.com>
googlercolin pushed a commit to googlercolin/vllm that referenced this pull request Aug 29, 2025
… using kv_cache_dtype trigger (vllm-project#17331)

Signed-off-by: Randall Smith <Randall.Smith@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants