-
Notifications
You must be signed in to change notification settings - Fork 2.1k
✂️ [DPO] Fix truncation keep_end
leading to zero'd out samples
#3398
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
✂️ [DPO] Fix truncation keep_end
leading to zero'd out samples
#3398
Conversation
10b1605
to
0a2f8bf
Compare
That's a good point @LeonEricsson! if self.truncation_mode == "keep_end":
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
input_ids = input_ids[:, -self.max_length :]
attention_mask = attention_mask[:, -self.max_length :]
loss_mask = loss_mask[:, -self.max_length :]
elif self.truncation_mode == "keep_start":
attention_mask, input_ids, loss_mask = flush_right(attention_mask, input_ids, loss_mask)
input_ids = input_ids[:, : self.max_length]
attention_mask = attention_mask[:, : self.max_length]
loss_mask = loss_mask[:, : self.max_length]
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask) Is this roughly equivalent to what you do? |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
yep, principally this is the same. i was deliberating between that approach - which is clearer to read/understand - and the implemented one which is marginally more performant. i'm fine with either; your pick. |
Yes I think the latest is probably clearer |
366ee7f
to
e227154
Compare
Done |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @LeonEricsson, super clean PR!
keep_end
leading to zero'd out samples
What does this PR do?
Fixes issue #3382. To recap, performing a left flush followed by left truncation (when
truncation_mode='keep_end'
) will remove real tokens before masked ones. In the worst case, an entire sample may be masked out, leading to NaNs and potentially corrupting the loss calculation (unconfirmed).trl/trl/trainer/dpo_trainer.py
Lines 1118 to 1134 in 999acd5
To address this we now:
Flush-left each sequence as before, thenFor keep_end, compute each row’s actual flushed length Lᵢ and gather the slice[max(0, Lᵢ–max_length) ... Lᵢ)
individually-so even very short completions still include their real tokens;For keep_start, simply take[:, :max_length]
on the already left-aligned data.Implement
flush_right
and apply the flush direction depending on the truncation mode (while making sure the final outcome is left flushed)This per‐row windowing ensures every example contributes at least its own completion tokens, preventing any fully zeroed
loss_mask
rows.Also, while tinkering I ended up with an optimized version of
flush_left()
, should be significantly faster.Benchmarks
Some runs on trl-lib/ultrafeedback_binarized with
shows that we no longer have any masked out samples, and that our token utilization has more than doubled (note that the rate of improvement depends heavily on the above parameters and the dataset).
The ratio of trainable tokens is calculated by dividing the total number of tokens by the total number of unmasked tokens in the
loss_mask
.Before submitting
Pull Request section?
to it if that's the case.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.