Skip to content

Faster position_ids computation for FFD packing #3649

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 3, 2025

Conversation

mariosasko
Copy link
Contributor

What does this PR do?

This PR removes the slow conversion from PyArrow to Python needed to compute position_ids for FFD packing, which was introduced in #3526. To do this, it uses seq_lengths that save disk space and can be efficiently converted to position_ids in a vectorized manner in the SFT collator.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@kashif
Copy link
Collaborator

kashif commented Jul 1, 2025

@mariosasko this PR might be relavent: #3673

@shirinyamani
Copy link
Member

Hi @mariosasko , thanks for your work.
One question when i was reading your work, if the two completely different sequence, have similar length e.g. [2,3,4], [5,6,7] both len=3 if we do not consider the position_id rather consider the length, how do we know that these are two different sequence?

@LeonEricsson
Copy link
Collaborator

LGTM.

Ran some tests to assert consistency with previous approach.

import numpy as np
from datasets import Dataset
from trl.data_utils import pack_dataset
import torch

def _convert_seq_lengths_to_position_ids(batch_seq_lengths: list[list[int]]) -> list[torch.Tensor]:
    example_lengths = [sum(seq_lengths) for seq_lengths in batch_seq_lengths]
    batch_seq_lengths = torch.tensor([seq_length for seq_lengths in batch_seq_lengths for seq_length in seq_lengths])
    position_ids = torch.ones(sum(example_lengths), dtype=batch_seq_lengths.dtype)
    position_ids[0] = 0
    position_ids[batch_seq_lengths[:-1].cumsum(0)] = -(batch_seq_lengths[:-1] - 1)
    position_ids = position_ids.cumsum(0)
    return list(position_ids.split(example_lengths))


# Create a larger dataset with sequence lengths following a gamma distribution
num_samples = 100_000

# Generate sequence lengths following a gamma distribution
seq_lengths = np.random.gamma(shape=5, scale=20, size=num_samples)  # mean will be 100
seq_lengths = np.clip(seq_lengths, 10, None).astype(int)  # Clip to [10, inf)

# Generate input sequences with examrandom lengths based on gamma distribution
examples = {
    "input_ids": [list(range(length)) for length in seq_lengths],
    "attention_mask": [[1] * length for length in seq_lengths],
}

dataset = Dataset.from_dict(examples)
max_length = 256  # Set a fixed packing length

output_pr = pack_dataset(dataset, max_length, strategy="ffd")
output_main = pack_dataset(dataset, max_length, strategy="ffd_main")

output_position_ids = _convert_seq_lengths_to_position_ids(output_pr["seq_lengths"])

# Compare
for pr_ids, main_ids in zip(output_position_ids, output_main["position_ids"]):
    assert torch.equal(pr_ids, torch.tensor(main_ids)), "Position IDs do not match!"

Got any benchmarks for the speedup?

@mariosasko
Copy link
Contributor Author

@LeonEricsson Thanks for testing the PR

I ran a benchmark from #3521 to check the speedup. These are the results:

Benchmark

For seq_lengths = np.random.gamma(shape=5, scale=20, size=num_samples) ($\mu = 100, \sigma \approx 44.72$)

main

Map: 100%|████████████████████████████████| 100000/100000 [00:02<00:00, 40686.82 examples/s]
Map: 100%|████████████████████████████████| 100000/100000 [00:02<00:00, 41073.94 examples/s]
Map: 100%|████████████████████████████████| 100000/100000 [00:02<00:00, 41129.72 examples/s]
Map: 100%|████████████████████████████████| 100000/100000 [00:02<00:00, 40616.59 examples/s]
Map: 100%|████████████████████████████████| 100000/100000 [00:02<00:00, 41340.03 examples/s]
Time for 100k rows with FFD strategy: 12.2312 seconds

PR

Map: 100%|███████████████████████████████| 100000/100000 [00:00<00:00, 255628.98 examples/s]
Map: 100%|███████████████████████████████| 100000/100000 [00:00<00:00, 269839.31 examples/s]
Map: 100%|███████████████████████████████| 100000/100000 [00:00<00:00, 268628.70 examples/s]
Map: 100%|███████████████████████████████| 100000/100000 [00:00<00:00, 268544.42 examples/s]
Map: 100%|███████████████████████████████| 100000/100000 [00:00<00:00, 240098.32 examples/s]
Time for 100k rows with FFD strategy: 1.9597 seconds

For seq_lengths = np.random.gamma(shape=50, scale=20, size=num_samples) ($\mu = 1000, \sigma \approx 141.42$)

main

Map: 100%|█████████████████████████████████| 100000/100000 [00:21<00:00, 4728.09 examples/s]
Map: 100%|█████████████████████████████████| 100000/100000 [00:21<00:00, 4727.46 examples/s]
Map: 100%|█████████████████████████████████| 100000/100000 [00:21<00:00, 4662.81 examples/s]
Map: 100%|█████████████████████████████████| 100000/100000 [00:21<00:00, 4630.90 examples/s]
Map: 100%|█████████████████████████████████| 100000/100000 [00:21<00:00, 4705.43 examples/s]
Time for 100k rows with FFD strategy: 106.6944 seconds

PR

Map: 100%|███████████████████████████████| 100000/100000 [00:00<00:00, 313898.62 examples/s]
Map: 100%|███████████████████████████████| 100000/100000 [00:00<00:00, 322974.41 examples/s]
Map: 100%|███████████████████████████████| 100000/100000 [00:00<00:00, 319968.51 examples/s]
Map: 100%|███████████████████████████████| 100000/100000 [00:00<00:00, 297004.75 examples/s]
Map: 100%|███████████████████████████████| 100000/100000 [00:00<00:00, 321820.60 examples/s]
Time for 100k rows with FFD strategy: 1.6738 seconds

So, the solution on main doesn't scale with sequence length at all.

One interesting thing is that this PR also makes the conversion to PyTorch (in the data collator) faster - it's faster to generate position_ids from a shorter seq_lengths list (of lists) than to directly create a tensor from a long list of position_ids.

@kashif I don't think the linked PR is relevant

@shirinyamani I'm not sure I understand the question. What do you mean by "len=3"? In your example, we would store [2,3,4] and [5,6,7] as seq_lengths when packing and later convert these sequence lengths into position_ids in the data collator.

@shirinyamani
Copy link
Member

shirinyamani commented Jul 2, 2025

@mariosasko yes, got it. lemme test it one more time, getting back to you soon!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@shirinyamani
Copy link
Member

@mariosasko very nice speed boost! LGTM!

@shirinyamani shirinyamani self-requested a review July 2, 2025 23:50
@shirinyamani shirinyamani requested a review from kashif July 2, 2025 23:51
@kashif kashif merged commit 4ccc5ca into huggingface:main Jul 3, 2025
@mariosasko mariosasko deleted the faster-position_ids-computation branch July 4, 2025 03:30
marcandrelarochelle pushed a commit to marcandrelarochelle/trl that referenced this pull request Jul 29, 2025
Co-authored-by: Shirin Yamani <75791599+shirinyamani@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants