generated from fastai/nbdev_template
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Closed
Labels
🐛 bugSomething isn't workingSomething isn't working
Description
Reproduction
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM
import logging
logging.basicConfig(
format=f"%(asctime)s [%(levelname)6s %(name)s] [%(process)d %(processName)s | %(thread)d %(threadName)s] [%(pathname)s:%(lineno)d:%(funcName)s] %(message)s",
)
dataset = load_dataset("/mnt/workspace/.t1/CodeAlpaca-20k", split="train",streaming=True)
model_id = '/mnt/workspace/.t1/Qwen2.5-7B-Instruct'
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
def format_instruction(example):
conversation = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": example['instruction']+'\n'+example['input']},
{"role": "assistant", "content": example['output']},
]
formatted_text = tokenizer.apply_chat_template(
conversation,
tokenize=False,
)
return formatted_text
response_template = "<|im_start|>assistant\n"
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer, padding_free=True)
trainer = SFTTrainer(
model,
train_dataset=dataset,
args=SFTConfig(
output_dir="./tmp",
gradient_checkpointing=True,
per_device_train_batch_size=4,
report_to="tensorboard",
max_seq_length=None,
log_level='debug',
),
formatting_func=format_instruction,
data_collator=collator,
)
trainer.train()
dataset is sahil2801/CodeAlpaca-20k
when streaming=False, no error occurs,
when streaming=True, the following error occurs.
Traceback (most recent call last):
File "/mnt/workspace/.t1/test-sft.py", line 39, in <module>
trainer = SFTTrainer(
File "/usr/local/lib/python3.10/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 194, in __init__
train_dataset = self._prepare_dataset(
File "/usr/local/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 400, in _prepare_dataset
if "prompt" in dataset.column_names and "completion" in dataset.column_names:
TypeError: argument of type 'NoneType' is not iterable
System Info
TRL version: 0.15.2
Checklist
- I have checked that my issue isn't already filed (see open issues)
- I have included my system information
- Any code provided is minimal, complete, and reproducible (more on MREs)
- Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
- Any traceback provided is complete
Metadata
Metadata
Assignees
Labels
🐛 bugSomething isn't workingSomething isn't working