Skip to content

construcing DTensor on a 2D device mesh SIGTERMs #151858

@wangkuiyi

Description

@wangkuiyi

🐛 Describe the bug

I am working with @wanchaol on drafting a tutorial about DTensor https://wkyi.quip.com/YVhXArYw2a5c/PyTorch-DTensor-From-Zero-To-Hero

You can run all examples in the above Quip doc using the same command line:

OMP_NUM_THREADS=1 torchrun --nproc_per_node=4 a.py

All the examples work well on a Runpod.ai pod with four GPUs except for the last one, whose complete source code is attached as the following:

# OMP_NUM_THREADS=1 torchrun --nproc_per_node=4 a.py
import os
import torch
import torch.distributed as dist
import contextlib


@contextlib.contextmanager
def distributed_context():
    try:
        local_rank = int(os.environ["LOCAL_RANK"])
        local_device = torch.device("cuda", local_rank)
        dist.init_process_group(backend="nccl", device_id=local_device)
        yield local_device
    finally:
        dist.barrier()
        dist.destroy_process_group()
        print(f"Rank {local_rank} finished")


def main(local_device):
    import torch.distributed.tensor.debug
    import torch.distributed.tensor as dt

    local_rank = local_device.index

    mesh = dist.init_device_mesh("cuda", (2, 2), mesh_dim_names=["dp", "tp"])
    placements = [dt.Shard(dim=0), dt.Shard(dim=1)]
    dtensor = dt.full((4, 4), 1.23, device_mesh=mesh, placements=placements)
    print(f"Rank {local_rank} created \n{dtensor}")
    dt.debug.visualize_sharding(dtensor)


if __name__ == "__main__":
    with distributed_context() as local_device:
        main(local_device)

Running it gave me the following errors

root@42cc9eef3ad3:/w# OMP_NUM_THREADS=1 torchrun --nproc_per_node=4 a.py
Rank 0 created
DTensor(local_tensor=tensor([[1.2300, 1.2300],
        [1.2300, 1.2300]], device='cuda:0'), device_mesh=DeviceMesh('cuda', [[0, 1], [2, 3]], mesh_dim_names=('dp', 'tp')), placements=(Shard(dim=0), Shard(dim=1)))

         Col 0-1    Col 2-3
-------  ---------  ---------
Row 0-1  cuda:0     cuda:1
Row 2-3  cuda:2     cuda:3
Rank 1 created
DTensor(local_tensor=tensor([[1.2300, 1.2300],
        [1.2300, 1.2300]], device='cuda:0'), device_mesh=DeviceMesh('cuda', [[0, 1], [2, 3]], mesh_dim_names=('dp', 'tp')), placements=(Shard(dim=0), Shard(dim=1)))
Rank 2 created
DTensor(local_tensor=tensor([[1.2300, 1.2300],
        [1.2300, 1.2300]], device='cuda:0'), device_mesh=DeviceMesh('cuda', [[0, 1], [2, 3]], mesh_dim_names=('dp', 'tp')), placements=(Shard(dim=0), Shard(dim=1)))
Rank 3 created
DTensor(local_tensor=tensor([[1.2300, 1.2300],
        [1.2300, 1.2300]], device='cuda:0'), device_mesh=DeviceMesh('cuda', [[0, 1], [2, 3]], mesh_dim_names=('dp', 'tp')), placements=(Shard(dim=0), Shard(dim=1)))
W0422 00:38:25.921000 3596 torch/distributed/elastic/multiprocessing/api.py:900] Sending process 3599 closing signal SIGTERM
W0422 00:38:25.922000 3596 torch/distributed/elastic/multiprocessing/api.py:900] Sending process 3600 closing signal SIGTERM
E0422 00:38:26.086000 3596 torch/distributed/elastic/multiprocessing/api.py:874] failed (exitcode: -11) local_rank: 0 (pid: 3598) of binary: /usr/bin/python
Traceback (most recent call last):
  File "/usr/local/bin/torchrun", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/distributed/run.py", line 892, in main
    run(args)
  File "/usr/local/lib/python3.11/dist-packages/torch/distributed/run.py", line 883, in run
    elastic_launch(
  File "/usr/local/lib/python3.11/dist-packages/torch/distributed/launcher/api.py", line 139, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/distributed/launcher/api.py", line 270, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
======================================================
a.py FAILED
------------------------------------------------------
Failures:
[1]:
  time      : 2025-04-22_00:38:25
  host      : 42cc9eef3ad3
  rank      : 3 (local_rank: 3)
  exitcode  : -11 (pid: 3601)
  error_file: <N/A>
  traceback : Signal 11 (SIGSEGV) received by PID 3601
------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2025-04-22_00:38:25
  host      : 42cc9eef3ad3
  rank      : 0 (local_rank: 0)
  exitcode  : -11 (pid: 3598)
  error_file: <N/A>
  traceback : Signal 11 (SIGSEGV) received by PID 3598
======================================================

Versions

Version 2.8.0

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @tianyu-l @XilunWu @chauhang @penguinwu

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: dtensordistributed tensor tagoncall: distributedAdd this issue/PR to distributed oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions