-
Notifications
You must be signed in to change notification settings - Fork 130
feat: v0 VLM support + GRPO pipeline #655
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: rohitrango <rohit.rango@gmail.com>
Signed-off-by: rohitrango <rohit.rango@gmail.com>
Signed-off-by: rohitrango <rohit.rango@gmail.com>
Signed-off-by: rohitrango <rohit.rango@gmail.com>
Signed-off-by: rohitrango <rohit.rango@gmail.com>
Signed-off-by: rohitrango <rohit.rango@gmail.com>
Signed-off-by: rohitrango <rohit.rango@gmail.com>
Signed-off-by: rohitrango <rohit.rango@gmail.com>
assertions for non-vlm keys Signed-off-by: rohitrango <rohit.rango@gmail.com>
Signed-off-by: rohitrango <rohit.rango@gmail.com>
Signed-off-by: rohitrango <rohit.rango@gmail.com>
Signed-off-by: rohitrango <rohit.rango@gmail.com>
needs testing on larger machine) Signed-off-by: rohitrango <rohit.rango@gmail.com>
Signed-off-by: rohitrango <rohit.rango@gmail.com>
Signed-off-by: rohitrango <rohit.rango@gmail.com>
Signed-off-by: rohitrango <rohit.rango@gmail.com>
@terrykong @ashors1 Created a draft PR (duplicate of #521) to see if CI passes on this instead. |
Signed-off-by: rohitrango <rohit.rango@gmail.com>
- separated reward functions into separate file (and made composable from YAML files directly) - added RefCOCO task - Ability to freeze huggingface models (language and vision tower) and finegrained freezing using regexes Signed-off-by: rohitrango <rohit.rango@gmail.com>
Signed-off-by: Rohit Jena <rohit.rango@gmail.com>
nemo_rl/data/llm_message_utils.py
Outdated
## this will have consequences for data sharding for VLM models (split along the batch dim but from [start_patch:end_patch]) | ||
keys_to_concat = [] | ||
|
||
if key in keys_to_concat: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
keys_to_concat is always empty here?
nemo_rl/data/hf_datasets/refcoco.py
Outdated
|
||
if random.random() < img_flip_prob: | ||
flip = True | ||
resized_image = resized_image.transpose(Image.FLIP_LEFT_RIGHT) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this always safe to do for this dataset? Do any captions rely on the positions of the original image (e.g. "A cat sitting to the left of a dog")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your work on this PR! Two quick comments:
- Could you add your test case to the nightly suite: https://github.com/NVIDIA-NeMo/RL/blob/rohit/vlm_grpo/tests/test_suites/nightly.txt?
- Throughout the code, there are a number of different ways of getting
vlm_keys
orvlm_kwargs
from the data. This seems slightly verbose, but perhaps I don't have a good enough understanding of the code to see why all these different methods are required. Would it be possible to streamline the process of getting the vlm keys/kwargs? If not, could we add some documentation to explain the structure of the vlm keys in the data? That might help to clarify things a bit
examples/run_vlm_grpo.py
Outdated
user_message['token_ids'] = message['input_ids'][0] | ||
# add all keys and values to the user message, and the list of keys | ||
user_message['vlm_keys'] = [] | ||
for key, value in message.items(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are the vlm_keys specific to the dataset? And are they applicable for all messages of that dataset? If so, can this be configured with the dataset? (i.e. in clevr.py and refcoco.py). This seems to assume all keys except for 'input_ids', 'attention_mask' are vlm_keys which seems less safe than if we were explicit about which keys are vlm keys.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we get the whitelist of vlm_keys from processor.image_processor.model_input_names
?
- recycle computed vlm_kwargs for both unflattened and flattened batches - remove potentially unsafe code for flipping images in refcoco - rename `get_vlm_keys_from_clippedpgloss_batch` to `get_vlm_keys_from_flattened_batch` and move it to batched_data_dict.py - add vlm grpo testcase to nightly - improve documentation in CLEVR Signed-off-by: rohitrango <rohit.rango@gmail.com>
Signed-off-by: rohitrango <rohit.rango@gmail.com>
I'm not entirely sure what causes the For the functional test cases, I had to choose higher thresholds for the |
@@ -41,6 +41,18 @@ | |||
from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM | |||
from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM | |||
|
|||
# Add VL model imports | |||
try: | |||
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration, Qwen2_5_VLModel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need these try/excepts? Can we instead make sure the transformers version we're using has these?
Signed-off-by: rohitrango <rohit.rango@gmail.com>
One thought is that we do the preprocessing of the image before calling the model in the dtensor path whereas vllm does preprocessing of the image internally (if I understand what is happening correctly). We may need to make sure whatever preprocessing vllm is doing matches exactly what we're doing in the dtensor path. |
This will take me a while to analyse since I don't know exactly how the vllm engine processes the images internally. For the policy, the typical multimodal pipeline is to use the For vllm, the message log is simply reformatted into the format specified in this tutorial. The same sequence of PIL Images is provided to the vLLM frontend. |
Signed-off-by: rohitrango <rohit.rango@gmail.com>
From today's meeting, the remaining blockers on this PR:
|
re: Remaining blockers:
I prefer handling this issue in a separate PR (and merging an initial support first) for
|
What does this PR do ?
Adds VLM support (Qwen2.5-VL) with TP plan, DTensor Policy, vLLM backend, and multiple gpus.
Usage
Convergence
(Training) convergence on 2 H100 GPUs happens in about 60 iterations. (highest possible reward is 5)
Before your PR is "Ready for review"
Pre checks:
Additional Information