Skip to content

THD refactoring #7434

@pietern

Description

@pietern

This is a master issue for tracking THD refactoring work that is going on.

To prevent breaking current THD users this will happen in a parallel directory tree until we're at feature parity and can make a compatibility layer.

Issues/features that this refactor will address:

  • Currently, THD exposes a process group (data channel) as a global variable. To subdivide the group and have collectives apply to subgroups you have to include a group argument on the C++ side. Different groups may have different sets of connections/state/algorithm choices so this will be implemented with different process group instances. Allowing for multiple process group instances also means they'll have independent lifecycle, which will makes testing easier, allows for some form of failure recovery, etc.
  • The TCP/env init methods require knowledge of a master address. If processes are dynamically scheduled this is not known up front. There is still a need for sharing state between processes though. To allow for dynamic scheduling without upfront knowledge of network address information, I propose using a set/get/wait key/value interface like the one used in Caffe2 and Gloo. This allows for exchange of socket information either by hosting a daemon thread on the master process, or by delegating to a separate key/value server that is running outside the process group. For dynamically scheduled processes to work they only need to know the address of the key/value server.
  • The collectives are currently handled in place, requiring threading for parallelism to happen on the Python side. Instead, the collectives should execute asynchronously and return some kind of future object. For the common case where allreduce runs for all layers (or batches of layers) during a backwards pass, we can issue all collectives, and wait on all of them at the end of the pass. A thread pool associated with the process group can then be responsible for executing this work. To still allow chaining these async allreduce operations with weight updates for CUDA tensors, we can execute the work on a separate CUDA stream and wait on a completion event in the primary stream. All work queued after this (e.g. the weight updates) will execute upon completion of the async operation.
  • Currently, Gloo collectives are cached by signature (collective type, input devices, input size, etc). This does not allow for multiple instances with the same signature. I propose we use a mechanism by which we can tune the number of algorithm instances such that we can increase the level of parallelism even for collectives with the same signature. This has lower priority than the other items though.

cc @teng-li @ailzhang @apaszke

Metadata

Metadata

Labels

oncall: distributedAdd this issue/PR to distributed oncall triage queue

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions