-
-
Notifications
You must be signed in to change notification settings - Fork 9.9k
[AMD] [Quantization] Add override flag for attention dtype instead of using kv_cache_dtype trigger #17331
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AMD] [Quantization] Add override flag for attention dtype instead of using kv_cache_dtype trigger #17331
Conversation
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
I think generally we prefer CLI flags and config over environment variables. Am I missing any context for why this PR is needed? |
Is this just for output scaling? So it decouples that from the kvcache dtype? Either way, could you add more details to the description of the PR? |
I think it's so any FP8 model will work instead of just those with FP8 KV cache. |
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
@ProExpertProg Please take another look, also had to fix |
vllm/config.py
Outdated
@@ -4363,7 +4367,8 @@ def set_current_vllm_config(vllm_config: VllmConfig, check_compile=False): | |||
" if you want it to be supported.", | |||
vllm_config.model_config.model) | |||
finally: | |||
_current_vllm_config = old_vllm_config | |||
if was_raised: | |||
_current_vllm_config = old_vllm_config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this necessary? Shouldn't we always restore the current config?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because _current_vllm_config
is always getting overwritten with old_vllm_config
when set_current_vllm_config
is called, whether there was an exception or not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I understand that. But why would we not want to restore old config if there was no exception?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ohhh I see now, I think you might be using this incorrectly. set_vllm_config
is meant to be used as a context manager:
with set_vllm_config(...):
...
@@ -766,9 +767,15 @@ def forward( | |||
query.dtype, | |||
seq_lens, | |||
make_attn_mask=causal_mask) # type: ignore | |||
|
|||
vllm_config = get_current_vllm_config() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We shouldn't be reading config in the forward method. Instead it should be read during init
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rocm_flash_attn doesn't seem to have any other access to the VllmConfig object. Is there another way for it to get access to the value it needs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I meant the backend/impl's __init__
(mentioned in the meeting)
vllm/config.py
Outdated
@@ -397,6 +397,8 @@ class ModelConfig: | |||
available.\n | |||
- "vllm" will use the vLLM model implementation.\n | |||
- "transformers" will use the Transformers model implementation.""" | |||
use_fp8_scales: bool = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this needs a better name. One idea is override-attention-dtype
and then it's specified as fp8
on the CLI/in the config
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean by "and then it's specified as fp8 on the CLI/in the config"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a string property that specifies the datatype (so not limited to fp8
) - explained in the meeting
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
@ProExpertProg Please take another look, I was able to remove the changes to set_current_vllm_config after adding the call to get_current_vllm_config in init. I renamed the use_fp8_scales flag in favor of the override flag. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great and is much cleaner. My only remaining concern is that we should really warn the user if the flag is ignored. If somebody specifies --overide-attention-dtype=fp8
on NVIDIA or when not using the ROCMFlash backend, we should print a warning saying the flag is not actually doing anything
vllm/config.py
Outdated
@@ -407,6 +407,8 @@ class ModelConfig: | |||
available.\n | |||
- "vllm" will use the vLLM model implementation.\n | |||
- "transformers" will use the Transformers model implementation.""" | |||
override_attention_dtype: str = "fp8" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be None
by default:
override_attention_dtype: str = "fp8" | |
override_attention_dtype: Optional[str] = None |
@@ -580,6 +581,7 @@ def __init__( | |||
logger.debug("Using naive (SDPA) attention in ROCmBackend") | |||
|
|||
self.aiter_kv_scales_initialized = False | |||
self.vllm_config = get_current_vllm_config() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to save the whole config, just do self.force_fp8_attention = ...
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
use_fp8_scales = (layer._q_scale and layer._k_scale | ||
and layer._v_scale and layer._prob_scale | ||
and self.kv_cache_dtype == "fp8") | ||
and self.force_fp8_attention) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we check here if the KV cache is in fp8 already?
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Sorry for the delay
… using kv_cache_dtype trigger (vllm-project#17331) Signed-off-by: Randall Smith <Randall.Smith@amd.com> Signed-off-by: minpeter <kali2005611@gmail.com>
… using kv_cache_dtype trigger (vllm-project#17331) Signed-off-by: Randall Smith <Randall.Smith@amd.com>
… using kv_cache_dtype trigger (vllm-project#17331) Signed-off-by: Randall Smith <Randall.Smith@amd.com>
… using kv_cache_dtype trigger (vllm-project#17331) Signed-off-by: Randall Smith <Randall.Smith@amd.com> Signed-off-by: avigny <47987522+avigny@users.noreply.github.com>
… using kv_cache_dtype trigger (vllm-project#17331) Signed-off-by: Randall Smith <Randall.Smith@amd.com>
This adds a flag to use override dtype in VllmConfig instead of using the kv_cache_dtype flag so any FP8 model will work instead of just those with fp8 kv cache