-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Description
Issue
When initializing SFTTrainer
for Qwen models (tested on Qwen2.5 instruct family and on trl-internal-testing/tiny-Qwen2ForCausalLM-2.5
), the _prepare_dataset()
function appends an extra eos_token
to the sequence.
This is happening because in Qwen2.5 chat template <|im_end|>
token is always foolowed by \n
. The fix introduced in #3091 adds an extra eos_token
if the tokenized text sequence doesn't end with eos_token
. In this case, since the last token in the sequence (formatted chat template) is \n
, the extra eos_token
is added, resulting in <|im_end|>\n<|im_end|>
at the end of the sequence.
I understand it can be "fixed" by passing the already processed dataset to SFTTrainer
, but just wanted to bring your attention that it doesn't process the dataset out-of-the-box as expected.
Reproduction
from trl import SFTConfig, SFTTrainer
from datasets import Dataset
dataset = Dataset.from_dict(
{
"messages": [
[
{"role": "user", "content": "What is better than ugly?"},
{"role": "assistant", "content": "Beautiful."},
],
]
}
)
training_args = SFTConfig(
seed=0,
output_dir="some_dir",
max_length=100,
report_to="none",
per_device_train_batch_size=1,
)
trainer = SFTTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
args=training_args,
train_dataset=dataset,
)
sample = next(iter(trainer.get_train_dataloader()))
print(sample["input_ids"][0])
Output:
tensor([151644, 8948, 198, 2610, 525, 1207, 16948, 11, 3465,
553, 54364, 14817, 13, 1446, 525, 264, 10950, 17847,
13, 151645, 198, 151644, 872, 198, 3838, 374, 2664,
1091, 27261, 30, 151645, 198, 151644, 77091, 198, 46518,
13, 151645, 198, 151645])
Expected Output:
tensor([151644, 8948, 198, 2610, 525, 1207, 16948, 11, 3465,
553, 54364, 14817, 13, 1446, 525, 264, 10950, 17847,
13, 151645, 198, 151644, 872, 198, 3838, 374, 2664,
1091, 27261, 30, 151645, 198, 151644, 77091, 198, 46518,
13, 151645, 198])
eos_token_id
is 151645, and 198 is \n
.
System Info
- Platform: Linux-5.10.234-225.921.amzn2.x86_64-x86_64-with-glibc2.35
- Python version: 3.11.11
- TRL version: 0.16.1
- PyTorch version: 2.5.1
- CUDA device(s): not available
- Transformers version: 4.49.0
- Accelerate version: 0.34.2
- Accelerate config: not found
- Datasets version: 3.5.0
- HF Hub version: 0.29.1
- bitsandbytes version: 0.45.5
- DeepSpeed version: not installed
- Diffusers version: not installed
- Liger-Kernel version: not installed
- LLM-Blender version: not installed
- OpenAI version: not installed
- PEFT version: 0.15.1
- vLLM version: not installed
Possible Workaround
We need to manually add the eos_token
only if maybe_apply_chat_template
didn't apply the chat template (as we can assume the chat template formatted with example is correct). This can be achieved by merging mapping with maybe_apply_chat_template
and tokenize
into one map with the following function:
def process_example(example, processing_class, dataset_text_field, tools):
add_eos_token = True
if is_conversational(example):
example = apply_chat_template(example, processing_class, tools)
add_eos_token = False
return tokenze(example, processing_class, dataset_text_field, add_eos_token)
def tokenize(example, processing_class, dataset_text_field, add_eos_token):
processed = processing_class(text=example[dataset_text_field])
if (
add_eos_token
and processing_class.eos_token_id is not None
and processed["input_ids"][-1] != processing_class.eos_token_id
):
processed["input_ids"] = processed["input_ids"] + [processing_class.eos_token_id]
processed["attention_mask"] = processed["attention_mask"] + [1]
return processed
Checklist
- I have checked that my issue isn't already filed (see open issues)
- I have included my system information
- Any code provided is minimal, complete, and reproducible (more on MREs)
- Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
- Any traceback provided is complete