Skip to content

Conversation

thepowerfuldeez
Copy link

@thepowerfuldeez thepowerfuldeez commented Jun 2, 2025

What does this PR do?

Adds ability to pass "sequence_length" key to data collator, achieving true absense of cross-document contamination when using packing.
This approach allows to set contiguous position_ids which in turn gets processed by flash attention kernel inside the Attention, allowing for block-diagonal attention masks and improving both speed and quality.

See huggingface/trl#3526 for loss plots

@ArthurZucker @muellerzr @qgallouedec

Copy link
Member

@Rocketknight1 Rocketknight1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good, but we might need a test to verify it's working okay and to ensure the behaviour doesn't regress in future updates!

Also, this is a nit but we probably don't need the is_sequence_length_provided bool when it's only used once, you can just test the condition directly

@@ -2028,7 +2038,14 @@ def __call__(self, features, return_tensors=None, separator_id=None):
else:
batch["labels"] += [separator_id] + input_ids[1:]
if self.return_position_ids:
batch["position_ids"] += list(range(len(input_ids)))
if is_sequence_length_provided:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if is_sequence_length_provided:
if "sequence_length" in features[0]:

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same logic applies to is_labels_provided variable -- i find current implementation cleaner?

@qgallouedec
Copy link
Member

Please keep this PR on hold. Actually we are still unsure if it's actually needed

@thepowerfuldeez
Copy link
Author

Looks like this might not be needed, @qgallouedec made it simplier by creating position_ids directly during packing and re-using existing functionality of DataCollatorWithFlattening

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants