Skip to content

SFTTrainer._prepare_dataset() adds an extra eos_token for Qwen2.5 #3318

@aalekseev1

Description

@aalekseev1

Issue

When initializing SFTTrainer for Qwen models (tested on Qwen2.5 instruct family and on trl-internal-testing/tiny-Qwen2ForCausalLM-2.5), the _prepare_dataset() function appends an extra eos_token to the sequence.

This is happening because in Qwen2.5 chat template <|im_end|> token is always foolowed by \n. The fix introduced in #3091 adds an extra eos_token if the tokenized text sequence doesn't end with eos_token. In this case, since the last token in the sequence (formatted chat template) is \n, the extra eos_token is added, resulting in <|im_end|>\n<|im_end|> at the end of the sequence.

I understand it can be "fixed" by passing the already processed dataset to SFTTrainer, but just wanted to bring your attention that it doesn't process the dataset out-of-the-box as expected.

Reproduction

from trl import SFTConfig, SFTTrainer
from datasets import Dataset

dataset = Dataset.from_dict(
    {
        "messages": [
            [
                {"role": "user", "content": "What is better than ugly?"},
                {"role": "assistant", "content": "Beautiful."},
            ],
        ]
    }
)

training_args = SFTConfig(
    seed=0,
    output_dir="some_dir",
    max_length=100,
    report_to="none",
    per_device_train_batch_size=1,
)

trainer = SFTTrainer(
    model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
    args=training_args,
    train_dataset=dataset,
)

sample = next(iter(trainer.get_train_dataloader()))
print(sample["input_ids"][0])

Output:

tensor([151644,   8948,    198,   2610,    525,   1207,  16948,     11,   3465,
           553,  54364,  14817,     13,   1446,    525,    264,  10950,  17847,
            13, 151645,    198, 151644,    872,    198,   3838,    374,   2664,
          1091,  27261,     30, 151645,    198, 151644,  77091,    198,  46518,
            13, 151645,    198, 151645])

Expected Output:

tensor([151644,   8948,    198,   2610,    525,   1207,  16948,     11,   3465,
           553,  54364,  14817,     13,   1446,    525,    264,  10950,  17847,
            13, 151645,    198, 151644,    872,    198,   3838,    374,   2664,
          1091,  27261,     30, 151645,    198, 151644,  77091,    198,  46518,
            13, 151645,    198])

eos_token_id is 151645, and 198 is \n.

System Info

- Platform: Linux-5.10.234-225.921.amzn2.x86_64-x86_64-with-glibc2.35
- Python version: 3.11.11
- TRL version: 0.16.1
- PyTorch version: 2.5.1
- CUDA device(s): not available
- Transformers version: 4.49.0
- Accelerate version: 0.34.2
- Accelerate config: not found
- Datasets version: 3.5.0
- HF Hub version: 0.29.1
- bitsandbytes version: 0.45.5
- DeepSpeed version: not installed
- Diffusers version: not installed
- Liger-Kernel version: not installed
- LLM-Blender version: not installed
- OpenAI version: not installed
- PEFT version: 0.15.1
- vLLM version: not installed

Possible Workaround

We need to manually add the eos_token only if maybe_apply_chat_template didn't apply the chat template (as we can assume the chat template formatted with example is correct). This can be achieved by merging mapping with maybe_apply_chat_template and tokenize into one map with the following function:

def process_example(example, processing_class, dataset_text_field, tools):
	add_eos_token = True
    if is_conversational(example):
        example = apply_chat_template(example, processing_class, tools)
        add_eos_token = False

    return tokenze(example, processing_class, dataset_text_field, add_eos_token)

def tokenize(example, processing_class, dataset_text_field, add_eos_token):
	processed = processing_class(text=example[dataset_text_field])
	if (
		add_eos_token
		and processing_class.eos_token_id is not None
		and processed["input_ids"][-1] != processing_class.eos_token_id
	):
		processed["input_ids"] = processed["input_ids"] + [processing_class.eos_token_id]
		processed["attention_mask"] = processed["attention_mask"] + [1]
	return processed

Checklist

  • I have checked that my issue isn't already filed (see open issues)
  • I have included my system information
  • Any code provided is minimal, complete, and reproducible (more on MREs)
  • Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
  • Any traceback provided is complete

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions