Skip to content

Online DPO crashes when using multiple GPUs #3063

@wilrop

Description

@wilrop

Reproduction

The example given in the documentation for online DPO crashes when executing on a system with multiple GPUs.

from datasets import load_dataset
from trl import OnlineDPOConfig, OnlineDPOTrainer, PairRMJudge
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
judge = PairRMJudge()
train_dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train")

training_args = OnlineDPOConfig(output_dir="Qwen2-0.5B-OnlineDPO", logging_steps=10)
trainer = OnlineDPOTrainer(
    model=model, judge=judge, args=training_args, processing_class=tokenizer, train_dataset=train_dataset
)
trainer.train()

outputs:

Traceback (most recent call last):
  File "/home/wropke/bug.py", line 14, in <module>
    trainer.train()
  File "/home/wropke/bug/lib/python3.11/site-packages/transformers/trainer.py", line 2241, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/wropke/bug/lib/python3.11/site-packages/transformers/trainer.py", line 2548, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wropke/bug/lib/python3.11/site-packages/trl/trainer/online_dpo_trainer.py", line 527, in training_step
    prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate(model, prompts)
                                                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wropke/bug/lib/python3.11/site-packages/trl/trainer/online_dpo_trainer.py", line 473, in _generate
    inputs = [self.tokenize_row(x, model.config.is_encoder_decoder, self.processing_class) for x in inputs]
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wropke/bug/lib/python3.11/site-packages/trl/trainer/online_dpo_trainer.py", line 473, in <listcomp>
    inputs = [self.tokenize_row(x, model.config.is_encoder_decoder, self.processing_class) for x in inputs]
                                   ^^^^^^^^^^^^
  File "/home/wropke/bug/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1928, in __getattr__
    raise AttributeError(
AttributeError: 'DataParallel' object has no attribute 'config'

I managed to fix this issue by taking inspiration from the DPO trainer implementation:

  1. Add self.is_encoder_decoder = model.config.is_encoder_decoder to the init
  2. Change line 473 to inputs = [self.tokenize_row(x, self.is_encoder_decoder, self.processing_class) for x in inputs]

System Info

I made a fresh virtual environment with just trl and llm-blender installed.

  • Platform: Linux-6.12.12+bpo-amd64-x86_64-with-glibc2.36
  • Python version: 3.11.2
  • PyTorch version: 2.6.0
  • CUDA device(s): NVIDIA A40, NVIDIA A40
  • Transformers version: 4.49.0
  • Accelerate version: 1.4.0
  • Accelerate config: not found
  • Datasets version: 3.3.2
  • HF Hub version: 0.29.3
  • TRL version: 0.15.2
  • bitsandbytes version: not installed
  • DeepSpeed version: not installed
  • Diffusers version: not installed
  • Liger-Kernel version: not installed
  • LLM-Blender version: 0.0.2
  • OpenAI version: not installed
  • PEFT version: not installed

Checklist

  • I have checked that my issue isn't already filed (see open issues)
  • I have included my system information
  • Any code provided is minimal, complete, and reproducible (more on MREs)
  • Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
  • Any traceback provided is complete

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions