-
Notifications
You must be signed in to change notification settings - Fork 25.2k
Description
🚀 Feature
I would request to add an Alltoall / alltoallv communication primitive to torch.distributed
module and corresponding implementations to MPI, GLOO and other distributed backends.
Motivation
When using both data and model parallelism in modern workloads (such as DLRM), communication pattern required for switching from one type of parallelism to other maps very well to all-to-all communication well studied in HPC workloads. MPI provides MPI_Alltoall (and more flexible MPI_Alltoallv APIs to perform this type of communication). In absence of alltoall support, we nee to use sequence of gathers or scatters to perform same communication. Having alltoall API allows reducing number of calls required to communication backend and avoid redundant data copies.
DLRM Research Paper explains this communication pattern as butterfly-shuffle in its Parallelism section.
Pitch
Suppose we have 4 processes and each one has one input
tensor with following data:
input:
[ 0 1 2 3] # Proc 0
[ 4 5 6 7] # Proc 1
[ 8 9 10 11] # Proc 2
[12 13 14 15] # Proc 3
After the communication, we want following output tensor on each processes:
output:
[ 0 4 8 12] # Proc 0
[ 1 5 9 13] # Proc 1
[ 2 6 10 14] # Proc 2
[ 3 7 11 15] # Proc 3
So, we simply call:
data = [ rank * 4 + i for i in range(4)]
input = torch.tensor(data)
output = torch.empty_like(input)
dist.alltoall(output, input)
This will split the input tensor equally and scatter to all the processes. On the other hand, output tensor would be split equally and receive gathered tensors from respective process.
Alternatively, if we have list of tensors to scatter and gather list for receiving tensors or we explicitly want to specify what portion goes where (particularly, in multi-dimensional tensor), we can explicitly pass list of tensors as scatter_list and gather_list:
scatter_list = list(input.split(1))
gather_list = list(output.split(1))
dist.alltoall(gather_list, scatter_list)
Alternatives
Today, doing it with scatter would require this:
scatter_list = list(input.split(1))
gather_list = list(output.split(1))
for i in range(4):
dist.scatter(gather_list[i], scatter_list if i == rank else [], src = i)
@mnaumovfb @JianpingChen066 @dmudiger @srinivas212 @Jianhui-Li @mshiryaev
cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @xush6528 @osalpekar
EDIT: Fixed order of input and output in alltoall call as per prototype code