Skip to content

IterableDataset is not compatible with SFTTrainer #3030

@loricxy0707

Description

@loricxy0707

Reproduction

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM

import logging
logging.basicConfig(
    format=f"%(asctime)s [%(levelname)6s %(name)s] [%(process)d %(processName)s | %(thread)d %(threadName)s] [%(pathname)s:%(lineno)d:%(funcName)s] %(message)s",
)

dataset = load_dataset("/mnt/workspace/.t1/CodeAlpaca-20k", split="train",streaming=True)

model_id = '/mnt/workspace/.t1/Qwen2.5-7B-Instruct'
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    )
tokenizer = AutoTokenizer.from_pretrained(model_id)

def format_instruction(example):

    conversation = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": example['instruction']+'\n'+example['input']},
        {"role": "assistant", "content": example['output']},
    ]

    formatted_text = tokenizer.apply_chat_template(
        conversation,
        tokenize=False,
    )
    return formatted_text


response_template = "<|im_start|>assistant\n"
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer, padding_free=True)

trainer = SFTTrainer(
    model,
    train_dataset=dataset,
    args=SFTConfig(
        output_dir="./tmp",
        gradient_checkpointing=True,
        per_device_train_batch_size=4,
        report_to="tensorboard",
        max_seq_length=None,
        log_level='debug',
    ),
    formatting_func=format_instruction,
    data_collator=collator,
   
)

trainer.train()

dataset is sahil2801/CodeAlpaca-20k
when streaming=False, no error occurs,
when streaming=True, the following error occurs.

Traceback (most recent call last):
  File "/mnt/workspace/.t1/test-sft.py", line 39, in <module>
    trainer = SFTTrainer(
  File "/usr/local/lib/python3.10/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 194, in __init__
    train_dataset = self._prepare_dataset(
  File "/usr/local/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 400, in _prepare_dataset
    if "prompt" in dataset.column_names and "completion" in dataset.column_names:
TypeError: argument of type 'NoneType' is not iterable

System Info

TRL version: 0.15.2

Checklist

  • I have checked that my issue isn't already filed (see open issues)
  • I have included my system information
  • Any code provided is minimal, complete, and reproducible (more on MREs)
  • Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
  • Any traceback provided is complete

Metadata

Metadata

Assignees

No one assigned

    Labels

    🐛 bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions