Skip to content

Enable completion-only loss in SFTTrainer when using Liger Kernel #3674

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 2, 2025

Conversation

kswhitecross
Copy link
Contributor

What does this PR do?

Fixes #3484, where the 'completion_mask' column will be stripped from trainer.train_dataset and trainer.eval_dataset if args.use_liger_kernel == True. This issue makes it impossible to use the SFTTrainer to train on prompt-completion style datasets with the memory-efficient Liger kernel, without preprocessing the dataset manually.

This issue can be minimally reproduced with the following example:

from trl import SFTTrainer, SFTConfig
from datasets import Dataset

train_dataset = Dataset.from_dict({
    "prompt": ["What is the capital of France?", "What is the capital of Germany?"],
    "completion": [" The capital of France is Paris.", " The capital of Germany is Berlin."]
})

training_args = SFTConfig(
    use_liger_kernel=True,
    completion_only_loss=True
)

trainer = SFTTrainer(
    model="Qwen/Qwen2.5-0.5B",
    train_dataset=train_dataset,
    args=training_args
)

print(trainer.train_dataset)

which prints:

Dataset({
    features: ['input_ids'],  # this is missing 'completion_mask', so loss will be computed on the prompt as well!
    num_rows: 2
})

With this PR, the output will instead be

Dataset({
    features: ['completion_mask', 'input_ids'],
    num_rows: 2
})

enabling the masking out of prompt tokens from the loss.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • [ x ] Did you read the contributor guideline,
    Pull Request section?
  • [ x ] Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@shirinyamani shirinyamani self-requested a review July 2, 2025 17:25
Copy link
Member

@shirinyamani shirinyamani left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @kswhitecross very nice catch! Thanks for your contribution!
I've ran make precommit let's see if the ci passes!

I've also tested using

from trl import SFTTrainer, SFTConfig
from datasets import Dataset

dataset = load_dataset("trl-lib/tldr", split="train")

training_args = SFTConfig(
    use_liger_kernel=True,
    completion_only_loss=True
)

trainer = SFTTrainer(
    model="Qwen/Qwen2.5-0.5B",
    train_dataset=dataset,
    args=training_args
)

print(trainer.train_dataset)
trainer.train()

run command:

accelerate launch pr3674.py

my env setup by trl env:

- Platform: Linux-5.15.0-1048-aws-x86_64-with-glibc2.31
- Python version: 3.10.16
- TRL version: 0.20.0.dev0+78de9d6
- PyTorch version: 2.7.0
- accelerator(s): NVIDIA H100 80GB HBM3
- Transformers version: 4.52.4
- Accelerate version: 1.7.0
- Accelerate config: 
  - compute_environment: LOCAL_MACHINE
  - distributed_type: NO
  - mixed_precision: fp16
  - use_cpu: False
  - debug: False
  - num_processes: 1
  - machine_rank: 0
  - num_machines: 1
  - gpu_ids: 0
  - rdzv_backend: static
  - same_network: True
  - main_training_function: main
  - enable_cpu_affinity: False
  - downcast_bf16: no
  - tpu_use_cluster: False
  - tpu_use_sudo: False
  - tpu_env: []
- Datasets version: 3.6.0
- HF Hub version: 0.32.2
- bitsandbytes version: not installed
- DeepSpeed version: 0.17.1
- Diffusers version: not installed
- Liger-Kernel version: 0.5.10
- LLM-Blender version: not installed
- OpenAI version: 1.82.0
- PEFT version: 0.15.2
- vLLM version: 0.9.0

@kswhitecross
Copy link
Contributor Author

Thanks @shirinyamani ! This is my first PR

I tried to run make precommit on my own machine, but got an error including make: pre-commit: No such file or directory. Is pre-commit missing from the [dev] requirements?

@kashif
Copy link
Collaborator

kashif commented Jul 2, 2025

@kswhitecross yes do pip install precommit

@shirinyamani
Copy link
Member

shirinyamani commented Jul 2, 2025

Thanks @shirinyamani ! This is my first PR

I tried to run make precommit on my own machine, but got an error including make: pre-commit: No such file or directory. Is pre-commit missing from the [dev] requirements?

Have you installed pip install precommit and ruff?
this happened to me before, actually there are two precommits when you run the command is make precommit without - in between but when it actually checks for formatting or style stuff it requires you to pip install pre-commit, im personally not 100% sure why this happens but reading the make file and installing what i mentioned above helped me figuring the issue in past!

@shirinyamani
Copy link
Member

@kswhitecross but anyways, thanks for the catch and for now i ran it and CI failure is not relevant to this PR. So we are good on this PR for now!

@shirinyamani shirinyamani self-requested a review July 2, 2025 18:11
@shirinyamani shirinyamani merged commit b520378 into huggingface:main Jul 2, 2025
9 of 10 checks passed
@sonalexle
Copy link

Could you please do the same for assistant_masks? Thanks!

marcandrelarochelle pushed a commit to marcandrelarochelle/trl that referenced this pull request Jul 29, 2025
…ggingface#3674)

Co-authored-by: kwhitecross <kwhitecross@cs.umass.edu>
Co-authored-by: shirinyamani <75791599+shirinyamani@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Completions Only Loss is incompatible with use_liger_kernel set as true
5 participants