Skip to content

Padding bug in Mixtral modeling code when running DPO #1266

@rosario-purple

Description

@rosario-purple

System Info

  • transformers version: 4.36.2
  • Platform: Linux-5.15.0-91-generic-x86_64-with-glibc2.35
  • Python version: 3.10.13
  • Huggingface_hub version: 0.19.4
  • Safetensors version: 0.4.0
  • Accelerate version: 0.26.1
  • Accelerate config: - compute_environment: LOCAL_MACHINE
    - distributed_type: FSDP
    - mixed_precision: bf16
    - use_cpu: False
    - debug: False
    - num_processes: 8
    - machine_rank: 0
    - num_machines: 1
    - rdzv_backend: static
    - same_network: True
    - main_training_function: main
    - fsdp_config: {'fsdp_auto_wrap_policy': 'SIZE_BASED_WRAP', 'fsdp_backward_prefetch': 'BACKWARD_PRE', 'fsdp_cpu_ram_efficient_loading': True, 'fsdp_forward_prefetch': False, 'fsdp_min_num_params': 100000000, 'fsdp_offload_params': False, 'fsdp_sharding_strategy': 'FULL_SHARD', 'fsdp_state_dict_type': 'SHARDED_STATE_DICT', 'fsdp_sync_module_states': True, 'fsdp_use_orig_params': True}
    - downcast_bf16: no
    - tpu_use_cluster: False
    - tpu_use_sudo: False
    - tpu_env: []
  • PyTorch version (GPU?): 2.1.1+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): 0.7.5 (cpu)
  • Jax version: 0.4.21
  • JaxLib version: 0.4.21
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: No

Who can help?

@ArthurZucker @younesbelkada

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Training a DPO model using HuggingFace's DPOTrainer, using Mixtral (mistralai/Mixtral-8x7B-v0.1), produces this error:

  File "/scratch/miniconda3/envs/brr/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_i\
mpl
    return self._call_impl(*args, **kwargs)
  File "/scratch/miniconda3/envs/brr/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/scratch/miniconda3/envs/brr/lib/python3.10/site-packages/accelerate/utils/operations.py", line 687, in forward
    return model_forward(*args, **kwargs)
  File "/scratch/miniconda3/envs/brr/lib/python3.10/site-packages/accelerate/utils/operations.py", line 675, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/scratch/miniconda3/envs/brr/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
    return func(*args, **kwargs)
  File "/scratch/miniconda3/envs/brr/lib/python3.10/site-packages/peft/peft_model.py", line 1073, in forward
    return self.base_model(
  File "/scratch/miniconda3/envs/brr/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_i\
mpl
    return self._call_impl(*args, **kwargs)
  File "/scratch/miniconda3/envs/brr/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/scratch/miniconda3/envs/brr/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 103, in forward
    return self.model.forward(*args, **kwargs)
  File "/scratch/miniconda3/envs/brr/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/scratch/miniconda3/envs/brr/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.py", line 122\
3, in forward
    outputs = self.model(
  File "/scratch/miniconda3/envs/brr/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_i\
mpl
    return self._call_impl(*args, **kwargs)
  File "/scratch/miniconda3/envs/brr/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/scratch/miniconda3/envs/brr/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.py", line 104\
8, in forward
    raise ValueError(
ValueError: You are attempting to perform batched generation with padding_side='right' this may lead to unexpected behaviour\
 for Flash Attention version of Mixtral. Make sure to  call `tokenizer.padding_side  = 'left'` before tokenizing the input.

I believe this is caused by this bit of code in modeling_mixtral.py:

        if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
            is_padding_right = attention_mask[:, -1].sum().item() != batch_size
            if is_padding_right:
                raise ValueError(
                    "You are attempting to perform batched generation with padding_side='right'"
                    " this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to "
                    " call `tokenizer.padding_side  = 'left'` before tokenizing the input. "
                )

Setting the padding side of the tokenizer does not affect this. I'd guess what's going on is that DPO masks out different parts of the batch for the prompt tokens, chosen tokens, and rejection tokens, so this exception will always be thrown.

Expected behavior

The training run should not crash.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions