-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Add AWQ quantization inference support #1019
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add AWQ quantization inference support #1019
Conversation
Amazing work @abhinavkulkarni. The original creators of AWQ recently removed all support of the GEMM kernel, however, after extensive testing of the GEMV kernel, I believe the GEMM kernel is the best for the highest throughput (large context, large batch size) as the GEMV kernel struggles with performance on anything but empty context and batch size > 1. I am the author of AutoAWQ and will maintain backward compatibility with GEMM as the main packing method and GEMV as an optional packing method (about to be merged). You can either use the old commit or switch to AutoAWQ, which has a PyPi package that is easily installable:
|
Hey thanks a LOT for this PR. Before diving into the details, I'll drop some benchmark I did on single A100: f16:awqgptqThese tell a very good story. Baseline latency for AWQ is 15ms/per, 25ms/token for f16 and 20ms/token for exllama. F16 seems to be managing to reach higher throughput but this regime is not the main one for usual TGI deployments, where latency is the key. I will probably make a bit more testing on different TP sharding layout and model sizes but this looks really good. |
@@ -73,3 +73,5 @@ win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and | |||
wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13" | |||
xxhash==3.3.0 ; python_version >= "3.9" and python_version < "3.13" | |||
yarl==1.9.2 ; python_version >= "3.9" and python_version < "3.13" | |||
# Custom 4-bit GEMM AWQ kernels | |||
git+https://github.com/mit-han-lab/llm-awq.git@f084f40bd996f3cf3a0633c1ad7d9d476c318aaa#subdirectory=awq/kernels |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there anyway we could move to versioned releases ?
Works currently it's a sha on main which is good enough for stability imo.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Narsil: Not sure about licensing issues, but can we copy over the kernels to TGI repo from AWQ repo? That's all this pip install
is pulling in. I saw somewhere that you have a custom_kernels
directory in the repo.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would recommend switching to autoawq==0.1.0
as this has support for Windows as well. The mit-han-lab
kernels break on Windows without the modifications I made.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @casper-hansen.
@Narsil: Okay to have a dependency on autoawq==0.1.0
? If yes, happy to switch from mit-han-lab/llm-awq
repo to autoawq==0.1.0
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's try keep as-is for now.
autoawq
seems too big of a dependency, most of the code is not relevant here, since we only need the kernel not the modeling code. Having small dependencies is better. (If you release a standoff only kernel version, I'd be happy to switch.
Pulling is OK, but it seems mit-han-lab
is pretty barebones, therefore I don't feel we need to duplicate.
For the windows support, as our main layer of distribution is docker, it's not really that necessary to support windows builds for us.
Also looking at the fix PR : casper-hansen/AutoAWQ@e188472 (Is that correct ?) it seems the fix, is just removing the unecessary kernels.
@@ -282,6 +314,20 @@ def get_multi_weights_row(self, prefix: str, quantize: str): | |||
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) | |||
|
|||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) | |||
elif quantize == "awq": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Narsil: I'm keeping separate blocks for GPTQ and AWQ here as GPTQ has a lot of custom logic with regards to exllama, etc. Please resolve if this is okay with you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair enough !
Ok, my comments are mostly nits about code maintenance, but this PR looks great ! Thanks for it ! |
@Narsil @abhinavkulkarni Has this PR been tested to be compatible with the 300+ AWQ repositories that @TheBloke has pushed with AutoAWQ?
|
@TheBloke: Can we get some clarity on naming the quantization config for GPTQ and AWQ? You seem to name it as TGI code expects it to be named Is it possible for you to rename it to |
I would advise against renaming to EDIT: It will also break all other repositories that have implemented AutoAWQ. |
Yeah that's because AutoGPTQ and AutoAWQ create that file with a different name. AutoGPTQ creates it as quantize_config.json, but AutoAWQ creats it as quant_config.json. I'm happy to rename it across all, but it will require AutoAWQ updating to support that name (not quite sure why a different name was used by AutoAWQ) |
Okay, @TheBloke, @casper-hansen and @Narsil, please agree on one standard. I'm happy to comply with whatever is decided. Two issues we need consensus on:
Please let me know what you agree on and I'll modify the PR accordingly. |
@abhinavkulkarni GPTQ doesn't use IMHO the best option is for AutoAWQ to update to use Here is the params that Transformers expects in config.json for a GPTQ model:
Based on this, IMHO AutoAWQ should create a
Then that will be easily integrated into Transformers' config.json -> For backwards compatibility, AutoAWQ could also read from |
I know that Huggingface takes backward compatibility very seriously. Implementing this in AutoAWQ requires quite a bit of maintenance going forward and it will break all PyPi releases of AutoAWQ (v0.0.1, v0.0.2, v0.1.0) because it will not be able to read any Huggingface repository. Additionally, the I am not so sure that people will be happy with this change for the reasons that I just explained, but I would gladly accept PRs that refactor the current behavior and extract the logic into a QuantizeConfig class that implements the necessary logic for backward compatibility. However, going forward, we have to support |
vLLM already allows for both filenames:
And both key names:
HF take backwards compatibility seriously once something is in a Hugging Face project - so they definitely won't want it to change after AWQ is merged. Before it's merged I don't think they'll care one way or the other. But they may well ask for the w_bits -> bits param change to be made in order for it to be merged, so I think you're going to end up doing at least that change eventually if you want it in Transformers. So if this change is going to be made, I would say the time to do it is now - there's never a good time for a change that might break compatibility, but earlier is always far better than later If the config file name does change, I'm happy to have both |
I agree with @TheBloke that it's better to make changes now than later. So, how about this:
Agreed? |
Ok, only nits are left. I will merge this into Thank you very much for this contribution ! |
Totally agree with all the names/config discussion, which is why I suggested reusing all the code we could for awq to match gptq. Also in this repo, unlike in |
try: | ||
g_idx = self.get_tensor(f"{prefix}.g_idx") | ||
except RuntimeError: | ||
g_idx = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you switch to an if
statement ?
It needs to fail if this is missing for GPTQ (which it wouldn't here).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have pushed a new commit that switches try-catch to if-else. Thanks.
try: | ||
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] | ||
for w2 in w[1:]: | ||
torch.testing.assert_close(w2, w[0]) | ||
g_idx = w[0] | ||
except RuntimeError: | ||
g_idx = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here.
@@ -282,6 +314,20 @@ def get_multi_weights_row(self, prefix: str, quantize: str): | |||
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) | |||
|
|||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) | |||
elif quantize == "awq": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair enough !
@@ -216,7 +233,7 @@ def get_tensor_shard(self, var, dim): | |||
return tensor | |||
|
|||
def get_multi_weights_row(self, prefix: str, quantize: str): | |||
if quantize == "gptq": | |||
if quantize in "gptq": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if quantize in "gptq": | |
if quantize == "gptq": |
@@ -73,3 +73,5 @@ win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and | |||
wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13" | |||
xxhash==3.3.0 ; python_version >= "3.9" and python_version < "3.13" | |||
yarl==1.9.2 ; python_version >= "3.9" and python_version < "3.13" | |||
# Custom 4-bit GEMM AWQ kernels | |||
git+https://github.com/mit-han-lab/llm-awq.git@f084f40bd996f3cf3a0633c1ad7d9d476c318aaa#subdirectory=awq/kernels |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's try keep as-is for now.
autoawq
seems too big of a dependency, most of the code is not relevant here, since we only need the kernel not the modeling code. Having small dependencies is better. (If you release a standoff only kernel version, I'd be happy to switch.
Pulling is OK, but it seems mit-han-lab
is pretty barebones, therefore I don't feel we need to duplicate.
For the windows support, as our main layer of distribution is docker, it's not really that necessary to support windows builds for us.
Also looking at the fix PR : casper-hansen/AutoAWQ@e188472 (Is that correct ?) it seems the fix, is just removing the unecessary kernels.
This part is not Windows support for the GEMM kernel, those are FasterTransformer kernels. Here I added Windows support. It's super simple. MSVC does not support |
# Add AWQ quantization inference support Fixes #781 This PR (partially) adds support for AWQ quantization for inference. More information on AWQ [here](https://arxiv.org/abs/2306.00978). In general, AWQ is faster and more accurate than GPTQ, which is currently supported by TGI. This PR installs 4-bit GEMM custom CUDA kernels released by AWQ authors (in `requirements.txt`, just one line change). Quick way to test this PR would be bring up TGI as follows: ``` text-generation-server download-weights abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq text-generation-launcher \ --huggingface-hub-cache ~/.cache/huggingface/hub/ \ --model-id abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq \ --trust-remote-code --port 8080 \ --max-input-length 2048 --max-total-tokens 4096 --max-batch-prefill-tokens 4096 \ --quantize awq ``` Please note: * This PR was tested with FlashAttention v2 and vLLM. * This PR adds support for AWQ inference, not quantizing the models. That needs to be done outside of TGI, instructions [here](https://github.com/mit-han-lab/llm-awq/tree/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa). * This PR only adds support for `FlashLlama` models for now. * Multi-GPU setup has not been tested. * No integration tests have been added so far, will add later if maintainers are interested in this change. * This PR can be tested on any of the models released [here](https://huggingface.co/abhinavkulkarni?sort_models=downloads#models). Please refer to the linked issue for benchmarks for [abhinavkulkarni/meta-llama-Llama-2-7b-chat-hf-w4-g128-awq](https://huggingface.co/abhinavkulkarni/meta-llama-Llama-2-7b-chat-hf-w4-g128-awq) vs [TheBloke/Llama-2-7b-Chat-GPTQ](https://huggingface.co/TheBloke/Llama-2-7b-Chat-GPTQ). Please note, AWQ has released faster (and in case of Llama, fused) kernels for 4-bit GEMM, currently at the top of the `main` branch at https://github.com/mit-han-lab/llm-awq, but this PR uses an older commit that has been tested to work. We can switch to latest commit later on. ## Who can review? @OlivierDehaene OR @Narsil --------- # What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil --> --------- Co-authored-by: Abhinav M Kulkarni <abhinavkulkarni@gmail.com> Co-authored-by: Abhinav Kulkarni <abhinav@concentric.ai>
Oh nice Thanks for the tip ! |
There is only a 10% difference in inference between llama2 7b model quantize model and basic model on a single A100 GPU. In the autoAWQ inference example, there is a fuse_layer option that can significantly affect the speed depending on whether it is set to 'true/false'. I am curious if this option is always set to true in TGI |
Is flash attention v2 installed and on? |
I tested using the tgi-1.1.0 Docker image, which has Flash Attention v2 installed. |
Add AWQ quantization inference support
Fixes #781
This PR (partially) adds support for AWQ quantization for inference. More information on AWQ here. In general, AWQ is faster and more accurate than GPTQ, which is currently supported by TGI.
This PR installs 4-bit GEMM custom CUDA kernels released by AWQ authors (in
requirements.txt
, just one line change).Quick way to test this PR would be bring up TGI as follows:
Please note:
FlashLlama
models for now.Please refer to the linked issue for benchmarks for abhinavkulkarni/meta-llama-Llama-2-7b-chat-hf-w4-g128-awq vs TheBloke/Llama-2-7b-Chat-GPTQ.
Please note, AWQ has released faster (and in case of Llama, fused) kernels for 4-bit GEMM, currently at the top of the
main
branch at https://github.com/mit-han-lab/llm-awq, but this PR uses an older commit that has been tested to work. We can switch to latest commit later on.Who can review?
@OlivierDehaene OR @Narsil