Skip to content

✂️ [DPO] Fix truncation keep_end leading to zero'd out samples #3398

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
May 27, 2025

Conversation

LeonEricsson
Copy link
Collaborator

@LeonEricsson LeonEricsson commented May 1, 2025

What does this PR do?

Fixes issue #3382. To recap, performing a left flush followed by left truncation (when truncation_mode='keep_end') will remove real tokens before masked ones. In the worst case, an entire sample may be masked out, leading to NaNs and potentially corrupting the loss calculation (unconfirmed).

trl/trl/trainer/dpo_trainer.py

Lines 1118 to 1134 in 999acd5

attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
# Truncate right
if self.max_length is not None:
if self.truncation_mode == "keep_end":
input_ids = input_ids[:, -self.max_length :]
attention_mask = attention_mask[:, -self.max_length :]
loss_mask = loss_mask[:, -self.max_length :]
elif self.truncation_mode == "keep_start":
input_ids = input_ids[:, : self.max_length]
attention_mask = attention_mask[:, : self.max_length]
loss_mask = loss_mask[:, : self.max_length]
else:
raise ValueError(
f"Unknown truncation mode: '{self.truncation_mode}'. Should be one of ['keep_end', "
"'keep_start']."
)

To address this we now:

  • Flush-left each sequence as before, then
  • For keep_end, compute each row’s actual flushed length Lᵢ and gather the slice [max(0, Lᵢ–max_length) ... Lᵢ) individually-so even very short completions still include their real tokens;
  • For keep_start, simply take [:, :max_length] on the already left-aligned data.

Implement flush_right and apply the flush direction depending on the truncation mode (while making sure the final outcome is left flushed)

This per‐row windowing ensures every example contributes at least its own completion tokens, preventing any fully zeroed loss_mask rows.

Also, while tinkering I ended up with an optimized version of flush_left(), should be significantly faster.

Benchmarks

Some runs on trl-lib/ultrafeedback_binarized with

max_length=512, 
max_prompt_length=256,
truncation_mode='keep_end'

shows that we no longer have any masked out samples, and that our token utilization has more than doubled (note that the rate of improvement depends heavily on the above parameters and the dataset).

Screenshot 2025-05-03 at 16 47 57

The ratio of trainable tokens is calculated by dividing the total number of tokens by the total number of unmasked tokens in the loss_mask.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@LeonEricsson LeonEricsson marked this pull request as ready for review May 3, 2025 15:10
@LeonEricsson LeonEricsson force-pushed the dpo_length_truncation_fix branch from 10b1605 to 0a2f8bf Compare May 3, 2025 15:16
@qgallouedec
Copy link
Member

That's a good point @LeonEricsson!
Another (maybe more intuitive?) to think of it:

if self.truncation_mode == "keep_end": 
    attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
    input_ids = input_ids[:, -self.max_length :]
    attention_mask = attention_mask[:, -self.max_length :]
    loss_mask = loss_mask[:, -self.max_length :]
elif self.truncation_mode == "keep_start":
    attention_mask, input_ids, loss_mask = flush_right(attention_mask, input_ids, loss_mask)
    input_ids = input_ids[:, : self.max_length]
    attention_mask = attention_mask[:, : self.max_length]
    loss_mask = loss_mask[:, : self.max_length]
    attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)

Is this roughly equivalent to what you do?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@LeonEricsson
Copy link
Collaborator Author

LeonEricsson commented May 5, 2025

That's a good point @LeonEricsson! Another (maybe more intuitive?) to think of it:

if self.truncation_mode == "keep_end": 
    attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
    input_ids = input_ids[:, -self.max_length :]
    attention_mask = attention_mask[:, -self.max_length :]
    loss_mask = loss_mask[:, -self.max_length :]
elif self.truncation_mode == "keep_start":
    attention_mask, input_ids, loss_mask = flush_right(attention_mask, input_ids, loss_mask)
    input_ids = input_ids[:, : self.max_length]
    attention_mask = attention_mask[:, : self.max_length]
    loss_mask = loss_mask[:, : self.max_length]
    attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)

Is this roughly equivalent to what you do?

yep, principally this is the same. i was deliberating between that approach - which is clearer to read/understand - and the implemented one which is marginally more performant. i'm fine with either; your pick.

@qgallouedec
Copy link
Member

Yes I think the latest is probably clearer

@LeonEricsson LeonEricsson force-pushed the dpo_length_truncation_fix branch from 366ee7f to e227154 Compare May 27, 2025 12:12
@LeonEricsson
Copy link
Collaborator Author

Yes I think the latest is probably clearer

Done

Copy link
Member

@qgallouedec qgallouedec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @LeonEricsson, super clean PR!

@qgallouedec qgallouedec changed the title [DPO] Truncation leading to zero'd out samples ✂️ [DPO] Fix truncation keep_end leading to zero'd out samples May 27, 2025
@qgallouedec qgallouedec merged commit 8e8e62b into huggingface:main May 27, 2025
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants