Skip to content

keep_end + max_length causes NaNs in trainer_state.json #3382

@jdebaer

Description

@jdebaer

Reproduction

Hi - I see existing issues that are similar, but I'm not not 100% sure they are identical. Apologies if so.

Symptom:

trainer_state.json contains entries like this:

  "eval_logits/rejected": NaN,
  "logits/chosen": NaN,

Setup/reproduce:

  • trl 0.16.0
  • load_dataset("trl-lib/ultrafeedback_binarized"
  • dataset_train = dataset_dict["train"].shuffle(seed=42)
  • model_name = "allenai/OLMo-1B-hf"
  • run DPO in 32-bit
  • max_length=1024 (important to reproduce)
    => I see these NaNs start appearing around step 150 with an EBS of 8 (4 grad acc x 2 b_s on 1 GPU).

Root cause:

In dpo_trainer.py in concatenated_forward(), when "keep_end" is applied, the following code cuts from the right (1024):

            input_ids = input_ids[:, -self.max_length :]
            attention_mask = attention_mask[:, -self.max_length :]
            loss_mask = loss_mask[:, -self.max_length :]

In batches where certain answers are very short, this can lead to input_ids completely filled up with padding tokens, with a loss_mask that is all zeroes (i.e., the answer is so short that no relevant tokens get "captured").

This has the following effect:

  1. mean_chosen_logits = logits[:num_examples][loss_mask[:num_examples]].mean() -> this is what produces the NaN in the trainer_state.json, as the filtering by the loss_mask (all zeros) results in an empty tensor on which we take the mean(). (Similar for mean_rejected_logits, depends on where the short samples are.)

Print of logits[:num_examples][loss_mask[:num_examples]]} when issue occurs::
tensor([], device='cuda:0', size=(0, 50304), grad_fn=) <<<< empty tensor

Print of logits[num_examples:][loss_mask[num_examples:]]} when issue occurs:
tensor([[ 0.6418, -2.8145, 11.0546, ..., -4.8087, -2.1290, -3.3897],
[-0.3914, -2.5756, 8.6170, ..., -3.8017, -1.5249, -2.8255],
.... (second series of logits is missing, masked out by loss_mask)

  1. For the logps the effect is different, since we have: per_token_logps[~loss_mask] = 0
    What happens for logps: since the loss_mask covers the whole sequence, we end up with all zeros, so eventually output["chosen_logps"] will have a "0" in it (the sum), which is later on passed on to the dpo_loss() function (where I think an incorrect loss will be calculated).

Print of logits[:num_examples]} when issue occurs:
tensor([0., 0.], device='cuda:0', grad_fn=) <<< both chosen answers are "too short"

Print of loss_mask[:num_examples]} when issue occurs:
tensor([-2771.5366, 0.0000], device='cuda:0', grad_fn=) <<< one rejected answer is also "too short"


I've been thinking about ways to avoid or mitigate this.

  1. We can let it happen and issue a warning, so that the user can increase max_length, but this is of course no guarantee (and the model needs to support it).

  2. We can add a check in get_batch_loss_metrics() to 1) see if any logps in model_output are 0 (as per the above), and 2) filter out the corresponding samples (chosen/rejected) from the batch before calling dpo_loss.

  3. If we notice an "all zeros" loss_mask for a response, then I think we could set the corresponding ch/rej response's loss_mask to "all zeros" as well. That way (I think) no loss will be calculated and at least nothing erroneous will be backpropagated for that particular problematic pair.

Thank you!

System Info

  • Platform: Linux-5.15.0-100-generic-x86_64-with-glibc2.35
  • Python version: 3.11.11
  • TRL version: 0.16.0
  • PyTorch version: 2.6.0
  • CUDA device(s): not available << I quickly ran trl env in a non-GPU container, however I can reproduce the issue on any of our V100/A40/A100 clusters - all have cuda 12.2
  • Transformers version: 4.51.2
  • Accelerate version: 1.6.0
  • Accelerate config: not found
  • Datasets version: 3.5.0
  • HF Hub version: 0.30.2
  • bitsandbytes version: not installed
  • DeepSpeed version: not installed
  • Diffusers version: not installed
  • Liger-Kernel version: not installed
  • LLM-Blender version: not installed
  • OpenAI version: not installed
  • PEFT version: not installed
  • vLLM version: not installed

Checklist

  • I have checked that my issue isn't already filed (see open issues)
  • I have included my system information
  • Any code provided is minimal, complete, and reproducible (more on MREs)
  • Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
  • Any traceback provided is complete

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions