Skip to content

Add alltoall collective communication support to torch.distributed module #32345

@ddkalamk

Description

@ddkalamk

🚀 Feature

I would request to add an Alltoall / alltoallv communication primitive to torch.distributed module and corresponding implementations to MPI, GLOO and other distributed backends.

Motivation

When using both data and model parallelism in modern workloads (such as DLRM), communication pattern required for switching from one type of parallelism to other maps very well to all-to-all communication well studied in HPC workloads. MPI provides MPI_Alltoall (and more flexible MPI_Alltoallv APIs to perform this type of communication). In absence of alltoall support, we nee to use sequence of gathers or scatters to perform same communication. Having alltoall API allows reducing number of calls required to communication backend and avoid redundant data copies.

DLRM Research Paper explains this communication pattern as butterfly-shuffle in its Parallelism section.

image

Pitch

Suppose we have 4 processes and each one has one input tensor with following data:

input:
[ 0  1  2  3]  # Proc 0
[ 4  5  6  7]  # Proc 1
[ 8  9 10 11]  # Proc 2
[12 13 14 15]  # Proc 3

After the communication, we want following output tensor on each processes:

output:
[ 0  4  8 12]  # Proc 0
[ 1  5  9 13]  # Proc 1
[ 2  6 10 14]  # Proc 2
[ 3  7 11 15]  # Proc 3

So, we simply call:

data = [ rank * 4 + i for i in range(4)]
input = torch.tensor(data)
output = torch.empty_like(input)
dist.alltoall(output, input)

This will split the input tensor equally and scatter to all the processes. On the other hand, output tensor would be split equally and receive gathered tensors from respective process.

Alternatively, if we have list of tensors to scatter and gather list for receiving tensors or we explicitly want to specify what portion goes where (particularly, in multi-dimensional tensor), we can explicitly pass list of tensors as scatter_list and gather_list:

scatter_list = list(input.split(1))
gather_list = list(output.split(1))
dist.alltoall(gather_list, scatter_list)

Alternatives

Today, doing it with scatter would require this:

scatter_list = list(input.split(1))
gather_list = list(output.split(1))
for i in range(4):
    dist.scatter(gather_list[i], scatter_list if i == rank else [], src = i)

@mnaumovfb @JianpingChen066 @dmudiger @srinivas212 @Jianhui-Li @mshiryaev

cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @xush6528 @osalpekar

EDIT: Fixed order of input and output in alltoall call as per prototype code

Metadata

Metadata

Assignees

No one assigned

    Labels

    oncall: distributedAdd this issue/PR to distributed oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions