Skip to content

[SFT][DPO] Qwen2/2.5+FlashAttention2+Validation steps don't work #3306

@benjamin-marie

Description

@benjamin-marie

Reproduction

Here is my training config:

training_arguments = DPOConfig(
        output_dir=output_dir,
        eval_strategy="steps",
        do_eval=True,
        optim="paged_adamw_8bit",
        per_device_train_batch_size=8,
        gradient_accumulation_steps=8,
        per_device_eval_batch_size=1,
        log_level="debug",
        save_strategy="steps",
        save_steps=200,
        logging_steps=25,
        learning_rate=lr,
        bf16 = True,
        beta = 0.1,
        eval_steps=25,
        #num_train_epochs=1,
        max_steps=200,
        warmup_ratio=0.1,
        lr_scheduler_type="linear",
        max_length=mseqlen,
        max_prompt_length=mseqlen,
        model_adapter_name="DPO",
        ref_adapter_name="reference",
        dataset_num_proc=multiprocessing.cpu_count(),
)

outputs:

File /usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
   1745 # If we don't have any hooks, we want to skip the rest of the logic in
   1746 # this function, and just call forward.
   1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1748         or _global_backward_pre_hooks or _global_backward_hooks
   1749         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750     return forward_call(*args, **kwargs)
   1752 result = None
   1753 called_always_called_hooks = set()

File /usr/local/lib/python3.11/dist-packages/transformers/utils/generic.py:965, in can_return_tuple.<locals>.wrapper(self, *args, **kwargs)
    962     set_attribute_for_modules(self, "_is_top_level_module", False)
    964 try:
--> 965     output = func(self, *args, **kwargs)
    966     if is_requested_to_return_tuple or (is_configured_to_return_tuple and is_top_level_module):
    967         output = output.to_tuple()

File /usr/local/lib/python3.11/dist-packages/transformers/models/qwen2/modeling_qwen2.py:519, in Qwen2Model.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, cache_position, **flash_attn_kwargs)
    516 if position_ids is None:
    517     position_ids = cache_position.unsqueeze(0)
--> 519 causal_mask = self._update_causal_mask(
    520     attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
    521 )
    523 hidden_states = inputs_embeds
    525 # create position embeddings to be shared across the decoder layers

File /usr/local/lib/python3.11/dist-packages/transformers/models/qwen2/modeling_qwen2.py:591, in Qwen2Model._update_causal_mask(self, attention_mask, input_tensor, cache_position, past_key_values, output_attentions)
    589     is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
    590     if is_padding_right:
--> 591         raise ValueError(
    592             "You are attempting to perform batched generation with padding_side='right'"
    593             " this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to "
    594             " call `tokenizer.padding_side  = 'left'` before tokenizing the input. "
    595         )
    596 if attention_mask is not None and 0.0 in attention_mask:
    597     return attention_mask

ValueError: You are attempting to perform batched generation with padding_side='right' this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to  call `tokenizer.padding_side  = 'left'` before tokenizing the input. 

Training works until it reaches the validation steps.
I explicitly set tokenizer.padding_side = 'left', but it is probably overridden somewhere.
Note that padding_side = 'right' was working well before this Transformers change:
huggingface/transformers@96f01a3

It denies FA2 + padding right for Qwen2.

System Info

  • Platform: Linux-6.8.0-52-generic-x86_64-with-glibc2.35
  • Python version: 3.11.10
  • TRL version: 0.16.1
  • PyTorch version: 2.6.0+cu126
  • CUDA device(s): NVIDIA A40
  • Transformers version: 4.51.3
  • Accelerate version: 1.6.0
  • Accelerate config: not found
  • Datasets version: 3.5.0
  • HF Hub version: 0.30.2
  • bitsandbytes version: 0.45.5
  • DeepSpeed version: not installed
  • Diffusers version: not installed
  • Liger-Kernel version: not installed
  • LLM-Blender version: not installed
  • OpenAI version: not installed
  • PEFT version: 0.15.2
  • vLLM version: not installed

Checklist

  • I have checked that my issue isn't already filed (see open issues)
  • I have included my system information
  • Any code provided is minimal, complete, and reproducible (more on MREs)
  • Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
  • Any traceback provided is complete

Metadata

Metadata

Assignees

No one assigned

    Labels

    🏋 DPORelated to DPO🏋 SFTRelated to SFT🐛 bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions