-
Notifications
You must be signed in to change notification settings - Fork 669
Fix PackedDataset bug for seq_len > 2 * max_seq_len setting. #1697
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1697
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 0e14d7c with merge base 3fddc56 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #1697 +/- ##
==========================================
- Coverage 70.67% 67.56% -3.11%
==========================================
Files 299 304 +5
Lines 15251 15615 +364
==========================================
- Hits 10778 10551 -227
- Misses 4473 5064 +591
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
@@ -136,12 +136,15 @@ def _pack(self) -> None: | |||
# Update the current pack | |||
current_pack["tokens"] += tokens | |||
current_pack["labels"] += labels | |||
current_pack["input_pos"] += list(range(seq_len)) | |||
current_pack["input_pos"] += [x % self.max_seq_len for x in range(seq_len)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this is the correct thing to do in terms of fixing the bug, my main question is whether it's the correct thing to do from a model training perspective. I think it is? Basically we are treating the 2nd half of a long sample as a new sample starting from position 0, right? And as long as we are using RoPE (or some kind of positional encoding that only cares about relative position) it doesn't matter that the input_pos
0. I guess with absolute positional encoding we should not do this though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with your analysis, and it's unclear to me what choice would be best in all scenarios; some ideas:
- Disallow any samples that have
seq_len > max_seq_len
(my assumption was that such samples are allowed due to the message of the ValueError on line 131). - Truncate the samples that are too long, throw away the rest of the sample.
- Chunk the samples as done in this PR.
- Allow either 2 or 3, controlled by a new flag to be added to the
PackedDataset
class, e.g.truncate_samples
(possibly with truncation being the default behavior). One could also have a warning be emitted in this case if any sample is truncated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Personally I think (3) is most in line with the spirit of split_across_pack=True
anyways, we don't care too much whether the split is caused by a single really long sample or just awkward spacing of a sample starting right near max_seq_len
of the pack. For (2) it is reasonable but I actually wouldn't add another config (don't want to make things more complicated than they have to be). One thing we could consider is modifying the behavior of split_across_pack=False
so that we put any sample whose length exceeds max_seq_len
in its own pack, truncate it, and raise a warning. I don't have a strong preference between this and the current state of things though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice find, very nice fix, and great unit test. One concern though: I think we will also need to change the seq_lens
values in the pack to not exceed max_seq_len
, right? It may not error when packing the dataset as we saw with input_pos
, but I suspect it will break in flex attention (see here)
I was going to suggest just changing L140 to current_pack["seq_lens"] += [min(seq_len, self.max_seq_len)]
, but I think you may need to use the sequence lengths of the entire pack to avoid a case like max_seq_len=50, seq_lens=[2, 50]
breaking the sum(seq_lens) == max_seq_len
assumption. (Hopefully that makes sense, lmk if not)
Thanks for taking a look!
I think I understand what you mean, but unless I've misunderstood how splitting happens this might already be fine: torchtune/torchtune/datasets/_packed.py Lines 168 to 172 in 3fddc56
Here Note that, I've played around with this a bit more and the current version of the code (on torchtune/main) will error out even for a simpler case. It is not necessary for max_seq_len = 60
sample_size = [max_seq_len // 2] * 10 + [max_seq_len + 1]
dataset = [dict(tokens=list(range(size)), labels=list(range(size))) for size in sample_size]
x = PackedDataset(dataset, max_seq_len=max_seq_len, split_across_pack=True) # RuntimeError: upper bound and larger bound inconsistent with step sign. |
Oh you're completely correct on this.. somehow when I was looking at the code I missed the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the fix!
Context
What is the purpose of this PR? Is it to
Please link to any issues this PR addresses.
Changelog
What are the changes made in this PR?
max_seq_len * 2 < seq_len
would throw a runtime error when trying to padinput_pos
.I've set the PR as a draft and only added 1 test to first get some feedback on this setting. In this case the PackedDataset acts more as a chunking utility. Note that since
seq_len > max_seq_len
, to avoid out of bounds errors I've changedinput_pos
to be generated by:current_pack["input_pos"] += [x % self.max_seq_len for x in range(seq_len)]
This might be unexpected for some users (?).
Also, some of the testing utilities in PackedDataset I assume where not written with this case in mind (e.g.
_get_expected_seq_lens_and_input_pos
), so before I make further changes I'd like to know this direction is fine.Test plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example