Skip to content

📨 [SFT] Tokenize directly when applying the chat template #3572

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 13, 2025

Conversation

qgallouedec
Copy link
Member

@qgallouedec qgallouedec commented Jun 12, 2025

What this PR does

This PR simplifies how we tokenize conversational data. Previously, we used:

text = tokenizer.apply_chat_template(messages, tokenize=False)
input_ids = tokenizer(text)

Now, we directly do:

input_ids = tokenizer.apply_chat_template(messages, tokenize=True)

Why

This change enables future support for return_assistant_tokens_mask, which is useful for training on assistant-only tokens.

Impact

  • User-facing: No changes.

  • Internally: Minimal changes, mostly cosmetic (see table below):

    • attention_mask is no longer added by default—it was always filled with 1s and later replaced by the collator.
    • The "messages" column is now preserved when remove_unused_columns=False, rather than being replaced with "text". This improves clarity.

Column changes summary

Branch Input Column Output Columns
main "text" ["input_ids", "attention_mask", "text"]
PR "text" ["input_ids", "text"]
main "messages" ["input_ids", "attention_mask", "text"]
PR "messages" ["input_ids", "messages"]

Functional Equivalence

All previous dataset preparation workflows remain functionally equivalent. A full equivalence test (including token IDs, position_ids, and completion_mask when applicable) has been run across configurations like:

  • Conversational & standard LM
  • Prompt-completion (conversational & standard)
  • With/without packing
  • With/without additional columns

No regressions were found.

from datasets import Dataset, load_dataset
from accelerate import PartialState
import warnings
from trl.data_utils import (
    maybe_convert_to_chatml,
    is_conversational,
    apply_chat_template,
    pack_dataset,
    truncate_dataset,
)
from trl import SFTConfig
from transformers import AutoTokenizer


def prepare_dataset_old(
    dataset,
    processing_class,
    args,
    packing,
    formatting_func,
    dataset_name,
):
    # If the dataset is already preprocessed (tokenized), skip the processing steps.
    column_names = list(next(iter(dataset)).keys())
    is_processed = "input_ids" in column_names

    # Build the kwargs for the `map` function
    map_kwargs = {}
    if isinstance(dataset, Dataset):  # IterableDataset does not support num_proc
        map_kwargs["num_proc"] = args.dataset_num_proc

    with PartialState().main_process_first():
        # Apply the formatting function if any
        if formatting_func is not None and is_processed:
            warnings.warn(
                "You passed a dataset that is already processed (contains an `input_ids` field) together with a "
                "formatting function. Therefore `formatting_func` will be ignored. Either remove the "
                "`formatting_func` or pass a dataset that is not already processed.",
                UserWarning,
            )

        if formatting_func is not None and not is_processed:
            if isinstance(dataset, Dataset):  # `IterableDataset.map` does not support `desc`
                map_kwargs["desc"] = f"Applying formatting function to {dataset_name} dataset"

            def _func(example):
                return {"text": formatting_func(example)}

            try:
                dataset = dataset.map(_func, batched=False, **map_kwargs)
            except Exception as e:
                warnings.warn(
                    f"Failed to apply the formatting function due to the following error: {e}. This may be "
                    "because the function is designed for batched input. Please update it to process one example "
                    "at a time (i.e., accept and return a single example). For now, we will attempt to apply the "
                    "function in batched mode, but note that batched formatting is deprecated and will be removed "
                    "in version 0.21.",
                    DeprecationWarning,
                )
                dataset = dataset.map(_func, batched=True, **map_kwargs)

        if not is_processed:
            # Convert the dataset to ChatML if needed
            if isinstance(dataset, Dataset):  # `IterableDataset.map` does not support `desc`
                map_kwargs["desc"] = f"Converting {dataset_name} dataset to ChatML"
            column_names = next(iter(dataset)).keys()
            dataset = dataset.map(
                maybe_convert_to_chatml,
                remove_columns="conversations" if "conversations" in column_names else None,
                **map_kwargs,
            )

            # Apply the chat template if needed
            first_example = next(iter(dataset))
            if is_conversational(first_example):
                if isinstance(dataset, Dataset):  # `IterableDataset.map` does not support `desc`
                    map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset"
                column_names = first_example.keys()
                dataset = dataset.map(
                    apply_chat_template,
                    fn_kwargs={"tokenizer": processing_class},
                    remove_columns="messages" if "messages" in column_names else None,  # renamed to "text"
                    **map_kwargs,
                )
                # Subsequent tokenization won't add special tokens (mostly for bos).
                # See https://huggingface.co/blog/qgallouedec/gotchas-in-tokenizer-behavior#7-chat-template-and-tokenization-dont-compose-due-to-special-tokens
                add_special_tokens = False
            # When dataset is not conversational, we need to add the EOS token at the end of each example
            # We don't need to do this for conversational datasets as this is already handled by the
            # `apply_chat_template` function.
            else:
                if isinstance(dataset, Dataset):  # `IterableDataset.map` does not support `desc`
                    map_kwargs["desc"] = f"Adding EOS to {dataset_name} dataset"

                def add_eos(example, eos_token):
                    if "text" in example and not example["text"].endswith(eos_token):  # language modeling case
                        example["text"] = example["text"] + eos_token
                    elif "completion" in example and not example["completion"].endswith(eos_token):
                        example["completion"] = example["completion"] + eos_token
                    return example

                dataset = dataset.map(
                    add_eos,
                    fn_kwargs={"eos_token": processing_class.eos_token},
                    remove_columns="messages" if "messages" in column_names else None,  # renamed to "text"
                    **map_kwargs,
                )
                # Subsequent tokenization will add special tokens (mostly for bos).
                # See https://huggingface.co/blog/qgallouedec/gotchas-in-tokenizer-behavior#7-chat-template-and-tokenization-dont-compose-due-to-special-tokens
                add_special_tokens = True

            # Tokenize the dataset
            if isinstance(dataset, Dataset):  # `IterableDataset.map` does not support `desc`
                map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"

            def tokenize(example, processing_class, dataset_text_field, add_special_tokens):
                if "prompt" in example:  # prompt-completion case
                    processed_prompt = processing_class(
                        text=example["prompt"],
                        add_special_tokens=add_special_tokens,
                    )
                    processed = processing_class(
                        text=example["prompt"] + example["completion"], add_special_tokens=add_special_tokens
                    )

                    # Check if the tokenized prompt starts with the tokenized prompt+completion
                    prompt_ids = processed_prompt["input_ids"]
                    prompt_completion_ids = processed["input_ids"]
                    if not prompt_completion_ids[: len(prompt_ids)] == prompt_ids:
                        warnings.warn(
                            "Mismatch between tokenized prompt and the start of tokenized prompt+completion. "
                            "This may be due to unexpected tokenizer behavior, whitespace issues, or special "
                            "token handling. Verify that the tokenizer is processing text consistently."
                        )

                    # Create a completion mask
                    completion_mask = [0] * len(prompt_ids) + [1] * (len(prompt_completion_ids) - len(prompt_ids))
                    processed = {**processed, "completion_mask": completion_mask}

                else:  # language modeling case
                    processed = processing_class(
                        text=example[dataset_text_field], add_special_tokens=add_special_tokens
                    )
                return processed

            dataset = dataset.map(
                tokenize,
                fn_kwargs={
                    "processing_class": processing_class,
                    "dataset_text_field": args.dataset_text_field,
                    "add_special_tokens": add_special_tokens,
                },
                **map_kwargs,
            )

        # Pack or truncate
        if packing:
            if args.max_length is None:
                raise ValueError("When packing is enabled, `max_length` can't be `None`.")
            if isinstance(dataset, Dataset):  # `IterableDataset.map` does not support `desc`
                map_kwargs["desc"] = f"Packing {dataset_name} dataset"
            dataset = dataset.select_columns("input_ids")
            # Packing adds new column "position_ids" needed for document aware flash attention
            dataset = pack_dataset(dataset, args.max_length, args.packing_strategy, map_kwargs)
        elif args.max_length is not None:
            if isinstance(dataset, Dataset):  # `IterableDataset.map` does not support `desc`
                map_kwargs["desc"] = f"Truncating {dataset_name} dataset"
            dataset = truncate_dataset(dataset, args.max_length, map_kwargs)
        # For Liger kernel, ensure only input_ids is present
        if args.use_liger_kernel:
            dataset = dataset.select_columns({"input_ids", "position_ids"}.intersection(dataset.column_names))

    return dataset


def prepare_dataset_new(
    dataset,
    processing_class,
    args,
    packing,
    formatting_func,
    dataset_name,
):
    # If the dataset is already preprocessed (tokenized), skip the processing steps.
    column_names = list(next(iter(dataset)).keys())
    is_processed = "input_ids" in column_names

    # Build the kwargs for the `map` function
    map_kwargs = {}
    if isinstance(dataset, Dataset):  # IterableDataset does not support num_proc
        map_kwargs["num_proc"] = args.dataset_num_proc

    with PartialState().main_process_first():
        # Apply the formatting function if any
        if formatting_func is not None and is_processed:
            warnings.warn(
                "You passed a dataset that is already processed (contains an `input_ids` field) together with a "
                "formatting function. Therefore `formatting_func` will be ignored. Either remove the "
                "`formatting_func` or pass a dataset that is not already processed.",
                UserWarning,
            )

        if formatting_func is not None and not is_processed:
            if isinstance(dataset, Dataset):  # `IterableDataset.map` does not support `desc`
                map_kwargs["desc"] = f"Applying formatting function to {dataset_name} dataset"

            def _func(example):
                return {"text": formatting_func(example)}

            try:
                dataset = dataset.map(_func, batched=False, **map_kwargs)
            except Exception as e:
                warnings.warn(
                    f"Failed to apply the formatting function due to the following error: {e}. This may be "
                    "because the function is designed for batched input. Please update it to process one example "
                    "at a time (i.e., accept and return a single example). For now, we will attempt to apply the "
                    "function in batched mode, but note that batched formatting is deprecated and will be removed "
                    "in version 0.21.",
                    DeprecationWarning,
                )
                dataset = dataset.map(_func, batched=True, **map_kwargs)

        if not is_processed:
            # Convert the dataset to ChatML if needed
            if isinstance(dataset, Dataset):  # `IterableDataset.map` does not support `desc`
                map_kwargs["desc"] = f"Converting {dataset_name} dataset to ChatML"
            column_names = next(iter(dataset)).keys()
            dataset = dataset.map(
                maybe_convert_to_chatml,
                remove_columns="conversations" if "conversations" in column_names else None,
                **map_kwargs,
            )

            # Apply the chat template if needed
            first_example = next(iter(dataset))
            if not is_conversational(first_example):
                if isinstance(dataset, Dataset):  # `IterableDataset.map` does not support `desc`
                    map_kwargs["desc"] = f"Adding EOS to {dataset_name} dataset"

                def add_eos(example, eos_token):
                    if "text" in example and not example["text"].endswith(eos_token):  # language modeling case
                        example["text"] = example["text"] + eos_token
                    elif "completion" in example and not example["completion"].endswith(eos_token):
                        example["completion"] = example["completion"] + eos_token
                    return example

                dataset = dataset.map(
                    add_eos,
                    fn_kwargs={"eos_token": processing_class.eos_token},
                    remove_columns="messages" if "messages" in column_names else None,  # renamed to "text"
                    **map_kwargs,
                )

            # Tokenize the dataset
            if isinstance(dataset, Dataset):  # `IterableDataset.map` does not support `desc`
                map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"

            def tokenize(example, processing_class, dataset_text_field):
                if "prompt" in example:  # prompt-completion case
                    if is_conversational(example):
                        prompt_ids = processing_class.apply_chat_template(example["prompt"])
                        prompt_completion_ids = processing_class.apply_chat_template(example["prompt"] + example["completion"])
                    else:
                        prompt_ids = processing_class(text=example["prompt"]).input_ids
                        prompt_completion_ids = processing_class(text=example["prompt"] + example["completion"]).input_ids

                    # Check if the tokenized prompt starts with the tokenized prompt+completion
                    if not prompt_completion_ids[: len(prompt_ids)] == prompt_ids:
                        warnings.warn(
                            "Mismatch between tokenized prompt and the start of tokenized prompt+completion. "
                            "This may be due to unexpected tokenizer behavior, whitespace issues, or special "
                            "token handling. Verify that the tokenizer is processing text consistently."
                        )

                    # Create a completion mask
                    completion_mask = [0] * len(prompt_ids) + [1] * (len(prompt_completion_ids) - len(prompt_ids))
                    processed = {"input_ids": prompt_completion_ids, "completion_mask": completion_mask}

                else:  # language modeling case
                    if is_conversational(example):
                        processed = {"input_ids": processing_class.apply_chat_template(example["messages"])}
                    else:
                        processed = {"input_ids": processing_class(text=example[dataset_text_field]).input_ids}
                return processed

            dataset = dataset.map(
                tokenize,
                fn_kwargs={
                    "processing_class": processing_class,
                    "dataset_text_field": args.dataset_text_field,
                },
                **map_kwargs,
            )

        # Pack or truncate
        if packing:
            if args.max_length is None:
                raise ValueError("When packing is enabled, `max_length` can't be `None`.")
            if isinstance(dataset, Dataset):  # `IterableDataset.map` does not support `desc`
                map_kwargs["desc"] = f"Packing {dataset_name} dataset"
            dataset = dataset.select_columns("input_ids")
            # Packing adds new column "position_ids" needed for document aware flash attention
            dataset = pack_dataset(dataset, args.max_length, args.packing_strategy, map_kwargs)
        elif args.max_length is not None:
            if isinstance(dataset, Dataset):  # `IterableDataset.map` does not support `desc`
                map_kwargs["desc"] = f"Truncating {dataset_name} dataset"
            dataset = truncate_dataset(dataset, args.max_length, map_kwargs)
        # For Liger kernel, ensure only input_ids is present
        if args.use_liger_kernel:
            dataset = dataset.select_columns({"input_ids", "position_ids"}.intersection(dataset.column_names))

    return dataset

from itertools import product
configs = ["conversational_language_modeling", "standard_language_modeling", "standard_prompt_completion", "conversational_prompt_completion"]
packing_strategies = ["ffd", "wrapped", False]
add_columns = [True, False]

for (config, packing_strategy, add_column) in product(configs, packing_strategies, add_columns):
    print(f"Testing config: {config}, packing_strategy: {packing_strategy}, add_column: {add_column}")
    args = SFTConfig(dataset_num_proc=1, max_length=17, packing_strategy=packing_strategy, use_liger_kernel=False)
    dataset = load_dataset("trl-internal-testing/zen", config, split="train")
    if add_column:
        dataset = dataset.add_column("col_name", range(len(dataset)))  # Add a prompt column for prompt-completion cases
    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")

    kwargs = {
        "dataset": dataset,
        "processing_class": tokenizer,
        "args": args,
        "packing": bool(packing_strategy),
        "formatting_func": None,  # No custom formatting function provided
        "dataset_name": "test_dataset",
    }

    old = prepare_dataset_old(**kwargs)
    new = prepare_dataset_new(**kwargs)
    assert old["input_ids"] == new["input_ids"]
    if "position_ids" in old.column_names:
        assert old["position_ids"] == new["position_ids"]
    if "completion_mask" in old.column_names:
        assert old["completion_mask"] == new["completion_mask"]

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qgallouedec qgallouedec merged commit 72c91e7 into main Jun 13, 2025
11 checks passed
@qgallouedec qgallouedec deleted the tokenize-directly branch June 13, 2025 14:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants