-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Faster position_ids
computation for FFD packing
#3649
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Faster position_ids
computation for FFD packing
#3649
Conversation
@mariosasko this PR might be relavent: #3673 |
Hi @mariosasko , thanks for your work. |
LGTM. Ran some tests to assert consistency with previous approach. import numpy as np
from datasets import Dataset
from trl.data_utils import pack_dataset
import torch
def _convert_seq_lengths_to_position_ids(batch_seq_lengths: list[list[int]]) -> list[torch.Tensor]:
example_lengths = [sum(seq_lengths) for seq_lengths in batch_seq_lengths]
batch_seq_lengths = torch.tensor([seq_length for seq_lengths in batch_seq_lengths for seq_length in seq_lengths])
position_ids = torch.ones(sum(example_lengths), dtype=batch_seq_lengths.dtype)
position_ids[0] = 0
position_ids[batch_seq_lengths[:-1].cumsum(0)] = -(batch_seq_lengths[:-1] - 1)
position_ids = position_ids.cumsum(0)
return list(position_ids.split(example_lengths))
# Create a larger dataset with sequence lengths following a gamma distribution
num_samples = 100_000
# Generate sequence lengths following a gamma distribution
seq_lengths = np.random.gamma(shape=5, scale=20, size=num_samples) # mean will be 100
seq_lengths = np.clip(seq_lengths, 10, None).astype(int) # Clip to [10, inf)
# Generate input sequences with examrandom lengths based on gamma distribution
examples = {
"input_ids": [list(range(length)) for length in seq_lengths],
"attention_mask": [[1] * length for length in seq_lengths],
}
dataset = Dataset.from_dict(examples)
max_length = 256 # Set a fixed packing length
output_pr = pack_dataset(dataset, max_length, strategy="ffd")
output_main = pack_dataset(dataset, max_length, strategy="ffd_main")
output_position_ids = _convert_seq_lengths_to_position_ids(output_pr["seq_lengths"])
# Compare
for pr_ids, main_ids in zip(output_position_ids, output_main["position_ids"]):
assert torch.equal(pr_ids, torch.tensor(main_ids)), "Position IDs do not match!" Got any benchmarks for the speedup? |
@LeonEricsson Thanks for testing the PR I ran a benchmark from #3521 to check the speedup. These are the results: BenchmarkFor
PR
For
PR
So, the solution on One interesting thing is that this PR also makes the conversion to PyTorch (in the data collator) faster - it's faster to generate @kashif I don't think the linked PR is relevant @shirinyamani I'm not sure I understand the question. What do you mean by " |
@mariosasko yes, got it. lemme test it one more time, getting back to you soon! |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@mariosasko very nice speed boost! LGTM! |
Co-authored-by: Shirin Yamani <75791599+shirinyamani@users.noreply.github.com>
What does this PR do?
This PR removes the slow conversion from PyArrow to Python needed to compute
position_ids
for FFD packing, which was introduced in #3526. To do this, it usesseq_lengths
that save disk space and can be efficiently converted toposition_ids
in a vectorized manner in the SFT collator.Before submitting
Pull Request section?
to it if that's the case.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.