Skip to content

Conversation

kiddj
Copy link
Contributor

@kiddj kiddj commented Feb 15, 2025

What does this PR do?

Fixes DeepSpeed Stage-3 compatibility by passing the wrapped model (self.model_wrapped) to unwrap-model_for_generation instead of self.model.

Previously, unwrap_model_for_generation(model) was called with a model passed via function like compute_loss, which avoided any DeepSpeed Stage-3 conflicts. In the current implementation, the model is obtained directly from self, causing a crash under Stage-3.

This fix ensures consistency by always passing the wrapped model instead of self.model, preventing the Stage-3 compatibility issue.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@qgallouedec

@dignfei
Copy link

dignfei commented Feb 15, 2025

I have already submitted a similar submission

@qgallouedec
Copy link
Member

qgallouedec commented Feb 18, 2025

Thanks for your contribution @kiddj! Can you provide a code that would fail with main branch and not with yours? I can get one. Currently, the following works for me:

# 2871.py
from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer

dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

training_args = GRPOConfig(
    output_dir="2871",
    per_device_train_batch_size=3,  # reduce the batch size to reduce memory usage
    num_generations=3,  # reduce the number of generations to reduce memory usage
    max_completion_length=32,  # reduce the completion length to reduce memory usage
    max_prompt_length=32,
    bf16=True,
    report_to="none",
)

def dummy_reward_func(completions, **kwargs):
    return [0.0] * len(completions)

trainer = GRPOTrainer(
    model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
    reward_funcs=dummy_reward_func,
    args=training_args,
    train_dataset=dataset,
)

trainer.train()
# 2 GPUs
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml --num_processes 2 sandbox/2871.py

@qgallouedec qgallouedec added the 😴 stale No update from the author, will be closed soon label Feb 24, 2025
@jamesbraza
Copy link
Contributor

@kiddj what was the crash you hit, can you share a stack trace?

I just hit the below AssertionError inside of grpo_trainer.py with accelerate launch --zero3_init_flag true --zero_stage 3, accelerate==1.4.0, deepspeed==0.16.4, current main branch of trl with #2963.

I think it might've been what your PR aims to resolve:

0: [rank3]: Traceback (most recent call last):
0: [rank3]:   File "/path/to/repo/train.py", line 268, in <module>
0: [rank3]:     main(script_args, training_args, model_args)
0: [rank3]:   File "/path/to/repo/train.py", line 254, in main
0: [rank3]:     trainer.train(**train_kw)
0: [rank3]:   File "/path/to/repo/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 2245, in train
0: [rank3]:     return inner_training_loop(
0: [rank3]:            ^^^^^^^^^^^^^^^^^^^^
0: [rank3]:   File "/path/to/repo/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 2556, in _inner_training_loop
0: [rank3]:     tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
0: [rank3]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0: [rank3]:   File "/path/to/repo/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 3700, in training_step
0: [rank3]:     inputs = self._prepare_inputs(inputs)
0: [rank3]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0: [rank3]:   File "/path/to/repo/.venv/lib/python3.12/site-packages/trl/extras/profiling.py", line 33, in wrapper
0: [rank3]:     result = func(self, *args, **kwargs)
0: [rank3]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
0: [rank3]:   File "/path/to/repo/.venv/lib/python3.12/site-packages/trl/trainer/grpo_trainer.py", line 670, in _prepare_inputs
0: [rank3]:     inputs = self._generate_and_score_completions(inputs)
0: [rank3]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0: [rank3]:   File "/path/to/repo/.venv/lib/python3.12/site-packages/trl/trainer/grpo_trainer.py", line 849, in _generate_and_score_completions
0: [rank3]:     with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
0: [rank3]:          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0: [rank3]:   File "/home/james/.local/share/uv/python/cpython-3.12.9-linux-x86_64-gnu/lib/python3.12/contextlib.py", line 144, in __exit__
0: [rank3]:     next(self.gen)
0: [rank3]:   File "/path/to/repo/.venv/lib/python3.12/site-packages/trl/models/utils.py", line 217, in unwrap_model_for_generation
0: [rank3]:     with deepspeed.zero.GatheredParameters(model.parameters()):
0: [rank3]:          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0: [rank3]:   File "/path/to/repo/.venv/lib/python3.12/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 2251, in __exit__
0: [rank3]:     self.params[0].partition(param_list=self.params, has_been_updated=False)
0: [rank3]:   File "/path/to/repo/.venv/lib/python3.12/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 1394, in partition
0: [rank3]:     self._partition(param_list, has_been_updated=has_been_updated, free_data=True)
0: [rank3]:   File "/path/to/repo/.venv/lib/python3.12/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 1543, in _partition
0: [rank3]:     self._partition_param(param, has_been_updated=has_been_updated, free_data=True)
0: [rank3]:   File "/path/to/repo/.venv/lib/python3.12/site-packages/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
0: [rank3]:     ret_val = func(*args, **kwargs)
0: [rank3]:               ^^^^^^^^^^^^^^^^^^^^^
0: [rank3]:   File "/path/to/repo/.venv/lib/python3.12/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 1552, in _partition_param
0: [rank3]:     assert param.ds_status is not ZeroParamStatus.INFLIGHT, f" {param} Cannot partition a param in flight"
0:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0: AssertionError:  Parameter containing:
0: tensor([[-2.9785e-02,  2.1484e-02,  4.7913e-03,  ...,  4.6875e-02,
0:           1.6357e-02,  3.2959e-02],
0:         [-7.4219e-02, -1.3611e-02,  1.2756e-02,  ...,  5.8899e-03,
0:           2.9144e-03,  7.0496e-03],
0:         [-3.4180e-02, -1.3306e-02,  7.2937e-03,  ...,  4.2343e-04,
0:          -1.8616e-03,  3.4424e-02],
0:         ...,
0:         [ 2.1104e-26, -2.4840e-26,  7.0177e-27,  ..., -2.4032e-26,
0:          -1.1471e-25, -2.4032e-26],
0:         [ 5.0083e-26, -2.2618e-26, -1.3329e-26,  ..., -6.8662e-26,
0:           1.4439e-26,  3.2110e-26],
0:         [ 1.0350e-26, -4.8872e-26,  1.9993e-26,  ...,  1.4136e-26,
0:           2.8879e-26,  7.8255e-28]], device='cuda:0', dtype=torch.bfloat16,
0:        requires_grad=True) Cannot partition a param in flight

@jamesbraza
Copy link
Contributor

After pulling in the PR, I did not get the same error. So I think this PR works

@qgallouedec qgallouedec changed the title [GRPO] Pass wrapped model to unwrap_model_for_generation for DeepSpeed Stage-3 compatibility 🫔 [GRPO] Pass wrapped model to unwrap_model_for_generation for DeepSpeed Stage-3 compatibility Feb 28, 2025
Copy link
Member

@qgallouedec qgallouedec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qgallouedec qgallouedec merged commit ad6a35b into huggingface:main Feb 28, 2025
12 of 13 checks passed
jhinpan pushed a commit to jhinpan/trl-jin that referenced this pull request Mar 12, 2025
…eed Stage-3 compatibility (huggingface#2871)

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
yxliu-TAMU pushed a commit to mincheolseong/ECEN743-GRPO-Project-Proposal that referenced this pull request Apr 20, 2025
…eed Stage-3 compatibility (huggingface#2871)

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
😴 stale No update from the author, will be closed soon
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants