-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Ring distributed backend #1784
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Ring distributed backend #1784
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really nice, I didn't read the actual ring.cpp yet but I think the API changes look great.
One high-level question, is there anyway to test this back-end in CI?
Yeah 100%. I should have added that actually. Will do. It doesn't actually need thunderbolt connections any socket will do. I regularly test for correctness using the loop back interface which will be the CI test as well. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!! This is good to go right?
Yeah it is good to go. I am writing docs for it in a different PR and there will be more PRs improving the performance and adding features (these should be limited in |
3bcd7cc
to
8289a68
Compare
Initial version of a custom ring reduce backend. It can be used with Ethernet as well but the main focus is on thunderbolt because:
The comparison is as follows for a 4-way all reduce using 4 M2 Ultras:

The bandwidth is computed as
total bytes * 3/4 * 2 / time
the3/4
is due to the ring reduce actually sending less than 2*B bytes on the wire.In terms of time per all reduce

which means that for 2GB all reduce we go from 1.3 seconds over Ethernet and 0.45s with MPI to 0.27s with our reduce.
Things to note:
mx.distributed.init()
takes abackend
argument and initializes the corresponding backend. If no argument is provided it tries to initialize in orderring > mpi
. The reason is that ring will fail ifMLX_HOSTFILE
andMLX_RANK
is not provided while mpi may succeed and "hide" the rest of the backends.mx.distributed.init()
called for a 2nd time without an argument returns the previously initialized backend. This allows the rest of the code to remain unchanged and agnostic of the communication group.Finally, I am in the process of preparing another PR with changes in the docs and possibly a launcher for this backend so people don't have to write their own shell scripts to launch with this backend.