Skip to content

😷 Fix SFT masking EOS when equal to PAD #3200

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Apr 2, 2025
Merged

😷 Fix SFT masking EOS when equal to PAD #3200

merged 17 commits into from
Apr 2, 2025

Conversation

qgallouedec
Copy link
Member

@qgallouedec qgallouedec commented Mar 31, 2025

What does this PR do?

Fixes:

This PR fixes the bug observed many times in which the SFT model seems to have unlearned how to generate the EOS. This is due to the fact that here we have the masking logic

labels = batch["input_ids"].clone()
if self.tokenizer.pad_token_id is not None:
    labels[labels == self.tokenizer.pad_token_id] = -100

So, if EOS=PAD, then all EOS are masked in the loss.

To solve the problem, we adopt a method that doesn't rely on the token value to determine whether it should be masked in the loss.

This is based on the addition of our own Collator.

Why did I choose to add a new collator?

On the one hand, to fix the issue, but also to prepare for the future, as it gives us more control to

  1. natively support multi-modal data, as we already do in DPO, see here:
    if "pixel_values" in examples[0]:
    pixel_values = [torch.tensor(example["pixel_values"]) for example in examples]
    if "pixel_attention_mask" in examples[0]:
    pixel_attention_mask = [torch.tensor(example["pixel_attention_mask"]) for example in examples]
    if "ref_chosen_logps" in examples[0] and "ref_rejected_logps" in examples[0]:
    ref_chosen_logps = torch.tensor([example["ref_chosen_logps"] for example in examples])
    ref_rejected_logps = torch.tensor([example["ref_rejected_logps"] for example in examples])
  2. train on completion only, without relying on a kind of reverse chat template.

@qgallouedec
Copy link
Member Author

Experiments:

The following code was run on both branch main and this branch, with both packing and not.

from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
from accelerate import PartialState
from transformers import AutoTokenizer


def main():
    dataset = load_dataset("trl-lib/Capybara", split="train")
    model_id = "meta-llama/Llama-3.2-3B"

    def func(example):
        messages = example["messages"]
        messages = [f"{message['role']}: {message['content']}" for message in messages]
        text = "\n".join(messages)
        return {"text": text}

    with PartialState().main_process_first():
        dataset = dataset.map(func, remove_columns=dataset.column_names)

    tokenizer = AutoTokenizer.from_pretrained(model_id)

    # The bug occurs when the pad token is set to the eos token.
    # We intentionally keep this line to verify that the fix works.
    tokenizer.pad_token = tokenizer.eos_token

    trainer = SFTTrainer(
        model=model_id,
        args=SFTConfig(
            output_dir="Llama-3.2-3B-556-2-fix-pack",
            max_length=4096,
            gradient_checkpointing=True,
            per_device_train_batch_size=4,
            logging_steps=5,
            save_steps=20,
            bf16=True,
            dataset_num_proc=16,
            num_train_epochs=1,
            packing=True,
        ),
        train_dataset=dataset,
        processing_class=tokenizer,
    )
    trainer.train()


if __name__ == "__main__":
    main()
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml sandbox/3200.py

The learning curves match, as expected:

No packing

Screenshot 2025-04-01 at 10 46 11

Packing

Screenshot 2025-04-01 at 10 46 25

The length distribution after training, which validates the bug is fixed:
completion_lengths

@qgallouedec qgallouedec marked this pull request as ready for review April 1, 2025 18:05
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qgallouedec qgallouedec changed the title Fix eos sft 😷 Fix SFT masking EOS when equal to PAD Apr 1, 2025
@@ -106,7 +106,6 @@ def test_sft_trainer_transformers(self, model_name, packing):

model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't need this anymore

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will do the same in our SFT script in open-r1 once this is merged

Comment on lines -172 to -184
# Model
if args.model_init_kwargs is not None and not isinstance(model, str):
warnings.warn(
"You passed model_init_kwargs to the `SFTConfig`, but your model is already instantiated. "
"The `model_init_kwargs` will be ignored."
)
if isinstance(model, str):
model = self._create_model_from_path(model, args)

# PEFT configuration and model wrapping
if peft_config is not None:
model = self._prepare_peft_model(model, peft_config, args)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is moved down so that the user doesn't wait for the model to be loaded to get error if the pad token is not correctly specified

Comment on lines -188 to -207
if processing_class.pad_token is None:
processing_class.pad_token = processing_class.eos_token # required for padding when collating data

# Dataset
preprocess_dataset = args.dataset_kwargs is None or not args.dataset_kwargs.get("skip_prepare_dataset", False)
if preprocess_dataset:
train_dataset = self._prepare_dataset(
train_dataset, processing_class, args, args.packing, formatting_func, "train"
)
if eval_dataset is not None:
packing = args.packing if args.eval_packing is None else args.eval_packing
if isinstance(eval_dataset, dict):
eval_dataset = {
key: self._prepare_dataset(dataset, processing_class, args, packing, formatting_func, key)
for key, dataset in eval_dataset.items()
}
else:
eval_dataset = self._prepare_dataset(
eval_dataset, processing_class, args, packing, formatting_func, "eval"
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also moved down so that the user doesn't have to wait for the dataset to be processed to get an error if the pad token is not correctly specified.

Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great detective work getting to the bottom of this @qgallouedec ! Logic LGTM with a question about what happens if the user provides a pad_token string that splits into multiple token IDs

@@ -106,7 +106,6 @@ def test_sft_trainer_transformers(self, model_name, packing):

model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will do the same in our SFT script in open-r1 once this is merged

# Get the pad token: if not provided, use the one from the processing class or the eos token
# if the processing class does not have a pad token.
pad_token = args.pad_token or processing_class.pad_token or processing_class.eos_token
pad_token_id = processing_class.convert_tokens_to_ids(pad_token)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens here if pad_token is not a single token in the vocab? I.e. if the user passes hello and convert_tokens_to_ids givens 2 token IDs?

Copy link
Member Author

@qgallouedec qgallouedec Apr 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This processing_class.convert_tokens_to_ids(pad_token) would returns None and then this exception is raised:

if pad_token_id is None:
    raise ValueError(
        f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given "
        f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists "
        "in the vocabulary before using it as a padding token."
    )

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

>>> from trl import SFTTrainer, SFTConfig
>>> from datasets import load_dataset
>>> dataset = load_dataset("trl-lib/Capybara", split="train")
>>> trainer = SFTTrainer(
...     model="Qwen/Qwen2.5-0.5B",
...     args=SFTConfig(pad_token="this is a bit long for a pad token"),
...     train_dataset=dataset,
>>> )
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/fsx/qgallouedec/trl/trl/trainer/sft_trainer.py", line 263, in __init__
    raise ValueError(
ValueError: The specified `pad_token` ('this is a bit long for a pad token') is not found in the vocabulary of the given `processing_class` (Qwen2TokenizerFast). Ensure that the `pad_token` exists in the vocabulary before using it as a padding token.

@qgallouedec qgallouedec merged commit 485852c into main Apr 2, 2025
8 of 10 checks passed
@qgallouedec qgallouedec deleted the fix-eos-sft branch April 2, 2025 15:56
yxliu-TAMU pushed a commit to mincheolseong/ECEN743-GRPO-Project-Proposal that referenced this pull request Apr 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants