-
Notifications
You must be signed in to change notification settings - Fork 2.1k
😷 Fix SFT masking EOS when equal to PAD #3200
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@@ -106,7 +106,6 @@ def test_sft_trainer_transformers(self, model_name, packing): | |||
|
|||
model = AutoModelForCausalLM.from_pretrained(model_name) | |||
tokenizer = AutoTokenizer.from_pretrained(model_name) | |||
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't need this anymore
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will do the same in our SFT script in open-r1
once this is merged
# Model | ||
if args.model_init_kwargs is not None and not isinstance(model, str): | ||
warnings.warn( | ||
"You passed model_init_kwargs to the `SFTConfig`, but your model is already instantiated. " | ||
"The `model_init_kwargs` will be ignored." | ||
) | ||
if isinstance(model, str): | ||
model = self._create_model_from_path(model, args) | ||
|
||
# PEFT configuration and model wrapping | ||
if peft_config is not None: | ||
model = self._prepare_peft_model(model, peft_config, args) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is moved down so that the user doesn't wait for the model to be loaded to get error if the pad token is not correctly specified
if processing_class.pad_token is None: | ||
processing_class.pad_token = processing_class.eos_token # required for padding when collating data | ||
|
||
# Dataset | ||
preprocess_dataset = args.dataset_kwargs is None or not args.dataset_kwargs.get("skip_prepare_dataset", False) | ||
if preprocess_dataset: | ||
train_dataset = self._prepare_dataset( | ||
train_dataset, processing_class, args, args.packing, formatting_func, "train" | ||
) | ||
if eval_dataset is not None: | ||
packing = args.packing if args.eval_packing is None else args.eval_packing | ||
if isinstance(eval_dataset, dict): | ||
eval_dataset = { | ||
key: self._prepare_dataset(dataset, processing_class, args, packing, formatting_func, key) | ||
for key, dataset in eval_dataset.items() | ||
} | ||
else: | ||
eval_dataset = self._prepare_dataset( | ||
eval_dataset, processing_class, args, packing, formatting_func, "eval" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is also moved down so that the user doesn't have to wait for the dataset to be processed to get an error if the pad token is not correctly specified.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great detective work getting to the bottom of this @qgallouedec ! Logic LGTM with a question about what happens if the user provides a pad_token
string that splits into multiple token IDs
@@ -106,7 +106,6 @@ def test_sft_trainer_transformers(self, model_name, packing): | |||
|
|||
model = AutoModelForCausalLM.from_pretrained(model_name) | |||
tokenizer = AutoTokenizer.from_pretrained(model_name) | |||
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will do the same in our SFT script in open-r1
once this is merged
# Get the pad token: if not provided, use the one from the processing class or the eos token | ||
# if the processing class does not have a pad token. | ||
pad_token = args.pad_token or processing_class.pad_token or processing_class.eos_token | ||
pad_token_id = processing_class.convert_tokens_to_ids(pad_token) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens here if pad_token
is not a single token in the vocab? I.e. if the user passes hello
and convert_tokens_to_ids
givens 2 token IDs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This processing_class.convert_tokens_to_ids(pad_token)
would returns None
and then this exception is raised:
if pad_token_id is None:
raise ValueError(
f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given "
f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists "
"in the vocabulary before using it as a padding token."
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
>>> from trl import SFTTrainer, SFTConfig
>>> from datasets import load_dataset
>>> dataset = load_dataset("trl-lib/Capybara", split="train")
>>> trainer = SFTTrainer(
... model="Qwen/Qwen2.5-0.5B",
... args=SFTConfig(pad_token="this is a bit long for a pad token"),
... train_dataset=dataset,
>>> )
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/fsx/qgallouedec/trl/trl/trainer/sft_trainer.py", line 263, in __init__
raise ValueError(
ValueError: The specified `pad_token` ('this is a bit long for a pad token') is not found in the vocabulary of the given `processing_class` (Qwen2TokenizerFast). Ensure that the `pad_token` exists in the vocabulary before using it as a padding token.
What does this PR do?
Fixes:
This PR fixes the bug observed many times in which the SFT model seems to have unlearned how to generate the EOS. This is due to the fact that here we have the masking logic
So, if EOS=PAD, then all EOS are masked in the loss.
To solve the problem, we adopt a method that doesn't rely on the token value to determine whether it should be masked in the loss.
This is based on the addition of our own Collator.
Why did I choose to add a new collator?
On the one hand, to fix the issue, but also to prepare for the future, as it gives us more control to
trl/trl/trainer/dpo_trainer.py
Lines 135 to 141 in 9f3702f