Skip to content

👁️ [GRPO] Add VLM training capabilities to the trainer #3072

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 122 commits into from
Jul 23, 2025

Conversation

CompN3rd
Copy link
Contributor

What does this PR do?

This is an attempt at addressing #2917 .
An associated unittest has been added and less "toy-examply" trainings seem to maximize rewards as well, but I don't have final models yet.

Issues regarding 4-bit quantization and applying liger kernel to vlm models still remain (marked with TODO in comments) and I would appreciate input on how to tackle/further debug them.

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@qgallouedec
Copy link
Member

Thanks! Have you tried to fine-tune a VLM with the trainer? Do you have results to share?

@CompN3rd
Copy link
Contributor Author

CompN3rd commented Mar 14, 2025

Well, actual fine tuning is still in progress and riddled with oom issues and the quantization bug referenced in the unittest as well and the full training script relies on a private dataset, but I can at least give a bit more information.

The training task I am currently looking at is fine-tuning a vlm (the language model part) doing image captioning to maximize a clip cosine similarity score.

So the reward module looks like this

@dataclass
class CLIPRewardModelOutput(ModelOutput):
    logits: torch.FloatTensor
    """The reward logits for the Trainer."""


class CLIPRewardModel(CLIPModel):
    """Inherits from CLIPModel (i.e. PreTrainedModel), such that the forward computation gives the rl reward,
    but type-based training logic (accelerator.prepare) is still possible"""

    def forward(self, *, input_ids, attention_mask, pixel_values) -> torch.Tensor:
        """Mainly copy-paste from CLIPModel.forward up to the logit and loss computation"""

        vision_outputs = self.vision_model(
            pixel_values=pixel_values,
            output_attentions=False,
            output_hidden_states=False,
            interpolate_pos_encoding=False,
            return_dict=True,
        )

        text_outputs = self.text_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=None,
            output_attentions=False,
            output_hidden_states=False,
            return_dict=True,
        )

        image_embeds = vision_outputs[1]
        image_embeds = self.visual_projection(image_embeds)

        text_embeds = text_outputs[1]
        text_embeds = self.text_projection(text_embeds)

        # normalized features
        image_embeds = image_embeds / _get_vector_norm(image_embeds)
        text_embeds = text_embeds / _get_vector_norm(text_embeds)

        # cosine similarity of the vectors as reward logits
        return CLIPRewardModelOutput(logits=cosine_similarity(image_embeds, text_embeds, dim=-1).unsqueeze(-1))```

@CompN3rd
Copy link
Contributor Author

CompN3rd commented Mar 14, 2025

Then we have a small modification to the trainer via subclassing (which is why I proposed to split off the relevant code section into it's own member function)

class GRPOVlmClipTrainer(GRPOTrainer):
    def _prepare_inputs_for_reward_module(
        self,
        *,
        inputs: dict[str, torch.Tensor | Any],
        reward_processing_class: PreTrainedTokenizerBase,
        prompts: list[str],
        completions: list[str],
        images=None,
    ) -> dict[str, torch.Tensor | Any]:
        # disregard prompts, only prepare completions (captions) and images
        reward_inputs = reward_processing_class(
            images=images,
            text=completions,
            return_tensors="pt",
            padding=True,
            padding_side="right",
            add_special_tokens=True,
            truncation=True,
            max_length=77,
        )
        reward_inputs = super(GRPOTrainer, self)._prepare_inputs(reward_inputs)

        return reward_inputs

@CompN3rd
Copy link
Contributor Author

CompN3rd commented Mar 14, 2025

Finally this leads to reward curves like this, which seem to indicate that it generally optimizes in the right direction.
image

@MohamedAliRashad
Copy link

@CompN3rd If you can give me a simple guide on how to use your PR i can help you with testing

@CompN3rd
Copy link
Contributor Author

CompN3rd commented Mar 17, 2025

@CompN3rd If you can give me a simple guide on how to use your PR i can help you with testing

Sure, if you want to get started with a semi-realistic example, I'd suggest starting with the setup from the unittest, which should be able to run on a 24Gb gpu (test_gpo_trainer.py l.900-987)

    @require_flash_attn
    @require_bitsandbytes
    @require_peft
    @require_torch_accelerator
    def test_vlm_training(self):
        model_name = "HuggingFaceTB/SmolVLM-Instruct"
        .....

Biggest question there is why 8 bit quantization works, but 4 bit quantization breaks the test (or whether that is somehow expected behavior), so any input in that regard would be valuable.

Other than that if you have access to more vram gpus you could rewrite the test configuration to work without quantization or you could alternatively replace the model with a smaller one...

@MohamedAliRashad
Copy link

@CompN3rd I have tried this preprocessing function:

def format_data(row):
    base64_image = encode_image(row["image"])
    prompt = "Extract all text from the given image and format it using Markdown syntax. Preserve headings, lists, bold/italic text, and other structural elements. Ensure the output is clean and readable in Markdown format."
    messages = [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": f"data:image/jpeg;base64,{base64_image}",
                },
                {"type": "text", "text": prompt},
            ],
        }
    ]
    # Preparation for inference
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    return inputs

and it gave me KeyError: 'prompt' (I am training Qwen/Qwen2.5-VL-3B-Instruct)

@CompN3rd
Copy link
Contributor Author

CompN3rd commented Mar 19, 2025

@MohamedAliRashad If I understand this correctly, it is probably because the processor in your case returns already tokenized input_ids and probably pixel_values or whatever fields are associated with image/video processing.

That is not the type of data the GRPOTrainer expects (even before this current pr). In the internal preprocessing function of the trainer, it accesses prompt of the input dictionary (and this pr adds image, which is expected to be a raw numpy or pil image, not a base64 string).
Then it internally calls the processor and goes from there.

Tldr.: Your data preprocessing probably interferes with the input data preparation done in the trainer class.

@nph4rd
Copy link

nph4rd commented Mar 19, 2025

@CompN3rd - so how would one preprocess the data or tell the trainer how to process it? For example, as far as I understand, Qwen2.5-VL uses qwen-vl-util's process_vision_info.

Based on your changes, what would be the best approach to use that during the input preparation?

@MohamedAliRashad
Copy link

@CompN3rd I changed the preprocessing to be closer to what you have in the test file and it worked wanderfully.
I made full finetuning for qwen 2.5 vl 3B and it worked on an 80 GB GPU

@nph4rd
Copy link

nph4rd commented Mar 19, 2025

@MohamedAliRashad - do you mind sharing the setup/coda you used for that?

@nph4rd
Copy link

nph4rd commented Mar 20, 2025

I just tested the following with this dummy dataset using 4 A100 80GB:

from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import copy

model_id = "Qwen/Qwen2.5-VL-3B-Instruct"
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    use_cache=False,
)
processor = AutoProcessor.from_pretrained(model_id, padding_side="left")

dataset = load_dataset("agentsea/vqa-test-formatted", split="train")
dataset = dataset.remove_columns(["completion"])

def preprocess_vision_info(examples):
    examples_copy = copy.deepcopy(examples)
    batch_size = len(examples["prompt"])
    examples["image"] = []
    for i in range(batch_size):
        prompt_data = examples_copy["prompt"][i]
        image_data = examples_copy["image"][i]
        for message in prompt_data:
            for content in message["content"]:
                if isinstance(content, dict) and content.get("type") == "image":
                    content["image"] = image_data
        processed_images, _ = process_vision_info(prompt_data)
        examples["image"].extend(processed_images)
    return examples

dataset = dataset.with_transform(preprocess_vision_info)


def reward_len(completions, **kwargs):
    return [-abs(20 - len(completion)) for completion in completions]

training_args = GRPOConfig(
    output_dir="Qwen2.5-VL-3B-GRPO",
    logging_steps=1,
    use_vllm=True,
    bf16=True,
    gradient_checkpointing=True,
    per_device_train_batch_size=1,
    num_generations=3,
    max_prompt_length=None,
    vllm_device="cuda:3",
)


trainer = GRPOTrainer(
    model=model,
    processing_class=processor,
    reward_funcs=reward_len,
    args=training_args,
    train_dataset=dataset,
)
trainer.train()

However I'm encountering this error:

ValueError: Attempted to assign 5185 + 5185 + 5185 = 15555 multimodal tokens to 31107 placeholders

Upon further inspection I found that the code works if I make the following change in this line specifically:

from:

            prompt_inputs = self.processing_class(
                text=prompts_text,
                images=images,
                return_tensors="pt",
                padding=True,
                padding_side="left",
                add_special_tokens=False,
            )

to:

            prompt_inputs = self.processing_class(
                text=prompts_text.copy(), # send a copy instead
                images=images,
                return_tensors="pt",
                padding=True,
                padding_side="left",
                add_special_tokens=False,
            )

What is happening is that the processor class is mutating the input here. So, vLLM complains because it's receiving the modified prompts_text. You can test the side-effect in this script.

I don't think this is an issue that should be handled by neither TRL nor vLLM. I think it should be handled at the source of the proccessor's code.

I can see that the SmolVLM processor class doesn't have this kind of side-effect. But Qwen2.5-VL's does, so I do wonder how @MohamedAliRashad made it work. I presume it's because prompts_text is a string and not a list when using 1 GPU with num_generations=1?


EDIT: fwiw - I raised huggingface/transformers#36865 + opened huggingface/transformers#36866

@MohamedAliRashad
Copy link

@nph4rd The error you are seeing is because of your context size limit.
Qwen (unlike other models) doesn't give a fixed number of tokens for images of different shapes, The number of tokens change based on the size of the input image, if i am not mistaken every 28x28 pixels are one token for them.

What you need to do is to resize your images to be in a smaller size than your acceptable context window and also i didn't use process_vision_info and it worked fine with me, so you may consider removing it and send the pil images as it is.

@CompN3rd
Copy link
Contributor Author

@qgallouedec Let me know if there are refactoring or api changes necessary to make this ready for merging. Would be happy to make those adjustments.

@nph4rd
Copy link

nph4rd commented Mar 25, 2025

@MohamedAliRashad / @CompN3rd - thanks for the comments. I don't understand why it would work with the change I shared but not without it though? 🤔 With that change I didn't have to resize the images for it to work.

Another thing I found is that when I set log_completions=True, the training was stuck at this line:

table["image"] = [wandb.Image(img) for img in gather_object(images)]

Specifically the gather_object(images) was timing out. This might be my image size's again, but I thought I'd let you know in case you hadn't tested log_completions.

@CompN3rd
Copy link
Contributor Author

@nph4rd Thanks for testing it out. I concur with @MohamedAliRashad observations, that I could produce such errors mostly by having too small of a context window.

As for the log_completions error, I had a version, where not all processes participated in the communication, which obviously failed.
As of yet, I have a test with 2gpus locally, which worked well as well as a cloud test, but that was only one A40 GPU.
Both produced images in weights and biases, but I admit, I haven't made multi-node tests.

Screenshot_20250325-175437.png

Copy link

@nph4rd nph4rd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@CompN3rd - left some comments regarding supporting multiple images per prompt and updating to the changes in f713f61

Happy to help, with changes and/or testing.

ordered_set_of_prompts = [
{"multi_modal_data": {"image": image}, "prompt": prompt}
for image, prompt in zip(ordered_set_of_images, ordered_set_of_prompts)
]
with profiling_context(self, "vLLM.generate"):
completion_ids = self.vllm_client.generate(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes in f713f61 imply that it's now not possible to send the multi_modal_data without adapting the new vllm server. I think the following would have to change:

def _generate_and_score_completions(
self, inputs: dict[str, Union[torch.Tensor, Any]]
) -> dict[str, Union[torch.Tensor, Any]]:
device = self.accelerator.device
prompts = [x["prompt"] for x in inputs]
images = [x["image"] for x in inputs if "image" in x]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed the DPO trainer already supports VLMs but expects the column to be called "images", as a list of PIL images:

processed_features = processor(images=features["images"], text=features["prompt"], add_special_tokens=False)

I think this should be "images" too, for consistency, and to support the case of multiple images per prompt.

Also, note that for that case the vLLM server would have to be able to receive the limit_mm_per_prompt argument.

@sunildkumar
Copy link

Eagerly awaiting this (#2734 - 2 months and counting)! @CompN3rd Let me know if and how I can help. I've been training VLMs with GRPO for a while now, just not on TRL main.

@sunildkumar
Copy link

@qgallouedec – I hope you don’t mind the tag. It’s been a couple of weeks since @CompN3rd has engaged with this PR, and I really appreciate the work that’s been done so far. I’d love to help move it forward if that’s appropriate.

I’m not entirely sure what the etiquette is in cases like this—would it be okay to open a follow-up PR branching off of this one, or would you recommend waiting longer? Apologies if this is a naive question, and thank you in advance for any guidance.

@qgallouedec
Copy link
Member

Thanks again for your work on this, and sorry for the slow response, be sure we're doing our best. It's a valuable feature and makes a lot of sense to include. That said, it requires thorough review, testing, and documentation before merging, and at the moment we don’t have the capacity to give it the attention it needs. I’ll make sure to revisit it as soon as I can.

In the meantime, keeping the PR open is a great idea. It allows the community to test it, report any issues, and benefit from the feature.

And to your question — yes, feel free to open a follow-up PR based on this one. That’s totally fine and actually very helpful. No need to wait.

@sunildkumar
Copy link

@qgallouedec - totally understood. Thanks for your advice!

@Benjoyo
Copy link

Benjoyo commented May 3, 2025

Anyone actively working on GRPO support for VLMs still? 🙏

@@ -0,0 +1,704 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of this script and does it need be included?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, the idea is only to generate the test dataset. I find it quite practical to keep these scripts somewhere, similar to the script for generating the tiny models. This script doesn't need to be there per se. Think of them as useful scripts for developers, a bit like https://github.com/huggingface/transformers/tree/main/utils

Copy link
Collaborator

@edbeeching edbeeching left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM barring minor comments and some missing tests.

@Revist
Copy link

Revist commented Jul 22, 2025

Hi guys, thank you for the great work!

I am trying to use this PR with "llava-hf/llava-v1.6-mistral-7b-hf" however get an error "jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'text'". Is this caused by wrong dataset format or a bug in the PR?

accelerate launch   --config_file=examples/accelerate_configs/deepspeed_zero3.yaml   examples/scripts/grpo_vlm.py   --model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf   --output_dir grpo-Qwen2.5-VL-3B-Instruct_test   --learning_rate 1e-5   --gradient_checkpointing   --torch_dtype bfloat16   --max_prompt_length 2048   --max_completion_length 1024   --use_vllm   --vllm_mode colocate   --use_peft   --lora_target_modules "q_proj", "v_proj"   --log_completions

@qgallouedec qgallouedec merged commit 56f4201 into huggingface:main Jul 23, 2025
9 of 10 checks passed
@EauDeData
Copy link

Thanks you for this incredibly important merge. I think the whole DL community is happy today :)

From the initial PR I thought we did not have documentation, yet it passed the documentation check on merge; where can we find a clearer example of how to use it? I cannot find the updated docs...

Again, sincere thanks for your work.

@kashif
Copy link
Collaborator

kashif commented Jul 23, 2025

thanks @EauDeData you can find the docs here: https://huggingface.co/docs/trl/main/en/grpo_trainer#vision-language-model-vlm-training

@ghubnerr
Copy link
Contributor

Hi everyone! Sorry to bring this back - I noticed that the VLM support has a very strict format required for the element spec, where it expects a dict containing the "prompt" and "image" keys. This removes the control of where the user wants to insert the <start_of_image> tag, for example. I can tell that this decision was made because the maybe_apply_chat_template function returns a string, processed by the processor_class's apply chat template method (with tokenize=False, making images not compatible with that format.

With the AutoProcessor, one can actually return tensors using the apply_chat_template method, which lets you control the image placement better. An example with Gemma3Processor:

prompt = "This <img> is the image."
messages = [
    {"role": "system", "content": [{"type": "text", "text": system_prompt}]},
    {"role": "user", "content": [
        {"type": "text", "text": prompt}
        {"type": "image", "image": Image.open(BytesIO(image_bytes)),
    ]},

Which could be called like this:

formatted_mm_tokens = processor.apply_chat_template(
    conversation=messages,
    add_generation_prompt=True,
    do_pan_and_scan=True,
    tokenize=True,  # <-- Do tokenize
)

If this format is provided, this could potentially replace these lines inside an if statement.

The AutoProcessor automatically identifies an image in the content and creates a multi-modal token array. It would be really great to have this sort of control. I'd be happy to work on this later -- I'm currently in the middle of an internship, but will be available soon.

@qgallouedec
Copy link
Member

Hi, thanks for reporting, contributions are very welcome to fix this

marcandrelarochelle pushed a commit to marcandrelarochelle/trl that referenced this pull request Jul 29, 2025
@MohamedAliRashad
Copy link

Is this PR in a working state right or it had breaking changes ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.