Skip to content

Conversation

rohitrango
Copy link
Contributor

What does this PR do ?

Adds VLM support (Qwen2.5-VL) with TP plan, DTensor Policy, vLLM backend, and multiple gpus.

Usage

uv run examples/run_vlm_grpo.py cluster.gpus_per_node=4

Convergence

(Training) convergence on 2 H100 GPUs happens in about 60 iterations. (highest possible reward is 5)

image
image

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

rohitrango added 17 commits July 7, 2025 15:13
Signed-off-by: rohitrango <rohit.rango@gmail.com>
Signed-off-by: rohitrango <rohit.rango@gmail.com>
Signed-off-by: rohitrango <rohit.rango@gmail.com>
Signed-off-by: rohitrango <rohit.rango@gmail.com>
Signed-off-by: rohitrango <rohit.rango@gmail.com>
Signed-off-by: rohitrango <rohit.rango@gmail.com>
Signed-off-by: rohitrango <rohit.rango@gmail.com>
Signed-off-by: rohitrango <rohit.rango@gmail.com>
Signed-off-by: rohitrango <rohit.rango@gmail.com>
assertions for non-vlm keys

Signed-off-by: rohitrango <rohit.rango@gmail.com>
Signed-off-by: rohitrango <rohit.rango@gmail.com>
Signed-off-by: rohitrango <rohit.rango@gmail.com>
Signed-off-by: rohitrango <rohit.rango@gmail.com>
needs testing on larger machine)

Signed-off-by: rohitrango <rohit.rango@gmail.com>
Signed-off-by: rohitrango <rohit.rango@gmail.com>
Signed-off-by: rohitrango <rohit.rango@gmail.com>
Signed-off-by: rohitrango <rohit.rango@gmail.com>
@rohitrango
Copy link
Contributor Author

@terrykong @ashors1 Created a draft PR (duplicate of #521) to see if CI passes on this instead.

@ashors1 ashors1 added the CI:L1 Run doctests, unit tests, and functional tests label Jul 11, 2025
Signed-off-by: rohitrango <rohit.rango@gmail.com>
- separated reward functions into separate file (and made composable
  from YAML files directly)
- added RefCOCO task
- Ability to freeze huggingface models (language and vision tower) and
  finegrained freezing using regexes

Signed-off-by: rohitrango <rohit.rango@gmail.com>
@rohitrango rohitrango marked this pull request as ready for review July 15, 2025 00:23
Signed-off-by: Rohit Jena <rohit.rango@gmail.com>
@terrykong
Copy link
Contributor

adding @ashors1 @yfw to review/approve

## this will have consequences for data sharding for VLM models (split along the batch dim but from [start_patch:end_patch])
keys_to_concat = []

if key in keys_to_concat:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keys_to_concat is always empty here?


if random.random() < img_flip_prob:
flip = True
resized_image = resized_image.transpose(Image.FLIP_LEFT_RIGHT)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this always safe to do for this dataset? Do any captions rely on the positions of the original image (e.g. "A cat sitting to the left of a dog")

Copy link
Contributor

@ashors1 ashors1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your work on this PR! Two quick comments:

  1. Could you add your test case to the nightly suite: https://github.com/NVIDIA-NeMo/RL/blob/rohit/vlm_grpo/tests/test_suites/nightly.txt?
  2. Throughout the code, there are a number of different ways of getting vlm_keys or vlm_kwargs from the data. This seems slightly verbose, but perhaps I don't have a good enough understanding of the code to see why all these different methods are required. Would it be possible to streamline the process of getting the vlm keys/kwargs? If not, could we add some documentation to explain the structure of the vlm keys in the data? That might help to clarify things a bit

user_message['token_ids'] = message['input_ids'][0]
# add all keys and values to the user message, and the list of keys
user_message['vlm_keys'] = []
for key, value in message.items():
Copy link
Contributor

@yfw yfw Jul 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the vlm_keys specific to the dataset? And are they applicable for all messages of that dataset? If so, can this be configured with the dataset? (i.e. in clevr.py and refcoco.py). This seems to assume all keys except for 'input_ids', 'attention_mask' are vlm_keys which seems less safe than if we were explicit about which keys are vlm keys.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we get the whitelist of vlm_keys from processor.image_processor.model_input_names ?

- recycle computed vlm_kwargs for both unflattened and flattened batches
- remove potentially unsafe code for flipping images in refcoco
- rename `get_vlm_keys_from_clippedpgloss_batch` to
  `get_vlm_keys_from_flattened_batch` and move it to
batched_data_dict.py
- add vlm grpo testcase to nightly
- improve documentation in CLEVR

Signed-off-by: rohitrango <rohit.rango@gmail.com>
Signed-off-by: rohitrango <rohit.rango@gmail.com>
@yfw
Copy link
Contributor

yfw commented Jul 19, 2025

I tried a run on the clevr dataset and noticed the token_mult_prob_error was a bit high. This is a measure of the difference in logprobs between vllm and dtensor. For non-vlm, we generally expect this number to be < 1.05 (for qwen 2.5-1b, we see around 1.02). A higher number usually indicates some issue with the refit so vllm and dtensor aren't exactly matching. Is there anything specific about the vlm setup that could cause this?

Screenshot 2025-07-18 at 6 30 17 PM

@rohitrango
Copy link
Contributor Author

I'm not entirely sure what causes the token_mult_prob_error to be relatively higher than the LLM only case. I did not fiddle with the either vllm or the dtensor model forward setup so I'm not sure what causes this.

For the functional test cases, I had to choose higher thresholds for the token_mult_prob_error errors. There are also a few spikes like you have shown that I cannot quite explain (the loss at these iterations does not spike though).

@@ -41,6 +41,18 @@
from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM
from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM

# Add VL model imports
try:
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration, Qwen2_5_VLModel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need these try/excepts? Can we instead make sure the transformers version we're using has these?

Signed-off-by: rohitrango <rohit.rango@gmail.com>
@yfw
Copy link
Contributor

yfw commented Jul 22, 2025

I'm not entirely sure what causes the token_mult_prob_error to be relatively higher than the LLM only case. I did not fiddle with the either vllm or the dtensor model forward setup so I'm not sure what causes this.

For the functional test cases, I had to choose higher thresholds for the token_mult_prob_error errors. There are also a few spikes like you have shown that I cannot quite explain (the loss at these iterations does not spike though).

One thought is that we do the preprocessing of the image before calling the model in the dtensor path whereas vllm does preprocessing of the image internally (if I understand what is happening correctly). We may need to make sure whatever preprocessing vllm is doing matches exactly what we're doing in the dtensor path.

@rohitrango
Copy link
Contributor Author

rohitrango commented Jul 22, 2025

This will take me a while to analyse since I don't know exactly how the vllm engine processes the images internally.

For the policy, the typical multimodal pipeline is to use the processor to encode the chat template dict into a sequence of text tokens, multimodal tokens indexed by the key pixel_values and metadata image_grid_thw to compute mRoPE embeddings (keys would be different for videos or audio). The pixel_values item takes a PIL Image as input --> processes it into patches.

For vllm, the message log is simply reformatted into the format specified in this tutorial. The same sequence of PIL Images is provided to the vLLM frontend.
I assume the token ids and multimodal tokens must be the same, but I will have to double check. In most cases the logprobs mult error is still close to 1, so it could also be numerical differences between the vllm frontend and the dtensor policy.

Signed-off-by: rohitrango <rohit.rango@gmail.com>
@terrykong
Copy link
Contributor

From today's meeting, the remaining blockers on this PR:

  • understanding the logprob error
  • address review on API VLM processor keys

@rohitrango
Copy link
Contributor Author

rohitrango commented Jul 29, 2025

re: Remaining blockers:

  1. understanding the logprob error: This is something I want to chalk up to how vllm loads multimodal image embeddings in the image processor. For LLM-only, I noted that vllm takes the same list of token_ids (int value list) that the policy consumes (i.e. going through the same text embedding layer, etc.). However, for multimodal images, vllm processes the images internally. There could also be differences in how sampling is done differently. I found the following excerpt from vllm docs https://docs.vllm.ai/en/v0.9.1/usage/v1_guide.html#feature-model

Logprobs in V1 are now returned immediately once computed from the model’s raw output (i.e. before applying any logits post-processing such as temperature scaling or penalty adjustments). As a result, the returned logprobs do not reflect the final adjusted probabilities used during sampling.
Support for logprobs with post-sampling adjustments is in progress and will be added in future updates.

I prefer handling this issue in a separate PR (and merging an initial support first) for three four reasons:

  • this discrepancy is isolated to multimodal models only, so a "fix" can be shipped independently
  • multiple VLMs converge on three different datasets despite the apparent discrepancy. It is equivalent to training GRPO with a slightly off-policy model, but it does not seem to be very unstable or destructive to the learning process
  • other PRs break multimodal support regularly (every 2-3 days) and I have to rollback / fix those changes in my PR to make my scripts work. Merging this PR or at least the test cases will prevent other PRs from breaking multimodal support
  • the PR has gotten very big as it is, and adding more fixes will add additional overhead to the review process
  1. PR has now migrated (again) to feat: GRPO + SFT Dtensor support for multimodal training  #712, and is tested on 4 families of multimodal models and 3 datasets. This rollsback the passing around of the vlm_kwargs list throughout the training process and instead proposes a PackedGenericDataItem to handle non-sequence data items (most of them would be multimodal tensors). The single implementation seems to work for multiple multimodal models without any additional modifications to the config.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CI:L1 Run doctests, unit tests, and functional tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants