-
Notifications
You must be signed in to change notification settings - Fork 2.1k
👁️ [GRPO] Add VLM training capabilities to the trainer #3072
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Thanks! Have you tried to fine-tune a VLM with the trainer? Do you have results to share? |
Well, actual fine tuning is still in progress and riddled with oom issues and the quantization bug referenced in the unittest as well and the full training script relies on a private dataset, but I can at least give a bit more information. The training task I am currently looking at is fine-tuning a vlm (the language model part) doing image captioning to maximize a clip cosine similarity score. So the reward module looks like this @dataclass
class CLIPRewardModelOutput(ModelOutput):
logits: torch.FloatTensor
"""The reward logits for the Trainer."""
class CLIPRewardModel(CLIPModel):
"""Inherits from CLIPModel (i.e. PreTrainedModel), such that the forward computation gives the rl reward,
but type-based training logic (accelerator.prepare) is still possible"""
def forward(self, *, input_ids, attention_mask, pixel_values) -> torch.Tensor:
"""Mainly copy-paste from CLIPModel.forward up to the logit and loss computation"""
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=False,
output_hidden_states=False,
interpolate_pos_encoding=False,
return_dict=True,
)
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
)
image_embeds = vision_outputs[1]
image_embeds = self.visual_projection(image_embeds)
text_embeds = text_outputs[1]
text_embeds = self.text_projection(text_embeds)
# normalized features
image_embeds = image_embeds / _get_vector_norm(image_embeds)
text_embeds = text_embeds / _get_vector_norm(text_embeds)
# cosine similarity of the vectors as reward logits
return CLIPRewardModelOutput(logits=cosine_similarity(image_embeds, text_embeds, dim=-1).unsqueeze(-1))``` |
Then we have a small modification to the trainer via subclassing (which is why I proposed to split off the relevant code section into it's own member function) class GRPOVlmClipTrainer(GRPOTrainer):
def _prepare_inputs_for_reward_module(
self,
*,
inputs: dict[str, torch.Tensor | Any],
reward_processing_class: PreTrainedTokenizerBase,
prompts: list[str],
completions: list[str],
images=None,
) -> dict[str, torch.Tensor | Any]:
# disregard prompts, only prepare completions (captions) and images
reward_inputs = reward_processing_class(
images=images,
text=completions,
return_tensors="pt",
padding=True,
padding_side="right",
add_special_tokens=True,
truncation=True,
max_length=77,
)
reward_inputs = super(GRPOTrainer, self)._prepare_inputs(reward_inputs)
return reward_inputs |
@CompN3rd If you can give me a simple guide on how to use your PR i can help you with testing |
Sure, if you want to get started with a semi-realistic example, I'd suggest starting with the setup from the unittest, which should be able to run on a 24Gb gpu ( @require_flash_attn
@require_bitsandbytes
@require_peft
@require_torch_accelerator
def test_vlm_training(self):
model_name = "HuggingFaceTB/SmolVLM-Instruct"
..... Biggest question there is why 8 bit quantization works, but 4 bit quantization breaks the test (or whether that is somehow expected behavior), so any input in that regard would be valuable. Other than that if you have access to more vram gpus you could rewrite the test configuration to work without quantization or you could alternatively replace the model with a smaller one... |
@CompN3rd I have tried this preprocessing function: def format_data(row):
base64_image = encode_image(row["image"])
prompt = "Extract all text from the given image and format it using Markdown syntax. Preserve headings, lists, bold/italic text, and other structural elements. Ensure the output is clean and readable in Markdown format."
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": f"data:image/jpeg;base64,{base64_image}",
},
{"type": "text", "text": prompt},
],
}
]
# Preparation for inference
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
return inputs and it gave me KeyError: 'prompt' (I am training |
@MohamedAliRashad If I understand this correctly, it is probably because the That is not the type of data the Tldr.: Your data preprocessing probably interferes with the input data preparation done in the trainer class. |
@CompN3rd - so how would one preprocess the data or tell the trainer how to process it? For example, as far as I understand, Qwen2.5-VL uses qwen-vl-util's Based on your changes, what would be the best approach to use that during the input preparation? |
@CompN3rd I changed the preprocessing to be closer to what you have in the test file and it worked wanderfully. |
@MohamedAliRashad - do you mind sharing the setup/coda you used for that? |
I just tested the following with this dummy dataset using 4 A100 80GB: from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import copy
model_id = "Qwen/Qwen2.5-VL-3B-Instruct"
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
use_cache=False,
)
processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
dataset = load_dataset("agentsea/vqa-test-formatted", split="train")
dataset = dataset.remove_columns(["completion"])
def preprocess_vision_info(examples):
examples_copy = copy.deepcopy(examples)
batch_size = len(examples["prompt"])
examples["image"] = []
for i in range(batch_size):
prompt_data = examples_copy["prompt"][i]
image_data = examples_copy["image"][i]
for message in prompt_data:
for content in message["content"]:
if isinstance(content, dict) and content.get("type") == "image":
content["image"] = image_data
processed_images, _ = process_vision_info(prompt_data)
examples["image"].extend(processed_images)
return examples
dataset = dataset.with_transform(preprocess_vision_info)
def reward_len(completions, **kwargs):
return [-abs(20 - len(completion)) for completion in completions]
training_args = GRPOConfig(
output_dir="Qwen2.5-VL-3B-GRPO",
logging_steps=1,
use_vllm=True,
bf16=True,
gradient_checkpointing=True,
per_device_train_batch_size=1,
num_generations=3,
max_prompt_length=None,
vllm_device="cuda:3",
)
trainer = GRPOTrainer(
model=model,
processing_class=processor,
reward_funcs=reward_len,
args=training_args,
train_dataset=dataset,
)
trainer.train() However I'm encountering this error:
Upon further inspection I found that the code works if I make the following change in this line specifically: from: prompt_inputs = self.processing_class(
text=prompts_text,
images=images,
return_tensors="pt",
padding=True,
padding_side="left",
add_special_tokens=False,
) to: prompt_inputs = self.processing_class(
text=prompts_text.copy(), # send a copy instead
images=images,
return_tensors="pt",
padding=True,
padding_side="left",
add_special_tokens=False,
) What is happening is that the processor class is mutating the input here. So, vLLM complains because it's receiving the modified I don't think this is an issue that should be handled by neither TRL nor vLLM. I think it should be handled at the source of the proccessor's code. I can see that the SmolVLM processor class doesn't have this kind of side-effect. But Qwen2.5-VL's does, so I do wonder how @MohamedAliRashad made it work. I presume it's because EDIT: fwiw - I raised huggingface/transformers#36865 + opened huggingface/transformers#36866 |
@nph4rd The error you are seeing is because of your context size limit. What you need to do is to resize your images to be in a smaller size than your acceptable context window and also i didn't use |
@qgallouedec Let me know if there are refactoring or api changes necessary to make this ready for merging. Would be happy to make those adjustments. |
@MohamedAliRashad / @CompN3rd - thanks for the comments. I don't understand why it would work with the change I shared but not without it though? 🤔 With that change I didn't have to resize the images for it to work. Another thing I found is that when I set trl/trl/trainer/grpo_trainer.py Line 1026 in e94b5fa
Specifically the gather_object(images) was timing out. This might be my image size's again, but I thought I'd let you know in case you hadn't tested log_completions.
|
@nph4rd Thanks for testing it out. I concur with @MohamedAliRashad observations, that I could produce such errors mostly by having too small of a context window. As for the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
trl/trainer/grpo_trainer.py
Outdated
ordered_set_of_prompts = [ | ||
{"multi_modal_data": {"image": image}, "prompt": prompt} | ||
for image, prompt in zip(ordered_set_of_images, ordered_set_of_prompts) | ||
] | ||
with profiling_context(self, "vLLM.generate"): | ||
completion_ids = self.vllm_client.generate( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
trl/trainer/grpo_trainer.py
Outdated
def _generate_and_score_completions( | ||
self, inputs: dict[str, Union[torch.Tensor, Any]] | ||
) -> dict[str, Union[torch.Tensor, Any]]: | ||
device = self.accelerator.device | ||
prompts = [x["prompt"] for x in inputs] | ||
images = [x["image"] for x in inputs if "image" in x] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I noticed the DPO trainer already supports VLMs but expects the column to be called "images", as a list of PIL images:
trl/trl/trainer/dpo_trainer.py
Line 642 in 26d8675
processed_features = processor(images=features["images"], text=features["prompt"], add_special_tokens=False) |
I think this should be "images" too, for consistency, and to support the case of multiple images per prompt.
Also, note that for that case the vLLM server would have to be able to receive the limit_mm_per_prompt
argument.
@qgallouedec – I hope you don’t mind the tag. It’s been a couple of weeks since @CompN3rd has engaged with this PR, and I really appreciate the work that’s been done so far. I’d love to help move it forward if that’s appropriate. I’m not entirely sure what the etiquette is in cases like this—would it be okay to open a follow-up PR branching off of this one, or would you recommend waiting longer? Apologies if this is a naive question, and thank you in advance for any guidance. |
Thanks again for your work on this, and sorry for the slow response, be sure we're doing our best. It's a valuable feature and makes a lot of sense to include. That said, it requires thorough review, testing, and documentation before merging, and at the moment we don’t have the capacity to give it the attention it needs. I’ll make sure to revisit it as soon as I can. In the meantime, keeping the PR open is a great idea. It allows the community to test it, report any issues, and benefit from the feature. And to your question — yes, feel free to open a follow-up PR based on this one. That’s totally fine and actually very helpful. No need to wait. |
@qgallouedec - totally understood. Thanks for your advice! |
Anyone actively working on GRPO support for VLMs still? 🙏 |
@@ -0,0 +1,704 @@ | |||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the purpose of this script and does it need be included?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, the idea is only to generate the test dataset. I find it quite practical to keep these scripts somewhere, similar to the script for generating the tiny models. This script doesn't need to be there per se. Think of them as useful scripts for developers, a bit like https://github.com/huggingface/transformers/tree/main/utils
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM barring minor comments and some missing tests.
Hi guys, thank you for the great work! I am trying to use this PR with "llava-hf/llava-v1.6-mistral-7b-hf" however get an error "jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'text'". Is this caused by wrong dataset format or a bug in the PR?
|
Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
Thanks you for this incredibly important merge. I think the whole DL community is happy today :) From the initial PR I thought we did not have documentation, yet it passed the documentation check on merge; where can we find a clearer example of how to use it? I cannot find the updated docs... Again, sincere thanks for your work. |
thanks @EauDeData you can find the docs here: https://huggingface.co/docs/trl/main/en/grpo_trainer#vision-language-model-vlm-training |
Hi everyone! Sorry to bring this back - I noticed that the VLM support has a very strict format required for the element spec, where it expects a dict containing the With the prompt = "This <img> is the image." messages = [
{"role": "system", "content": [{"type": "text", "text": system_prompt}]},
{"role": "user", "content": [
{"type": "text", "text": prompt}
{"type": "image", "image": Image.open(BytesIO(image_bytes)),
]}, Which could be called like this: formatted_mm_tokens = processor.apply_chat_template(
conversation=messages,
add_generation_prompt=True,
do_pan_and_scan=True,
tokenize=True, # <-- Do tokenize
) If this format is provided, this could potentially replace these lines inside an if statement. The |
Hi, thanks for reporting, contributions are very welcome to fix this |
Is this PR in a working state right or it had breaking changes ? |
What does this PR do?
This is an attempt at addressing #2917 .
An associated unittest has been added and less "toy-examply" trainings seem to maximize rewards as well, but I don't have final models yet.
Issues regarding 4-bit quantization and applying liger kernel to vlm models still remain (marked with TODO in comments) and I would appreciate input on how to tackle/further debug them.
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.