Skip to content

Conversation

mirceamironenco
Copy link
Contributor

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses.

Changelog

What are the changes made in this PR?

I've set the PR as a draft and only added 1 test to first get some feedback on this setting. In this case the PackedDataset acts more as a chunking utility. Note that since seq_len > max_seq_len, to avoid out of bounds errors I've changed input_pos to be generated by:

current_pack["input_pos"] += [x % self.max_seq_len for x in range(seq_len)]

This might be unexpected for some users (?).
Also, some of the testing utilities in PackedDataset I assume where not written with this case in mind (e.g. _get_expected_seq_lens_and_input_pos), so before I make further changes I'd like to know this direction is fine.

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API
  • I have added an example to docs or docstrings

Copy link

pytorch-bot bot commented Sep 26, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1697

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 0e14d7c with merge base 3fddc56 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 26, 2024
@codecov-commenter
Copy link

codecov-commenter commented Sep 27, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 67.56%. Comparing base (6bc143f) to head (9742982).
Report is 15 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1697      +/-   ##
==========================================
- Coverage   70.67%   67.56%   -3.11%     
==========================================
  Files         299      304       +5     
  Lines       15251    15615     +364     
==========================================
- Hits        10778    10551     -227     
- Misses       4473     5064     +591     
Flag Coverage Δ
67.56% <100.00%> (-3.11%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@@ -136,12 +136,15 @@ def _pack(self) -> None:
# Update the current pack
current_pack["tokens"] += tokens
current_pack["labels"] += labels
current_pack["input_pos"] += list(range(seq_len))
current_pack["input_pos"] += [x % self.max_seq_len for x in range(seq_len)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is the correct thing to do in terms of fixing the bug, my main question is whether it's the correct thing to do from a model training perspective. I think it is? Basically we are treating the 2nd half of a long sample as a new sample starting from position 0, right? And as long as we are using RoPE (or some kind of positional encoding that only cares about relative position) it doesn't matter that the $(max\textunderscore seq\textunderscore len+1)^{th}$ element now has input_pos 0. I guess with absolute positional encoding we should not do this though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with your analysis, and it's unclear to me what choice would be best in all scenarios; some ideas:

  1. Disallow any samples that have seq_len > max_seq_len (my assumption was that such samples are allowed due to the message of the ValueError on line 131).
  2. Truncate the samples that are too long, throw away the rest of the sample.
  3. Chunk the samples as done in this PR.
  4. Allow either 2 or 3, controlled by a new flag to be added to the PackedDataset class, e.g. truncate_samples (possibly with truncation being the default behavior). One could also have a warning be emitted in this case if any sample is truncated.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally I think (3) is most in line with the spirit of split_across_pack=True anyways, we don't care too much whether the split is caused by a single really long sample or just awkward spacing of a sample starting right near max_seq_len of the pack. For (2) it is reasonable but I actually wouldn't add another config (don't want to make things more complicated than they have to be). One thing we could consider is modifying the behavior of split_across_pack=False so that we put any sample whose length exceeds max_seq_len in its own pack, truncate it, and raise a warning. I don't have a strong preference between this and the current state of things though.

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice find, very nice fix, and great unit test. One concern though: I think we will also need to change the seq_lens values in the pack to not exceed max_seq_len, right? It may not error when packing the dataset as we saw with input_pos, but I suspect it will break in flex attention (see here)

I was going to suggest just changing L140 to current_pack["seq_lens"] += [min(seq_len, self.max_seq_len)], but I think you may need to use the sequence lengths of the entire pack to avoid a case like max_seq_len=50, seq_lens=[2, 50] breaking the sum(seq_lens) == max_seq_len assumption. (Hopefully that makes sense, lmk if not)

@mirceamironenco
Copy link
Contributor Author

Thanks for taking a look!

Very nice find, very nice fix, and great unit test. One concern though: I think we will also need to change the seq_lens values in the pack to not exceed max_seq_len, right? It may not error when packing the dataset as we saw with input_pos, but I suspect it will break in flex attention (see here)

I was going to suggest just changing L140 to current_pack["seq_lens"] += [min(seq_len, self.max_seq_len)], but I think you may need to use the sequence lengths of the entire pack to avoid a case like max_seq_len=50, seq_lens=[2, 50] breaking the sum(seq_lens) == max_seq_len assumption. (Hopefully that makes sense, lmk if not)

I think I understand what you mean, but unless I've misunderstood how splitting happens this might already be fine:

if self.split_across_pack:
boundary = self.max_seq_len
# The last elem in ``seq_lens`` ensures that ``sum(seq_lens) == self.max_seq_len``
leftover_seq_len = self.max_seq_len - sum(current_pack["seq_lens"][:-1])
seq_len_padding = [leftover_seq_len] if leftover_seq_len > 0 else []

Here sum(current_pack["seq_lens"][:-1]) should always be <= self.max_seq_len and for the case where all seq_lens are greater than max_seq_len it will always be 0. If you have a counter-example I can address it. I've also modified the unit test to explicitly check that the seq_len assumption holds (see latest commit at the time of this comment).

Note that, I've played around with this a bit more and the current version of the code (on torchtune/main) will error out even for a simpler case. It is not necessary for seq_len > 2 * max_seq_len, even a single example where seq_len > max_seq_len can cause the same error:

max_seq_len = 60
sample_size = [max_seq_len // 2] * 10 + [max_seq_len + 1]
dataset = [dict(tokens=list(range(size)), labels=list(range(size))) for size in sample_size]
x = PackedDataset(dataset, max_seq_len=max_seq_len, split_across_pack=True) # RuntimeError: upper bound and larger bound inconsistent with step sign.

@ebsmothers
Copy link
Contributor

Thanks for taking a look!

Very nice find, very nice fix, and great unit test. One concern though: I think we will also need to change the seq_lens values in the pack to not exceed max_seq_len, right? It may not error when packing the dataset as we saw with input_pos, but I suspect it will break in flex attention (see here)
I was going to suggest just changing L140 to current_pack["seq_lens"] += [min(seq_len, self.max_seq_len)], but I think you may need to use the sequence lengths of the entire pack to avoid a case like max_seq_len=50, seq_lens=[2, 50] breaking the sum(seq_lens) == max_seq_len assumption. (Hopefully that makes sense, lmk if not)

I think I understand what you mean, but unless I've misunderstood how splitting happens this might already be fine:

if self.split_across_pack:
boundary = self.max_seq_len
# The last elem in ``seq_lens`` ensures that ``sum(seq_lens) == self.max_seq_len``
leftover_seq_len = self.max_seq_len - sum(current_pack["seq_lens"][:-1])
seq_len_padding = [leftover_seq_len] if leftover_seq_len > 0 else []

Here sum(current_pack["seq_lens"][:-1]) should always be <= self.max_seq_len and for the case where all seq_lens are greater than max_seq_len it will always be 0. If you have a counter-example I can address it. I've also modified the unit test to explicitly check that the seq_len assumption holds (see latest commit at the time of this comment).

Note that, I've played around with this a bit more and the current version of the code (on torchtune/main) will error out even for a simpler case. It is not necessary for seq_len > 2 * max_seq_len, even a single example where seq_len > max_seq_len can cause the same error:

max_seq_len = 60
sample_size = [max_seq_len // 2] * 10 + [max_seq_len + 1]
dataset = [dict(tokens=list(range(size)), labels=list(range(size))) for size in sample_size]
x = PackedDataset(dataset, max_seq_len=max_seq_len, split_across_pack=True) # RuntimeError: upper bound and larger bound inconsistent with step sign.

Oh you're completely correct on this.. somehow when I was looking at the code I missed the [:-1] slice in seq_lens, which got me confused. Then I agree: I think we will always satisfy sum(seq_lens) == max_seq_len, even when dealing with samples whose length exceeds max_seq_len. Then there are no major concerns from my side on this PR; we can decide on the discussion of truncation vs chunking but I don't think it's a blocker here.

@mirceamironenco mirceamironenco marked this pull request as ready for review September 28, 2024 20:59
Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the fix!

@ebsmothers ebsmothers merged commit ded8958 into pytorch:main Sep 30, 2024
17 checks passed
@mirceamironenco mirceamironenco deleted the fix-packedds-seqlen branch September 30, 2024 18:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

PackedDataset cannot handle long sequence whose length is larger than 2*max_seq_len when using split_across_pack=True
4 participants