Skip to content

Packing is currently broken? #3705

@jiosephlee

Description

@jiosephlee

Issue

To my knowledge, the current implementation of packing generates seq_lengths in SFTTrainer and then generates position_ids on the fly in the data collator.

If this is true, there might be a bug with packing at the moment: seq_lengths, which is generated by pack_dataset in SFTTrainer,

dataset = pack_dataset(dataset, args.max_length, args.packing_strategy, map_kwargs)

is not a signature column, so it's removed after being passed to Trainer.

To elaborate, the custom DataCollatorForLanguageModelling in sft_trainer.py correctly generates the corresponding position_ids when it receives seq_lengths. But, seq_lengths is removed before it can even arrive at the collator.

The key issue seems to be this block of code in _get_data_loader in trainer.py,

https://github.com/huggingface/transformers/blob/896e9cea1ade521b2648f4798218550f6c72190c/src/transformers/trainer.py#L1007
dataset = self._remove_unused_columns(dataset, description=description)

When args.remove_unused_columns is True, _remove_unused_columns proceeds to remove non-signature columns, and seq_lengths is not recognized as a signature column in the forward pass of the model. Thus, when the collator receives the data, it notices that seq_length is missing and generates position_ids as if the entire sequence were whole.

Tangetially, the code is currently incompatible with Liger kernels as well. These lines in sft_trainer.py

if args.use_liger_kernel:
    dataset = dataset.select_columns(
        {"input_ids", "position_ids", "completion_mask"}.intersection(dataset.column_names)
    )

should be adjusted to keep seq_lengths.

Investigation

This is the difference in position_ids when remove_unused_columns is set to True (Top) and False (Bottom) on an example dataset (batch_size = 1).

Image Image

Here is an example batch (size 1) from the data_loader. It should only contain position_ids, but the dataloader has computed the attention_mask due to a lack of seq_lengths.

{'input_ids': tensor([[  3923,    374,   4461,  ...,    288,     13, 100257]],
       device='cuda:0'), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1]], device='cuda:0'), 'position_ids': tensor([[   0,    1,    2,  ..., 3362, 3363, 3364]], device='cuda:0'), 'labels': tensor([[  3923,    374,   4461,  ...,    288,     13, 100257]],
       device='cuda:0')}

The impact of this is concerning, as 'remove_unused_columns` is True by default. At the same time, the short-term fix should be as easy as adding a warning and, perhaps, setting the default to False.

Code

Here is a simple snippet that can be run to replicate this based on the main branches of trl and transformers.

%pip install "transformers[torch]"
%pip install git+https://github.com/huggingface/trl.git

from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
)
from trl import SFTTrainer, SFTConfig

model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen3-0.6B",
    torch_dtype="auto",    
    device_map="auto",  
    attn_implementation = "flash_attention_2"         
)
tokenizer   = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B", use_fast=True)

EOT_TOKEN = "\nResponse:"
dataset = load_dataset("GAIR/lima")
train_dataset = dataset["train"].shuffle(seed=42)

# 3. Define the formatting function
def format_lima_conversation(example):
    conversation = example['conversations']
    # Join turns with the EOT token. Add one at the very end.
    formatted_text = f"{EOT_TOKEN}".join(conversation) + tokenizer.eos_token
    return {"text": formatted_text}
    
# 4. Apply the formatting
train_dataset = train_dataset.map(format_lima_conversation, remove_columns=['conversations', 'source'])
training_args = SFTConfig(
    dataset_text_field="text",
    packing = True,
    max_length = 4096,
    per_device_train_batch_size = 1
    # remove_unused_columns= False
)
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset = train_dataset
)

for i, batch in enumerate(trainer.get_train_dataloader()):
    if i == 0:
        print(batch)
        print(len(batch['position_ids']))
        print(batch['position_ids'][0])
        for j in batch['position_ids'][0]:
            if j == 0:
                print("seq found")

Strangely, the training loss curve is better with cross-contamination. Here are training runs of doing instruction tuning (LIMA dataset) with 'remove_unused_columns` True and False. Manual review of the outputs seems fine for both.

Image

Metadata

Metadata

Assignees

No one assigned

    Labels

    🏋 SFTRelated to SFT🐛 bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions