-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Performance optimization: Replace list comprehensions with tensor operations in BCO and KTO trainers #3813
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Performance optimization: Replace list comprehensions with tensor operations in BCO and KTO trainers #3813
Conversation
…perations in BCO and KTO trainers ## 🎯 Summary This PR optimizes performance-critical list comprehensions in BCO and KTO trainers by replacing them with equivalent tensor operations. The optimization provides significant performance improvements, especially on GPU with large batch sizes.
Hi @kashif @qgallouedec, I understand that reviewing PRs requires valuable time, and I genuinely appreciate any feedback you can provide. The changes are relatively small but impactful, focusing on three strategic locations where list comprehensions were creating GPU-CPU transfer overhead. Thank you in advance for considering this contribution to TRL! |
trl/trainer/bco_trainer.py
Outdated
labels = torch.tensor(batch["label"], dtype=torch.bool) | ||
chosen_idx = torch.nonzero(labels, as_tuple=True)[0] | ||
rejected_idx = torch.nonzero(~labels, as_tuple=True)[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
labels = torch.tensor(batch["label"], dtype=torch.bool) | |
chosen_idx = torch.nonzero(labels, as_tuple=True)[0] | |
rejected_idx = torch.nonzero(~labels, as_tuple=True)[0] | |
labels = torch.tensor(batch["label"], dtype=torch.bool, device=embeddings.device) | |
chosen_idx = torch.where(labels)[0] | |
rejected_idx = torch.where(~labels)[0] |
How about using torch.where
, which is more explicit about boolean operations? and similarly for the other changes? would be good to also be explicit about the device
when creating the labels
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about using
torch.where
, which is more explicit about boolean operations? and similarly for the other changes? would be good to also be explicit about thedevice
when creating the labels
Thank you for the excellent feedback! You're absolutely right about using torch.where
for boolean operations - it's much more explicit and readable.
thanks @chi2liu left some suggestions for your consideration |
…add explicit device placement - Replace torch.nonzero(labels, as_tuple=True)[0] with torch.where(labels)[0] for more explicit boolean operations - Replace torch.nonzero(~labels, as_tuple=True)[0] with torch.where(~labels)[0] for negated boolean operations - Add explicit device specification when creating labels tensor to ensure device consistency with embeddings - Updated in BCO and KTO trainers for consistent tensor operations
…github.com/chi2liu/trl into optimize-list-comprehensions-to-tensor-ops
you might need to do a |
Hi @kashif , I've already run make precommit in the root directory and all style checks are passing: ✅ All files have the required copyright. The committed changes don't have any style issues. Is there a specific style issue you're seeing that I might have missed? |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
…rations in BCO and KTO trainers (huggingface#3813) Co-authored-by: chiliu <chiliu@paypal.com>
…rations in BCO and KTO trainers (huggingface#3813) Co-authored-by: chiliu <chiliu@paypal.com>
🎯 Summary
This PR optimizes performance-critical list comprehensions in BCO and KTO trainers by replacing them with equivalent tensor operations. The optimization provides significant performance improvements, especially on GPU with large batch sizes.
🔍 Technical Details
Why These Optimizations Matter
torch.nonzero()
leverages GPU's parallel processing capabilitiesCorrectness Guarantees
📈 Benchmark Results
Detailed benchmarks on NVIDIA A100 80GB:
🎁 Additional Benefits
⚡ Impact on Real Training
For typical training scenarios:
These savings compound significantly over long training runs.
🔒 Safety
What does this PR do?
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.