Skip to content

Conversation

varun-sundar-rabindranath
Copy link
Contributor

@varun-sundar-rabindranath varun-sundar-rabindranath commented Jul 18, 2025

Purpose

Mistral-Small-3.1-24B-Instruct-2503 is a multi-modal model. The model has 3 sub-modules,

  • language_model
  • multi_modal_projector
  • vision_tower
    The logits_processor module is contained within the language_model module.
    on main, due to the nested structure, LoRA layer replacement logic fails to replace the LogitsProcessor with LogitsProcessWithLoRA. This results is incorrect outputs when the model is enabled with LoRA.

Test Plan

lm_eval results with and without LoRA.

Test Result

base command : vllm serve mistralai/Mistral-Small-3.1-24B-Instruct-2503 --max-model-len 10K --port 9010 --enforce-eager -tp 2

lm_eval command : lm_eval --model local-completions --model_args model=mistralai/Mistral-Small-3.1-24B-Instruct-2503,base_url=http://0.0.0.0:9010/v1/completions,num_concurrent=500,tokenized_requests=False --tasks gsm8k --num_fewshot 5

main:
command = base_command

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8976|±  |0.0083|
|     |       |strict-match    |     5|exact_match|↑  |0.8886|±  |0.0087|

command = base_command + --enable-lora

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.7885|±  |0.0112|
|     |       |strict-match    |     5|exact_match|↑  |0.7794|±  |0.0114|

This PR:
command = base_command

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8976|±  |0.0083|
|     |       |strict-match    |     5|exact_match|↑  |0.8901|±  |0.0086|

command = base_command + --enable-lora

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8984|±  |0.0083|
|     |       |strict-match    |     5|exact_match|↑  |0.8901|±  |0.0086|

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@varun-sundar-rabindranath
Copy link
Contributor Author

cc @mgoin @jeejeelee

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request fixes an issue with LoRA layer replacement for nested multi-modal models like Mistral-Small-3.1. The changes correctly locate the logits_processor within its parent submodule and also properly discover embedding_modules defined in submodules.

The logic seems sound and directly addresses the described bug. I've found one area for improvement in vllm/lora/utils.py regarding a redundant function call, which I've commented on. Otherwise, the changes look good.

@mgoin mgoin added the bug Something isn't working label Jul 18, 2025
@mgoin mgoin added this to the v0.10.0 milestone Jul 18, 2025
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Surprising we haven't run into this before, this fix makes a lot of sense. Thanks a lot!

@mgoin mgoin added multi-modality Related to multi-modality (#4194) ready ONLY add when PR is ready to merge/full CI is needed labels Jul 18, 2025
@jeejeelee
Copy link
Collaborator

Which model class does this model use? Is it PixtralForConditionalGeneration? But it looks like this model doesn't support LoRA yet, see:

| `PixtralForConditionalGeneration` | Mistral 3 (Mistral format), Pixtral (Mistral format) | T + I<sup>+</sup> | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, `mistralai/Pixtral-12B-2409`, etc. | | ✅︎ | ✅︎ |

Copy link
Collaborator

@jeejeelee jeejeelee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC, for multi-modal models, we don't explicitly set embedding_modules, so LogitsProcessWithLoRA never gets instantiated, maybe I remember wrong

@varun-sundar-rabindranath
Copy link
Contributor Author

Hi @jeejeelee, this is the model definition,

(VllmWorker rank=0 pid=3077698) model : Mistral3ForConditionalGeneration(
(VllmWorker rank=0 pid=3077698)   (vision_tower): PixtralHFVisionModel(
(VllmWorker rank=0 pid=3077698)     (patch_conv): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
(VllmWorker rank=0 pid=3077698)     (ln_pre): RMSNorm(hidden_size=1024, eps=1e-05)
(VllmWorker rank=0 pid=3077698)     (transformer): PixtralHFTransformer(
(VllmWorker rank=0 pid=3077698)       (layers): ModuleList(
(VllmWorker rank=0 pid=3077698)         (0-23): 24 x PixtralHFTransformerBlock(
(VllmWorker rank=0 pid=3077698)           (attention_norm): RMSNorm(hidden_size=1024, eps=1e-05)
(VllmWorker rank=0 pid=3077698)           (attention): PixtralHFAttention(
(VllmWorker rank=0 pid=3077698)             (qkv_proj): QKVParallelLinear(in_features=1024, output_features=768, bias=False, tp_size=4, gather_output=False)
(VllmWorker rank=0 pid=3077698)             (o_proj): RowParallelLinear(input_features=256, output_features=1024, bias=False, tp_size=4, reduce_results=True)
(VllmWorker rank=0 pid=3077698)           )
(VllmWorker rank=0 pid=3077698)           (feed_forward): PixtralHFMLP(
(VllmWorker rank=0 pid=3077698)             (gate_up_proj): MergedColumnParallelLinear(in_features=1024, output_features=2048, bias=False, tp_size=4, gather_output=False)
(VllmWorker rank=0 pid=3077698)             (down_proj): RowParallelLinear(input_features=1024, output_features=1024, bias=False, tp_size=4, reduce_results=True)
(VllmWorker rank=0 pid=3077698)             (act_and_mul): SiluAndMul()
(VllmWorker rank=0 pid=3077698)           )
(VllmWorker rank=0 pid=3077698)           (ffn_norm): RMSNorm(hidden_size=1024, eps=1e-05)
(VllmWorker rank=0 pid=3077698)         )
(VllmWorker rank=0 pid=3077698)       )
(VllmWorker rank=0 pid=3077698)     )
(VllmWorker rank=0 pid=3077698)     (patch_positional_embedding): PixtralRotaryEmbedding()
(VllmWorker rank=0 pid=3077698)   )
(VllmWorker rank=0 pid=3077698)   (multi_modal_projector): Mistral3MultiModalProjector(
(VllmWorker rank=0 pid=3077698)     (norm): RMSNorm(hidden_size=1024, eps=1e-05)
(VllmWorker rank=0 pid=3077698)     (patch_merger): Mistral3PatchMerger(
(VllmWorker rank=0 pid=3077698)       (merging_layer): Linear(in_features=4096, out_features=1024, bias=False)
(VllmWorker rank=0 pid=3077698)     )
(VllmWorker rank=0 pid=3077698)     (linear_1): ColumnParallelLinear(in_features=1024, output_features=1280, bias=False, tp_size=4, gather_output=False)
(VllmWorker rank=0 pid=3077698)     (act): GELU(approximate='none')
(VllmWorker rank=0 pid=3077698)     (linear_2): RowParallelLinear(input_features=1280, output_features=5120, bias=False, tp_size=4, reduce_results=True)
(VllmWorker rank=0 pid=3077698)   )
(VllmWorker rank=0 pid=3077698)   (language_model): LlamaForCausalLM(
(VllmWorker rank=0 pid=3077698)     (model): LlamaModel(
(VllmWorker rank=0 pid=3077698)       (embed_tokens): VocabParallelEmbeddingWithLoRA(
(VllmWorker rank=0 pid=3077698)         (base_layer): VocabParallelEmbedding(num_embeddings=32832, embedding_dim=5120, org_vocab_size=131072, num_embeddings_padded=131328, tp_size=4)
(VllmWorker rank=0 pid=3077698)       )
(VllmWorker rank=0 pid=3077698)       (layers): ModuleList(
(VllmWorker rank=0 pid=3077698)         (0-39): 40 x LlamaDecoderLayer(
(VllmWorker rank=0 pid=3077698)           (self_attn): LlamaAttention(
(VllmWorker rank=0 pid=3077698)             (qkv_proj): MergedQKVParallelLinearWithLoRA(
(VllmWorker rank=0 pid=3077698)               (base_layer): QKVParallelLinear(in_features=5120, output_features=1536, bias=False, tp_size=4, gather_output=False)
(VllmWorker rank=0 pid=3077698)             )
(VllmWorker rank=0 pid=3077698)             (o_proj): RowParallelLinearWithLoRA(
(VllmWorker rank=0 pid=3077698)               (base_layer): RowParallelLinear(input_features=1024, output_features=5120, bias=False, tp_size=4, reduce_results=True)
(VllmWorker rank=0 pid=3077698)             )
(VllmWorker rank=0 pid=3077698)             (rotary_emb): RotaryEmbedding(head_size=128, rotary_dim=128, max_position_embeddings=131072, base=1000000000.0, is_neox_style=True)
(VllmWorker rank=0 pid=3077698)             (attn): Attention(head_size=128, num_heads=8, num_kv_heads=2, scale=0.08838834764831845, backend=FlashAttentionImpl)
(VllmWorker rank=0 pid=3077698)           )
(VllmWorker rank=0 pid=3077698)           (mlp): LlamaMLP(
(VllmWorker rank=0 pid=3077698)             (gate_up_proj): MergedColumnParallelLinearWithLoRA(
(VllmWorker rank=0 pid=3077698)               (base_layer): MergedColumnParallelLinear(in_features=5120, output_features=16384, bias=False, tp_size=4, gather_output=False)
(VllmWorker rank=0 pid=3077698)             )
(VllmWorker rank=0 pid=3077698)             (down_proj): RowParallelLinearWithLoRA(
(VllmWorker rank=0 pid=3077698)               (base_layer): RowParallelLinear(input_features=8192, output_features=5120, bias=False, tp_size=4, reduce_results=True)
(VllmWorker rank=0 pid=3077698)             )
(VllmWorker rank=0 pid=3077698)             (act_fn): SiluAndMul()
(VllmWorker rank=0 pid=3077698)           )
(VllmWorker rank=0 pid=3077698)           (input_layernorm): RMSNorm(hidden_size=5120, eps=1e-05)
(VllmWorker rank=0 pid=3077698)           (post_attention_layernorm): RMSNorm(hidden_size=5120, eps=1e-05)
(VllmWorker rank=0 pid=3077698)         )
(VllmWorker rank=0 pid=3077698)       )
(VllmWorker rank=0 pid=3077698)       (norm): RMSNorm(hidden_size=5120, eps=1e-05)
(VllmWorker rank=0 pid=3077698)     )
(VllmWorker rank=0 pid=3077698)     (lm_head): ParallelLMHead(num_embeddings=32832, embedding_dim=5120, org_vocab_size=131072, num_embeddings_padded=131328, tp_size=4)
(VllmWorker rank=0 pid=3077698)     (logits_processor): LogitsProcessorWithLoRA(
(VllmWorker rank=0 pid=3077698)       (base_layer): LogitsProcessor(vocab_size=131328, org_vocab_size=131072, scale=1.0, logits_as_input=False)
(VllmWorker rank=0 pid=3077698)     )
(VllmWorker rank=0 pid=3077698)   )
(VllmWorker rank=0 pid=3077698) )

It uses the Mistral3ForConditionalGeneration - it has some Pixtral* blocks in it, but I am not sure how related it is to PixtralForConditionalGeneration.
However the fix is for LlamaForCausalLM language-model block. The replacement logic couldn't detect the logits_processor in that block.

@jeejeelee
Copy link
Collaborator

jeejeelee commented Jul 18, 2025

The difference is between Mistral format and HF Transformers.
Thanks for this fix. If I have a chance later, I'll look into why LogitsProcessorWithLoRA gets instantiated

Varun Sundar Rabindranath added 7 commits July 18, 2025 13:48
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
@mgoin mgoin enabled auto-merge (squash) July 18, 2025 22:29
@simon-mo simon-mo disabled auto-merge July 19, 2025 04:14
@simon-mo simon-mo merged commit 9ffe905 into vllm-project:main Jul 19, 2025
67 of 70 checks passed
hj-mistral pushed a commit to hj-mistral/vllm that referenced this pull request Jul 19, 2025
…m-project#21183)

Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Signed-off-by: Himanshu Jaju <hj@mistral.ai>
LyrisZhong pushed a commit to LyrisZhong/vllm that referenced this pull request Jul 23, 2025
…m-project#21183)

Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
avigny pushed a commit to avigny/vllm that referenced this pull request Jul 31, 2025
…m-project#21183)

Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Signed-off-by: avigny <47987522+avigny@users.noreply.github.com>
x22x22 pushed a commit to x22x22/vllm that referenced this pull request Aug 5, 2025
…m-project#21183)

Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Signed-off-by: x22x22 <wadeking@qq.com>
Pradyun92 pushed a commit to Pradyun92/vllm that referenced this pull request Aug 6, 2025
…m-project#21183)

Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
…m-project#21183)

Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
…m-project#21183)

Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
…m-project#21183)

Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Signed-off-by: Paul Pak <paulpak58@gmail.com>
taneem-ibrahim pushed a commit to taneem-ibrahim/vllm that referenced this pull request Aug 14, 2025
…m-project#21183)

Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
…m-project#21183)

Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Signed-off-by: Diego-Castan <diego.castan@ibm.com>
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 27, 2025
…m-project#21183)

Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
googlercolin pushed a commit to googlercolin/vllm that referenced this pull request Aug 29, 2025
…m-project#21183)

Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working multi-modality Related to multi-modality (#4194) ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants