-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Description
Issue
To my knowledge, the current implementation of packing generates seq_lengths
in SFTTrainer
and then generates position_ids on the fly in the data collator.
If this is true, there might be a bug with packing at the moment: seq_lengths
, which is generated by pack_dataset
in SFTTrainer
,
trl/trl/trainer/sft_trainer.py
Line 810 in 686cd35
dataset = pack_dataset(dataset, args.max_length, args.packing_strategy, map_kwargs) |
is not a signature column, so it's removed after being passed to
Trainer
.
To elaborate, the custom DataCollatorForLanguageModelling in sft_trainer.py correctly generates the corresponding position_ids
when it receives seq_lengths
. But, seq_lengths
is removed before it can even arrive at the collator.
The key issue seems to be this block of code in _get_data_loader
in trainer.py,
https://github.com/huggingface/transformers/blob/896e9cea1ade521b2648f4798218550f6c72190c/src/transformers/trainer.py#L1007
dataset = self._remove_unused_columns(dataset, description=description)
When args.remove_unused_columns
is True, _remove_unused_columns
proceeds to remove non-signature columns, and seq_lengths
is not recognized as a signature column in the forward pass of the model. Thus, when the collator receives the data, it notices that seq_length
is missing and generates position_ids as if the entire sequence were whole.
Tangetially, the code is currently incompatible with Liger kernels as well. These lines in sft_trainer.py
if args.use_liger_kernel:
dataset = dataset.select_columns(
{"input_ids", "position_ids", "completion_mask"}.intersection(dataset.column_names)
)
should be adjusted to keep seq_lengths.
Investigation
This is the difference in position_ids when remove_unused_columns
is set to True (Top) and False (Bottom) on an example dataset (batch_size = 1).
Here is an example batch (size 1) from the data_loader. It should only contain position_ids
, but the dataloader has computed the attention_mask
due to a lack of seq_lengths
.
{'input_ids': tensor([[ 3923, 374, 4461, ..., 288, 13, 100257]],
device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, ..., 1, 1, 1]], device='cuda:0'), 'position_ids': tensor([[ 0, 1, 2, ..., 3362, 3363, 3364]], device='cuda:0'), 'labels': tensor([[ 3923, 374, 4461, ..., 288, 13, 100257]],
device='cuda:0')}
The impact of this is concerning, as 'remove_unused_columns` is True by default. At the same time, the short-term fix should be as easy as adding a warning and, perhaps, setting the default to False.
Code
Here is a simple snippet that can be run to replicate this based on the main branches of trl
and transformers
.
%pip install "transformers[torch]"
%pip install git+https://github.com/huggingface/trl.git
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
)
from trl import SFTTrainer, SFTConfig
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-0.6B",
torch_dtype="auto",
device_map="auto",
attn_implementation = "flash_attention_2"
)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B", use_fast=True)
EOT_TOKEN = "\nResponse:"
dataset = load_dataset("GAIR/lima")
train_dataset = dataset["train"].shuffle(seed=42)
# 3. Define the formatting function
def format_lima_conversation(example):
conversation = example['conversations']
# Join turns with the EOT token. Add one at the very end.
formatted_text = f"{EOT_TOKEN}".join(conversation) + tokenizer.eos_token
return {"text": formatted_text}
# 4. Apply the formatting
train_dataset = train_dataset.map(format_lima_conversation, remove_columns=['conversations', 'source'])
training_args = SFTConfig(
dataset_text_field="text",
packing = True,
max_length = 4096,
per_device_train_batch_size = 1
# remove_unused_columns= False
)
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset = train_dataset
)
for i, batch in enumerate(trainer.get_train_dataloader()):
if i == 0:
print(batch)
print(len(batch['position_ids']))
print(batch['position_ids'][0])
for j in batch['position_ids'][0]:
if j == 0:
print("seq found")
Strangely, the training loss curve is better with cross-contamination. Here are training runs of doing instruction tuning (LIMA dataset) with 'remove_unused_columns` True and False. Manual review of the outputs seems fine for both.