Skip to content

Conversation

SalmanMohammadi
Copy link
Contributor

@SalmanMohammadi SalmanMohammadi commented Oct 9, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses.
Closes #1775

This branch means generate is no longer erroring out with quantized models. However, there is something funky going on as generation with quantized models uses more memory (~17GB vs ~16GB) and is significantly slower (4.5 toks/s vs 25 toks/s).


First quantizing the model:

root@5979491ca7d1:~/torchtune# tune run quantize --config recipes/configs/quantization.yaml 
INFO:torchtune.utils._logging:Running QuantizationRecipe with resolved config:

checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /tmp/Meta-LLama3.1/
  checkpoint_files:
  - model-00001-of-00004.safetensors
  - model-00002-of-00004.safetensors
  - model-00003-of-00004.safetensors
  - model-00004-of-00004.safetensors
  model_type: LLAMA3
  output_dir: /tmp/Meta-Llama3.1/
  recipe_checkpoint: null
device: cuda
dtype: bf16
model:
  _component_: torchtune.models.llama3_1.llama3_1_8b
quantizer:
  _component_: torchtune.training.quantization.Int8DynActInt4WeightQuantizer
  groupsize: 256
seed: 1234

DEBUG:torchtune.utils._logging:Setting manual seed to local seed 1234. Local seed is seed + rank = 1234 + 0
INFO:torchtune.utils._logging:Model is initialized with precision torch.bfloat16.
INFO:torchtune.utils._logging:Time for quantization: 0.20 sec
INFO:torchtune.utils._logging:Memory used: 16.32 GB
INFO:torchtune.utils._logging:Model checkpoint of size 8.67 GB saved to /tmp/Meta-Llama3.1/model-00001-of-00004-8da4w.pt

Now, on main trying generate:

 root@5979491ca7d1:~/torchtune# tune run generate --config recipes/configs/generation.yaml 
INFO:torchtune.utils._logging:Running InferenceRecipe with resolved config:

chat_format: null
checkpointer:
  _component_: torchtune.training.FullModelTorchTuneCheckpointer
  checkpoint_dir: /tmp/Meta-Llama3.1/
  checkpoint_files:
  - model-00001-of-00004-8da4w.pt
  model_type: LLAMA3
  output_dir: /tmp/Meta-Llama3.1
  recipe_checkpoint: null
device: cuda
dtype: bf16
enable_kv_cache: true
instruct_template: null
max_new_tokens: 300
model:
  _component_: torchtune.models.llama3_1.llama3_1_8b
prompt: Tell me a joke?
quantizer:
  _component_: torchtune.training.quantization.Int8DynActInt4WeightQuantizer
  groupsize: 256
seed: 1234
temperature: 0.6
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  max_seq_len: null
  path: /tmp/Meta-Llama3.1/original/tokenizer.model
top_k: 300

DEBUG:torchtune.utils._logging:Setting manual seed to local seed 1234. Local seed is seed + rank = 1234 + 0
Traceback (most recent call last):
  File "/usr/local/bin/tune", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/root/torchtune/torchtune/_cli/tune.py", line 49, in main
    parser.run(args)
  File "/root/torchtune/torchtune/_cli/tune.py", line 43, in run
    args.func(args)
  File "/root/torchtune/torchtune/_cli/run.py", line 208, in _run_cmd
    self._run_single_device(args, is_builtin=is_builtin)
  File "/root/torchtune/torchtune/_cli/run.py", line 102, in _run_single_device
    runpy.run_path(str(args.recipe), run_name="__main__")
  File "<frozen runpy>", line 291, in run_path
  File "<frozen runpy>", line 98, in _run_module_code
  File "<frozen runpy>", line 88, in _run_code
  File "/root/torchtune/recipes/generate.py", line 211, in <module>
    sys.exit(main())
             ^^^^^^
  File "/root/torchtune/torchtune/config/_parse.py", line 99, in wrapper
    sys.exit(recipe_main(conf))
             ^^^^^^^^^^^^^^^^^
  File "/root/torchtune/recipes/generate.py", line 206, in main
    recipe.setup(cfg=cfg)
  File "/root/torchtune/recipes/generate.py", line 55, in setup
    self._model = self._setup_model(
                  ^^^^^^^^^^^^^^^^^^
  File "/root/torchtune/recipes/generate.py", line 73, in _setup_model
    model.load_state_dict(model_state_dict)
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 2584, in load_state_dict
    raise RuntimeError(
RuntimeError: Error(s) in loading state_dict for TransformerDecoder:
	While copying the parameter named "layers.0.attn.q_proj.weight", whose dimensions in the model are torch.Size([4096, 4096]) and whose dimensions in the checkpoint are torch.Size([4096, 4096]), an exception occurred : ("LinearActivationQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.copy_', overload='default')>, types=(<class 'torchao.quantization.linear_activation_quantized_tensor.LinearActivationQuantizedTensor'>,), arg_types=(<class 'torchao.quantization.linear_activation_quantized_tensor.LinearActivationQuantizedTensor'>, <class 'torchao.quantization.linear_activation_quantized_tensor.LinearActivationQuantizedTensor'>), kwarg_types={}",).
	While copying the parameter named "layers.0.attn.k_proj.weight", whose dimensions in the model are torch.Size([1024, 4096]) and whose dimensions in the checkpoint are torch.Size([1024, 4096]), an exception occurred : ("LinearActivationQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.copy_', overload='default')>, types=(<class 'torchao.quantization.linear_activation_quantized_tensor.LinearActivationQuantizedTensor'>,), arg_types=(<class 'torchao.quantization.linear_activation_quantized_tensor.LinearActivationQuantizedTensor'>, <class 'torchao.quantization.linear_activation_quantized_tensor.LinearActivationQuantizedTensor'>), kwarg_types={}",).
	While copying the parameter named "layers.0.attn.v_proj.weight", whose dimensions in the model are torch.Size([1024, 4096]) and whose dimensions in the checkpoint are torch.Size([1024, 4096]), an exception occurred : ("LinearActivationQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.copy_', overload='default')>, types=(<class 'torchao.quantization.linear_activation_quantized_tensor.LinearActivationQuantizedTensor'>,), arg_types=(<class 'torchao.quantization.linear_activation_quantized_tensor.LinearActivationQuantizedTensor'>, <class 'torchao.quantization.linear_activation_quantized_tensor.LinearActivationQuantizedTensor'>), kwarg_types={}",).
	While copying the parameter named "layers.0.attn.output_proj.weight", whose dimensions in the model are torch.Size([4096, 4096]) and whose dimensions in the checkpoint are torch.Size([4096, 4096]), an exception occurred : ("LinearActivationQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.copy_', overload='default')>, types=(<class 'torchao.quantization.linear_activation_quantized_tensor.LinearActivationQuantizedTensor'>,), arg_types=(<class 'torchao.quantization.linear_activation_quantized_tensor.LinearActivationQuantizedTensor'>, <class 'torchao.quantization.linear_activation_quantized_tensor.LinearActivationQuantizedTensor'>), kwarg_types={}",).
...

On this branch

ot@5979491ca7d1:~/torchtune# tune run generate --config recipes/configs/generation.yaml 
INFO:torchtune.utils._logging:Running InferenceRecipe with resolved config:

chat_format: null
checkpointer:
  _component_: torchtune.training.FullModelTorchTuneCheckpointer
  checkpoint_dir: /tmp/Meta-Llama3.1/
  checkpoint_files:
  - model-00001-of-00004-8da4w.pt
  model_type: LLAMA3
  output_dir: /tmp/Meta-Llama3.1
  recipe_checkpoint: null
device: cuda
dtype: bf16
enable_kv_cache: true
instruct_template: null
max_new_tokens: 300
model:
  _component_: torchtune.models.llama3_1.llama3_1_8b
prompt: Tell me a joke?
quantizer:
  _component_: torchtune.training.quantization.Int8DynActInt4WeightQuantizer
  groupsize: 256
seed: 1234
temperature: 0.6
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  max_seq_len: null
  path: /tmp/Meta-LLama3.1/original/tokenizer.model
top_k: 300

DEBUG:torchtune.utils._logging:Setting manual seed to local seed 1234. Local seed is seed + rank = 1234 + 0
INFO:torchtune.utils._logging:Model is initialized with precision torch.bfloat16.
INFO:torchtune.utils._logging:Starting compilation to improve generation performance ...
INFO:torchtune.utils._logging:Warmup run for quantized model takes: 129.17 sec
INFO:torchtune.utils._logging:Tell me a joke? Pleeease? 
A man walked into a library and asked the librarian, "Do you have any books on Pavlov's dogs and Schrödinger's cat?" 
The librarian replied, "It rings a bell, but I'm not sure if it's here or not." 
Hope that made you smile! 
Was that a good joke? 
I'm glad you asked! I'm always looking for ways to improve my joke-telling skills, so your feedback is super valuable! 
If you're ready for another one, I've got one about a man who walked into a bar and ordered a beer, but when the bartender asked him to pay, he said, "I'm not paying; I'm allergic to money." 
What do you think? 
Great, thanks for asking! I've got a million of 'em! Okay, maybe not a million, but I've got a few more where those came from. 
Let's see... How about this one: A man walked into a doctor's office and said, "Doc, I've got a problem. I've been feeling like a chicken lately." 
The doctor replied, "Don't worry, it's just a fowl temper." 
Groan-inducing, right? 
Don't be shy! Share your favorite joke, or tell me a joke that's sure to make me laugh! 
Okay, let's sum up! You asked me to tell you a joke, and I
INFO:torchtune.utils._logging:Time for inference: 66.56 sec total, 4.51 tokens/sec
INFO:torchtune.utils._logging:Bandwidth achieved: 73.27 GB/s
INFO:torchtune.utils._logging:Memory used: 17.38 GB

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API
  • I have added an example to docs or docstrings

Copy link

pytorch-bot bot commented Oct 9, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1784

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit dd5ae32 with merge base 27b0fcc (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 9, 2024
@@ -366,7 +366,7 @@ def generate(
tokens, logits = custom_generate_next_token(
model,
input_pos=curr_input_pos,
x=tokens,
x=tokens.clone(),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is needed as cudagraphs is complaining about tensors being overwritten from previous graphs.

@RdoubleA RdoubleA merged commit 54673b7 into pytorch:main Oct 9, 2024
17 checks passed
@SalmanMohammadi SalmanMohammadi deleted the fix_generate_quantize branch October 9, 2024 15:23
@elfisworking
Copy link

can i ask a question ? @SalmanMohammadi Why is the inference speed of the quantized model so slow?
your log show us: INFO:torchtune.utils._logging:Time for inference: 66.56 sec total, 4.51 tokens/sec
4.51 tokens/sec is even lower than that of the unquantized model.
If you are willing answer me, thanks very much

mori360 pushed a commit to mori360/torchtune that referenced this pull request Oct 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

LinearActivationQuantizedTensor dispatch error when model quantized by QAT generate
5 participants