Skip to content

Conversation

abhinavkulkarni
Copy link
Contributor

@abhinavkulkarni abhinavkulkarni commented Sep 13, 2023

Add AWQ quantization inference support

Fixes #781

This PR (partially) adds support for AWQ quantization for inference. More information on AWQ here. In general, AWQ is faster and more accurate than GPTQ, which is currently supported by TGI.

This PR installs 4-bit GEMM custom CUDA kernels released by AWQ authors (in requirements.txt, just one line change).

Quick way to test this PR would be bring up TGI as follows:

text-generation-server download-weights abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq

text-generation-launcher \
--huggingface-hub-cache ~/.cache/huggingface/hub/ \
--model-id abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq \
--trust-remote-code --port 8080 \
--max-input-length 2048 --max-total-tokens 4096 --max-batch-prefill-tokens 4096 \
--quantize awq

Please note:

  • This PR was tested with FlashAttention v2 and vLLM.
  • This PR adds support for AWQ inference, not quantizing the models. That needs to be done outside of TGI, instructions here.
  • This PR only adds support for FlashLlama models for now.
  • Multi-GPU setup has not been tested.
  • No integration tests have been added so far, will add later if maintainers are interested in this change.
  • This PR can be tested on any of the models released here.

Please refer to the linked issue for benchmarks for abhinavkulkarni/meta-llama-Llama-2-7b-chat-hf-w4-g128-awq vs TheBloke/Llama-2-7b-Chat-GPTQ.

Please note, AWQ has released faster (and in case of Llama, fused) kernels for 4-bit GEMM, currently at the top of the main branch at https://github.com/mit-han-lab/llm-awq, but this PR uses an older commit that has been tested to work. We can switch to latest commit later on.

Who can review?

@OlivierDehaene OR @Narsil

@casper-hansen
Copy link

Amazing work @abhinavkulkarni. The original creators of AWQ recently removed all support of the GEMM kernel, however, after extensive testing of the GEMV kernel, I believe the GEMM kernel is the best for the highest throughput (large context, large batch size) as the GEMV kernel struggles with performance on anything but empty context and batch size > 1.

I am the author of AutoAWQ and will maintain backward compatibility with GEMM as the main packing method and GEMV as an optional packing method (about to be merged). You can either use the old commit or switch to AutoAWQ, which has a PyPi package that is easily installable:

pip install autoawq

@Narsil
Copy link
Collaborator

Narsil commented Sep 22, 2023

Hey thanks a LOT for this PR.

Before diving into the details, I'll drop some benchmark I did on single A100:

f16:

Screenshot from 2023-09-22 18-47-19

awq

Screenshot from 2023-09-22 18-45-35

gptq

Screenshot from 2023-09-22 18-49-11

These tell a very good story.

Baseline latency for AWQ is 15ms/per, 25ms/token for f16 and 20ms/token for exllama.
But more importantly, the baseline holds much better for AWQ at higher throughput, giving more throughput at constant latency.

F16 seems to be managing to reach higher throughput but this regime is not the main one for usual TGI deployments, where latency is the key.

I will probably make a bit more testing on different TP sharding layout and model sizes but this looks really good.

@@ -73,3 +73,5 @@ win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and
wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
xxhash==3.3.0 ; python_version >= "3.9" and python_version < "3.13"
yarl==1.9.2 ; python_version >= "3.9" and python_version < "3.13"
# Custom 4-bit GEMM AWQ kernels
git+https://github.com/mit-han-lab/llm-awq.git@f084f40bd996f3cf3a0633c1ad7d9d476c318aaa#subdirectory=awq/kernels
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there anyway we could move to versioned releases ?
Works currently it's a sha on main which is good enough for stability imo.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Narsil: Not sure about licensing issues, but can we copy over the kernels to TGI repo from AWQ repo? That's all this pip install is pulling in. I saw somewhere that you have a custom_kernels directory in the repo.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would recommend switching to autoawq==0.1.0 as this has support for Windows as well. The mit-han-lab kernels break on Windows without the modifications I made.

Copy link
Contributor Author

@abhinavkulkarni abhinavkulkarni Sep 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @casper-hansen.

@Narsil: Okay to have a dependency on autoawq==0.1.0? If yes, happy to switch from mit-han-lab/llm-awq repo to autoawq==0.1.0.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's try keep as-is for now.

autoawq seems too big of a dependency, most of the code is not relevant here, since we only need the kernel not the modeling code. Having small dependencies is better. (If you release a standoff only kernel version, I'd be happy to switch.

Pulling is OK, but it seems mit-han-lab is pretty barebones, therefore I don't feel we need to duplicate.

For the windows support, as our main layer of distribution is docker, it's not really that necessary to support windows builds for us.
Also looking at the fix PR : casper-hansen/AutoAWQ@e188472 (Is that correct ?) it seems the fix, is just removing the unecessary kernels.

@@ -282,6 +314,20 @@ def get_multi_weights_row(self, prefix: str, quantize: str):
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)

weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
elif quantize == "awq":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

Copy link
Contributor Author

@abhinavkulkarni abhinavkulkarni Sep 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Narsil: I'm keeping separate blocks for GPTQ and AWQ here as GPTQ has a lot of custom logic with regards to exllama, etc. Please resolve if this is okay with you.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough !

@Narsil
Copy link
Collaborator

Narsil commented Sep 22, 2023

Ok, my comments are mostly nits about code maintenance, but this PR looks great !

Thanks for it !

@casper-hansen
Copy link

casper-hansen commented Sep 23, 2023

@Narsil @abhinavkulkarni Has this PR been tested to be compatible with the 300+ AWQ repositories that @TheBloke has pushed with AutoAWQ?

  • Most notably, I want to ensure that the models can be loaded in TGI with the quant_config.json that is provided? Example: https://huggingface.co/TheBloke/Llama-2-13B-AWQ
  • Is this PR compatible with Windows? The original AWQ GEMM kernels are not compatible with Windows - I modified them to support Windows in AutoAWQ.

@abhinavkulkarni
Copy link
Contributor Author

@TheBloke: Can we get some clarity on naming the quantization config for GPTQ and AWQ?

You seem to name it as quantize_config.json for GPTQ and quant_config.json for AWQ.

TGI code expects it to be named quantize_config.json.

Is it possible for you to rename it to quantize_config.json across all your AWQ repos? Or create a symlink in these repos so that both the files are present?

@casper-hansen
Copy link

casper-hansen commented Sep 23, 2023

I would advise against renaming to quantize_config.json as that makes all his repositories incompatible with AutoAWQ, which was used to create the 300+ repositories.

EDIT: It will also break all other repositories that have implemented AutoAWQ.

@TheBloke
Copy link

Yeah that's because AutoGPTQ and AutoAWQ create that file with a different name. AutoGPTQ creates it as quantize_config.json, but AutoAWQ creats it as quant_config.json.

I'm happy to rename it across all, but it will require AutoAWQ updating to support that name (not quite sure why a different name was used by AutoAWQ)

@abhinavkulkarni
Copy link
Contributor Author

Okay, @TheBloke, @casper-hansen and @Narsil, please agree on one standard.

I'm happy to comply with whatever is decided. Two issues we need consensus on:

  1. Wether to have quantize_config.json for GPTQ and quant_config.json for AWQ or just one for both

  2. Secondly, uniformity in how we refer to quantization parameters - for e.g. in GPTQ quantize_config.json, we use bits but in quant_config.json of AWQ, we use w_bit. Same for groupsize for GPTQ vs group_size for AWQ.

Please let me know what you agree on and I'll modify the PR accordingly.

@TheBloke
Copy link

TheBloke commented Sep 23, 2023

@abhinavkulkarni GPTQ doesn't use groupsize, it's group_size there too.

IMHO the best option is for AutoAWQ to update to use quantize_config.json, and match the param names of AutoGPTQ. Ultimately the hope is that AutoAWQ will be integrated into Transformers, like AutoGPTQ was. If/when that happens, I imagine the Transformers team will ask for the same param names to be used - though they won't care about the filename because the data will need to go into config.json

Here is the params that Transformers expects in config.json for a GPTQ model:

    "quantization_config": {
        "bits": 4,
        "group_size": 128,
        "damp_percent": 0.1,
        "desc_act": true,
        "sym": true,
        "true_sequential": true,
        "quant_method": "gptq"
    }

Based on this, IMHO AutoAWQ should create a quantize_config.json matching this format:

"bits": 4,
"group_size": 128,
"zero_point": true,
"version": "GEMM"

Then that will be easily integrated into Transformers' config.json -> quantization_config section in future - with the addition of a "quant_method": "awq" key as well.

For backwards compatibility, AutoAWQ could also read from quant_config.json and still support the existing w_bits key name, so that support for existing models is still provided. Though I would update all my repo config files to the new format at such a time as the change was made.

@casper-hansen
Copy link

I know that Huggingface takes backward compatibility very seriously. Implementing this in AutoAWQ requires quite a bit of maintenance going forward and it will break all PyPi releases of AutoAWQ (v0.0.1, v0.0.2, v0.1.0) because it will not be able to read any Huggingface repository. Additionally, the oobabooga/text-generation-webui will also need to update their version because it will brick the functionality in the PR. I am also not sure if this breaks the vLLM support for AWQ.

I am not so sure that people will be happy with this change for the reasons that I just explained, but I would gladly accept PRs that refactor the current behavior and extract the logic into a QuantizeConfig class that implements the necessary logic for backward compatibility. However, going forward, we have to support quant_config.json for multiple versions and raise warnings in AutoAWQ before we phase the quant_config.json out (just like they do in Huggingface libraries). We can finally phase it out in the next major release: v0.2.0.

@TheBloke
Copy link

TheBloke commented Sep 23, 2023

vLLM already allows for both filenames:

   @classmethod
    def get_config_filenames(cls) -> List[str]:
        return [
            "quant_config.json",  # E.g., casperhansen/vicuna-7b-v1.5-awq
            "quantize_config.json",  # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq  # pylint: disable=line-too-long
        ]

And both key names:

 weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])

HF take backwards compatibility seriously once something is in a Hugging Face project - so they definitely won't want it to change after AWQ is merged. Before it's merged I don't think they'll care one way or the other. But they may well ask for the w_bits -> bits param change to be made in order for it to be merged, so I think you're going to end up doing at least that change eventually if you want it in Transformers.

So if this change is going to be made, I would say the time to do it is now - there's never a good time for a change that might break compatibility, but earlier is always far better than later

If the config file name does change, I'm happy to have both quantize_config and quant_config.json in all my repos for a while, which would allow ongoing support for prior AutoAWQ versions and any un-updated clients. Then I could phase it out in a a few weeks, once everyone is up-to-date. A bit like I've handled the many breaking llama.cpp changes.

@abhinavkulkarni
Copy link
Contributor Author

I agree with @TheBloke that it's better to make changes now than later.

So, how about this:

  1. @TheBloke, please write a "translator" to translate AutoAWQ quant_config.json to quantize_config.json and have both of them in the model repo. Please back-process all of your published models to have both of these files.

  2. @casper-hansen: Over a period of time, please deprecate quant_config.json and start publishing quantize_config.json in the format HF team expects.

Agreed?

@Narsil Narsil changed the base branch from main to dev September 25, 2023 07:57
@Narsil
Copy link
Collaborator

Narsil commented Sep 25, 2023

Ok, only nits are left. I will merge this into dev and add the nits I wanted to share and add an integration tests.

Thank you very much for this contribution !

@Narsil Narsil merged commit c35f39c into huggingface:dev Sep 25, 2023
@Narsil
Copy link
Collaborator

Narsil commented Sep 25, 2023

Totally agree with all the names/config discussion, which is why I suggested reusing all the code we could for awq to match gptq.

Also in this repo, unlike in transformers we might want to make the necessary changes to support the existing repos (provided they are not too hard to support as in little lines of code.)

Comment on lines +154 to +157
try:
g_idx = self.get_tensor(f"{prefix}.g_idx")
except RuntimeError:
g_idx = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you switch to an if statement ?

It needs to fail if this is missing for GPTQ (which it wouldn't here).

Copy link
Contributor Author

@abhinavkulkarni abhinavkulkarni Sep 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have pushed a new commit that switches try-catch to if-else. Thanks.

Comment on lines +204 to +210
try:
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
except RuntimeError:
g_idx = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

@@ -282,6 +314,20 @@ def get_multi_weights_row(self, prefix: str, quantize: str):
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)

weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
elif quantize == "awq":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough !

@@ -216,7 +233,7 @@ def get_tensor_shard(self, var, dim):
return tensor

def get_multi_weights_row(self, prefix: str, quantize: str):
if quantize == "gptq":
if quantize in "gptq":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if quantize in "gptq":
if quantize == "gptq":

@@ -73,3 +73,5 @@ win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and
wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
xxhash==3.3.0 ; python_version >= "3.9" and python_version < "3.13"
yarl==1.9.2 ; python_version >= "3.9" and python_version < "3.13"
# Custom 4-bit GEMM AWQ kernels
git+https://github.com/mit-han-lab/llm-awq.git@f084f40bd996f3cf3a0633c1ad7d9d476c318aaa#subdirectory=awq/kernels
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's try keep as-is for now.

autoawq seems too big of a dependency, most of the code is not relevant here, since we only need the kernel not the modeling code. Having small dependencies is better. (If you release a standoff only kernel version, I'd be happy to switch.

Pulling is OK, but it seems mit-han-lab is pretty barebones, therefore I don't feel we need to duplicate.

For the windows support, as our main layer of distribution is docker, it's not really that necessary to support windows builds for us.
Also looking at the fix PR : casper-hansen/AutoAWQ@e188472 (Is that correct ?) it seems the fix, is just removing the unecessary kernels.

@casper-hansen
Copy link

Also looking at the fix PR : casper-hansen/AutoAWQ@e188472 (Is that correct ?) it seems the fix, is just removing the unecessary kernels.

This part is not Windows support for the GEMM kernel, those are FasterTransformer kernels.

Here I added Windows support. It's super simple. MSVC does not support __asm__ __volatile__, so just remove the underscores and it works. casper-hansen/AutoAWQ@4c39a76

Narsil added a commit that referenced this pull request Sep 25, 2023
# Add AWQ quantization inference support

Fixes
#781

This PR (partially) adds support for AWQ quantization for inference.
More information on AWQ [here](https://arxiv.org/abs/2306.00978). In
general, AWQ is faster and more accurate than GPTQ, which is currently
supported by TGI.

This PR installs 4-bit GEMM custom CUDA kernels released by AWQ authors
(in `requirements.txt`, just one line change).

Quick way to test this PR would be bring up TGI as follows:

```
text-generation-server download-weights abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq

text-generation-launcher \
--huggingface-hub-cache ~/.cache/huggingface/hub/ \
--model-id abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq \
--trust-remote-code --port 8080 \
--max-input-length 2048 --max-total-tokens 4096 --max-batch-prefill-tokens 4096 \
--quantize awq
```

Please note:
* This PR was tested with FlashAttention v2 and vLLM.
* This PR adds support for AWQ inference, not quantizing the models.
That needs to be done outside of TGI, instructions

[here](https://github.com/mit-han-lab/llm-awq/tree/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa).
* This PR only adds support for `FlashLlama` models for now.
* Multi-GPU setup has not been tested. 
* No integration tests have been added so far, will add later if
maintainers are interested in this change.
* This PR can be tested on any of the models released

[here](https://huggingface.co/abhinavkulkarni?sort_models=downloads#models).

Please refer to the linked issue for benchmarks for

[abhinavkulkarni/meta-llama-Llama-2-7b-chat-hf-w4-g128-awq](https://huggingface.co/abhinavkulkarni/meta-llama-Llama-2-7b-chat-hf-w4-g128-awq)
vs

[TheBloke/Llama-2-7b-Chat-GPTQ](https://huggingface.co/TheBloke/Llama-2-7b-Chat-GPTQ).

Please note, AWQ has released faster (and in case of Llama, fused)
kernels for 4-bit GEMM, currently at the top of the `main` branch at
https://github.com/mit-han-lab/llm-awq, but this PR uses an older commit
that has been tested to work. We can switch to latest commit later on.

## Who can review?

@OlivierDehaene OR @Narsil

---------



# What does this PR do?

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->

---------

Co-authored-by: Abhinav M Kulkarni <abhinavkulkarni@gmail.com>
Co-authored-by: Abhinav Kulkarni <abhinav@concentric.ai>
@Narsil
Copy link
Collaborator

Narsil commented Sep 26, 2023

Oh nice Thanks for the tip !

@Archmilio
Copy link

Archmilio commented Oct 8, 2023

There is only a 10% difference in inference between llama2 7b model quantize model and basic model on a single A100 GPU. In the autoAWQ inference example, there is a fuse_layer option that can significantly affect the speed depending on whether it is set to 'true/false'. I am curious if this option is always set to true in TGI

@RonanKMcGovern
Copy link

Is flash attention v2 installed and on?

@Archmilio
Copy link

I tested using the tgi-1.1.0 Docker image, which has Flash Attention v2 installed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add support for AWQ quantized models
6 participants