Skip to content

[BUG] The map device of training model in GRPO includes the device used by vllm #3088

@maoulee

Description

@maoulee

Reproduction

Reproduction

Question: During TRL training with a separate GPU for vLLM inference, I'm seeing unexpected GPU memory increases on the vLLM GPU, leading to OOM errors.

Description:

I'm fine-tuning a quantized (int4) model using TRL (likely based on the transformers library) and QLoRA(surport by unsloth-zoo code), with a two-GPU setup on NVIDIA A100s (40GB):

  • GPU 0 (cuda:0): Handles the TRL training process (specifically, updating the QLoRA adapter parameters).
  • GPU 1 (cuda:1): Runs vLLM for inference (generating text samples).

The setup should only involve parameter updates on GPU 0, with GPU 1 dedicated solely to vLLM. Initial memory usage on both GPUs is as expected, without OOM errors.

Image

However, during training, I observe a significant and unexpected increase in GPU memory usage on both GPUs, particularly on GPU 1 (the vLLM GPU), which should not be involved in the parameter updates. This eventually leads to an OOM error on GPU 1.
Image

Problem Isolation:

  1. move_model_vllm is NOT the cause: I've confirmed through breakpoint debugging that the move_model_vllm function (presumably a custom function to update the vLLM model with new LoRA weights) is not being called when the OOM error occurs. This rules out parameter updates to vLLM as the direct cause of the memory spike on GPU 1.

  2. DataParallel is the culprit: Further debugging revealed that the model, at the point where self._get_per_token_logps is called (likely within the TRL loss calculation), has been unexpectedly wrapped in torch.nn.DataParallel. The DataParallel instance includes both GPU 0 and GPU 1 (device_ids=[0, 1]). This means that parts of the model and/or data are being replicated onto GPU 1, causing the memory increase.

Image

  1. Likely Root Cause (Transformers Trainer): I suspect the issue originates within the transformers.Trainer (or a similar TRL training loop). It appears that the Trainer, upon detecting multiple GPUs (n_gpu > 1), automatically wraps the model in DataParallel without considering that one GPU is dedicated to a separate process (vLLM).

I further tracked this down and found that in the transformer's trainer, if the gpu is greater than 1 the model is automatically encapsulated as a dataparallel model and all graphics cards are included:

        if self.args.n_gpu > 1 and not getattr(model, "is_loaded_in_8bit", False):
            model = nn.DataParallel(model)

so I changed the above code to:

if self.args.n_gpu > 1 and not getattr(model, "is_loaded_in_8bit", False):
            if self.use_vllm:
                vllm_gpu_id = int(self.args.vllm_device.split(":")[1])
                all_device_ids = list(range(self.args.n_gpu))
                all_device_ids.remove(vllm_gpu_id)
                model = nn.DataParallel(model, device_ids=all_device_ids)
            else:
                model = nn.DataParallel(model)

The OOM error disappeared, and vram was assigned the way I had hoped:

Image

System Info

Copy-paste the following information when reporting an issue:

  • Platform: Linux-5.15.0-60-generic-x86_64-with-glibc2.31
  • Python version: 3.10.14
  • TRL version: 0.16.0.dev0+b55d9f0
  • PyTorch version: 2.5.1
  • CUDA device(s): NVIDIA A100-SXM4-40GB, NVIDIA A100-SXM4-40GB
  • Transformers version: 4.48.3
  • Accelerate version: 1.4.0
  • Accelerate config: not found
  • Datasets version: 3.0.1
  • HF Hub version: 0.29.1
  • bitsandbytes version: 0.45.3
  • DeepSpeed version: not installed
  • Diffusers version: 0.32.2
  • Liger-Kernel version: not installed
  • LLM-Blender version: not installed
  • OpenAI version: 1.65.1
  • PEFT version: 0.14.0
  • vLLM version: 0.7.3

Checklist

  • I have checked that my issue isn't already filed (see open issues)
  • I have included my system information
  • Any code provided is minimal, complete, and reproducible (more on MREs)
  • Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
  • Any traceback provided is complete

Metadata

Metadata

Assignees

No one assigned

    Labels

    🏋 GRPORelated to GRPO🐛 bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions