Skip to content

Conversation

mobicham
Copy link
Contributor

@mobicham mobicham commented Jun 6, 2025

[Follow-up to #19147 due to DCO rebasing issues]

Purpose

The goal of this small PR is to fix loading torchao models where not all the layers have been quantized.

The current implementation doesn't keep track of the skipped layers defined in config["modules_to_not_convert"]. As a result, quantized VL models where the vision head is not quantized results in a crash.

The PR also includes logic to skip layers defined in module_fqn_to_config. Currently, if a module is skipped in module_fqn_to_config, loading the model in vLLM would crash.

Also, made a quick fix to improve loading speed by avoiding creating an nn.Linear with the full tensor shape.

Test Plan

Dependencies

USE_CPP=0 pip install git+https://github.com/pytorch/ao -v --no-build-isolation --use-pep517;
pip install git+https://github.com/mobiusml/gemlite/;

Code

Loading a VL model with unquantized vision modules

import torch
from vllm import LLM
from vllm.sampling_params import SamplingParams

model_id = "mobiuslabsgmbh/Qwen2.5-VL-7B-Instruct_gemlite-ao_a16w4_gs_128_pack_32bit"
processor_args = {
    'limit_mm_per_prompt': {"image": 3}, 
    'mm_processor_kwargs': {"min_pixels": 28 * 28, "max_pixels": 1280 * 28 * 28},
    'disable_mm_preprocessor_cache': False,
}

llm = LLM(model=model_id, gpu_memory_utilization=0.9, dtype=torch.float16, max_model_len=4096, 
            max_num_batched_tokens=4096, **processor_args) 

Skip module example

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig

model_id = "meta-llama/Llama-3.1-8B-Instruct"

from torchao.quantization import Int4WeightOnlyConfig, Int8WeightOnlyConfig, ModuleFqnToConfig
config = Int4WeightOnlyConfig(group_size=128)
config2 = None 

quant_config = ModuleFqnToConfig({"_default": config, "model.layers.0.self_attn.o_proj": config2, "model.layers.13.self_attn.o_proj": config2})

quantization_config = TorchAoConfig(quant_type=quant_config)
quantized_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16, quantization_config=quantization_config)

print("quantized model:", quantized_model)
tokenizer = AutoTokenizer.from_pretrained(model_id)

q_model_id = 'quant_model_test'
quantized_model.save_pretrained(q_model_id, safe_serialization=False) 
tokenizer.save_pretrained(q_model_id)
######################################################################################

import torch
from vllm import LLM
from vllm.sampling_params import SamplingParams
llm = LLM(model="quant_model_test", dtype=torch.bfloat16) 
sampling_params = SamplingParams(max_tokens=1024, temperature=0.5, repetition_penalty=1.1, ignore_eos=False)
messages = [{"content": "You are a helpful assistant", "role":"system"}, {"content":"Solve this equation x^2 + 1 = -1.", "role":"user"}]
outputs = llm.chat(messages, sampling_params, chat_template=llm.get_tokenizer().chat_template)
print(outputs[0].outputs[0].text)

Test Result

The model should load successfully.

@jerryzh168

mobicham added 2 commits June 6, 2025 09:43
…atures) for faster loading

Signed-off-by: mobicham <hicham@mobiuslabs.com>
Signed-off-by: mobicham <hicham@mobiuslabs.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @mobicham, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

Summary of Changes

Hello team,

Gemini here with a summary of this pull request. This PR, titled "Fix TorchAOConfig skip layers", aims to address issues encountered when loading models quantized with TorchAO where not all layers have been converted. Specifically, it fixes crashes that occur when loading models that define certain modules to be skipped from quantization, either via the modules_to_not_convert list or by setting their configuration to None in module_fqn_to_config. The PR ensures that vLLM correctly identifies and handles these unquantized layers during loading. Additionally, a minor optimization is included to improve the loading speed for TorchAO models by avoiding the creation of large dummy nn.Linear modules during parameter quantization.

Highlights

  • TorchAO Quantization Fix: Resolves issues loading TorchAO quantized models that contain layers explicitly marked to be skipped from quantization.
  • Support for Skipped Modules: Adds logic to correctly handle modules listed in modules_to_not_convert and modules with a None configuration in module_fqn_to_config, ensuring they are treated as unquantized.
  • Loading Speed Optimization: Improves the speed of loading TorchAO models by optimizing the creation of dummy nn.Linear modules used during parameter quantization.
  • New Test Case: Adds a test case specifically for loading a Qwen-VL model quantized with TorchAO, which helps validate the fix for models with unquantized components like vision heads.

Changelog

  • tests/quantization/test_torchao.py
    • Added a new test function test_qwenvl_int8wo_model_loading_with_params (lines 62-74) to test loading a Qwen-VL model quantized with TorchAO, targeting the scenario with potentially skipped layers.
  • vllm/model_executor/layers/quantization/torchao.py
    • Modified the TorchAOConfig constructor to accept an optional skip_modules list (lines 23-25, 41).
    • Updated the from_config class method to parse modules_to_not_convert and identify modules set to None in module_fqn_to_config, adding them to the skip_modules list (lines 78-85).
    • In get_quant_method, added a check to return UnquantizedLinearMethod() if the current module's prefix is in the skip_modules list (lines 96-97).
    • Ensured the skip_modules list is passed down when creating nested TorchAOConfig instances within get_quant_method (line 105).
    • Optimized torchao_quantize_param_data by creating a small nn.Linear(1, 1) and manually setting in_features and out_features instead of using the full parameter shape directly (lines 129-131).
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively addresses the issue of loading TorchAO models with partially quantized layers, particularly for Vision-Language models. The changes to handle modules_to_not_convert and module_fqn_to_config for skipping layers are well-implemented. The added test case for a Qwen-VL model is a good addition, and the optimization in torchao_quantize_param_data to reduce memory allocation during dummy linear layer creation is a nice improvement.

I have one suggestion regarding the string matching logic for skipping modules, which could be made more robust to prevent potential over-matching. Overall, this is a valuable fix.

Summary of Findings

  • Module Skipping Logic Robustness: The logic for determining whether to skip a module (any(s in prefix for s in self.skip_modules)) uses a general substring check. This could potentially lead to over-matching if a skip pattern is a substring of an unrelated module's FQN (e.g., skipping "layer.1" might unintentionally affect "layer.10"). A more precise FQN-aware prefix matching or exact matching would be more robust.
  • Test Coverage: A new test case (test_qwenvl_int8wo_model_loading_with_params) was added, which is good for verifying the fix for VL models with unquantized vision modules.
  • Performance Improvement: The change in torchao_quantize_param_data to initialize nn.Linear with minimal dimensions (1,1) and then update in_features and out_features is a good optimization to reduce temporary memory allocation.

Merge Readiness

The pull request is well-structured and addresses the core issues effectively. However, there's one medium-severity concern regarding the robustness of the module skipping logic that should be discussed and potentially addressed. Once that point is clarified or resolved, the PR should be in good shape for merging. As an AI, I am not authorized to approve pull requests; please ensure further review and approval from the maintainers.

Signed-off-by: mobicham <hicham@mobiuslabs.com>
Copy link

github-actions bot commented Jun 6, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

mobicham added 4 commits June 6, 2025 11:36
Signed-off-by: mobicham <hicham@mobiuslabs.com>
Signed-off-by: mobicham <hicham@mobiuslabs.com>
Signed-off-by: mobicham <hicham@mobiuslabs.com>
@houseroad
Copy link
Collaborator

cc: @jerryzh168

@jerryzh168
Copy link
Contributor

@mobicham thanks for the fix, can you talk a bit more about qkv fusion that you mentioned before? still didn't quite get it

Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The q,k,v fusion issue you mentioned makes sense, does this PR fix that?

@mobicham
Copy link
Contributor Author

mobicham commented Jun 7, 2025

@drisspg only if:

  • They have the same module-level quant settings.
  • The ao implementation correctly implements slice and copy_.

Otherwise, there's no clean way to merge qkv if they don't have the quant settings. Moreover, the merging is not happening in TorchAOConfig, it's happening in the QKV linear modules.

The main focus of this PR is to handle layer skipping for layers that were not quantized though. So it's simply checking in the config if the prefix matches the skipped layers defined in the config.

@drisspg
Copy link
Contributor

drisspg commented Jun 7, 2025

Ohhh yea, fwiw I am planning on writing a little doc in AO on how to get write a subclass that will work with VLLM, and the slice and copy is the main point.

I guess it feels like if someone skipped one of the q,k, projections then we shud skip (via ModFQNconfig or skip list) the stacked variant. Was more curious if this is tested and expected to work in this PR

@mobicham
Copy link
Contributor Author

mobicham commented Jun 7, 2025

I guess it feels like if someone skipped one of the q,k, projections then we shud skip (via ModFQNconfig or skip list) the stacked variant. Was more curious if this is tested and expected to work in this PR

I see! That can't work with the current code unfortunately because the metadata will mismatch during the slice/copy.
However, maybe it's worth mentioning it in the doc or throw a warning if q,k,v don't share the same config. Same thing for the MLP layer by the way.

@houseroad
Copy link
Collaborator

Actually DCO is only a soft requirement. :-)

Signed-off-by: mobicham <hicham@mobiuslabs.com>
@mgoin mgoin added bug Something isn't working quantization ready ONLY add when PR is ready to merge/full CI is needed labels Jun 9, 2025
Signed-off-by: mobicham <hicham@mobiuslabs.com>
@mobicham
Copy link
Contributor Author

Anything else to have this merged? Thank you!

Copy link
Collaborator

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you rebase to the latest main and fix the pre-commit linter?

mobicham and others added 3 commits June 12, 2025 08:03
Signed-off-by: mobicham <hicham@mobiuslabs.com>
Signed-off-by: mobicham <hicham@mobiuslabs.com>
@mobicham
Copy link
Contributor Author

@houseroad sorry didn't see that, fixed, thank you!

@houseroad houseroad merged commit 96846bb into vllm-project:main Jun 12, 2025
71 checks passed
minpeter pushed a commit to minpeter/vllm that referenced this pull request Jun 24, 2025
Signed-off-by: mobicham <hicham@mobiuslabs.com>
Signed-off-by: minpeter <kali2005611@gmail.com>
xjpang pushed a commit to xjpang/vllm that referenced this pull request Jun 30, 2025
Signed-off-by: mobicham <hicham@mobiuslabs.com>
wseaton pushed a commit to wseaton/vllm that referenced this pull request Jun 30, 2025
Signed-off-by: mobicham <hicham@mobiuslabs.com>
avigny pushed a commit to avigny/vllm that referenced this pull request Jul 31, 2025
Signed-off-by: mobicham <hicham@mobiuslabs.com>
Signed-off-by: avigny <47987522+avigny@users.noreply.github.com>
googlercolin pushed a commit to googlercolin/vllm that referenced this pull request Aug 29, 2025
Signed-off-by: mobicham <hicham@mobiuslabs.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working quantization ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants