Skip to content

[RFC] Low-level speed optimizations for PowerSGD #65813

@tvogels

Description

@tvogels

🚀 Feature

Communication compression

PyTorch features several DDP Communication Hooks that compress messages exchanged between workers in distributed optimization. If communication time is a bottleneck, these hooks can speed up distributed training. Of course, compression is only beneficial if the time required to compress the messages is significantly smaller than the time spent communicating. Speeding up compression times can open up communication savings to a wider range of models and hardware.

PowerSGD

In training problems with a strong communication bottleneck, the current PowerSGD hook in PyTorch already improves training times (@SciPioneer), but a recent paper argues that in many settings, the gains in communication do not yet weigh up to the added compression time.

Proposed optimizations

The recent DALL-E paper uses PowerSGD for large-scale distributed training, and the paper's appendix contains many recommendations on how to implement the algorithm efficiently. The most actionable recommendation is

  • the creation of a specialized CUDA kernel for the orthogonalization of matrices with many rows and few columns.

Orthogonalization is the most expensive step in PowerSGD compression and based on timing results from the DALL-E authors, there is a potential for speedups up to 100x in this operation.

Benefits

  • With faster compression times, compression will generate speedups for more models and with faster communication hardware.
  • A faster orthogonalization operation will allow PowerSGD to use higher 'rank's (more accurate compression). Currently, such large ranks would make compression too slow. With more accurate compression, we can avoid drops in model accuracy.

cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @SciPioneer @H-Huang

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNot as big of a feature, but technically not a bug. Should be easy to fixmodule: ddpIssues/PRs related distributed data parallel trainingoncall: distributedAdd this issue/PR to distributed oncall triage queue

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions