Skip to content

FSDP2+TP2 demo script does not work #621

@xxman-google

Description

@xxman-google

Describe the bug

Running SFT with the following config examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp2sp.v2.yaml gave an error.

Steps/Code to reproduce bug

Under main, run

uv run examples/run_sft.py --config examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp2sp.v2.yaml

Expected behavior
Encountered the following error:

[Rank 1] Loading state dict from rank 0... [repeated 2x across cluster]
(DTensorPolicyWorker pid=343603) Exception raised in creation task: The actor died because of an error raised in its creation task, ray::lm_policy-0-2:DTensorPolicyWorker.__init__() (pid=343603, ip=10.182.0.80, actor_id=87eda088ff87e9a825b334fc01000000, repr=DTensorPolicyWorker[rank=2]) [repeated 6x across cluster]
(DTensorPolicyWorker pid=343603)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [repeated 6x across cluster]
(DTensorPolicyWorker pid=343603)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [repeated 12x across cluster]
(DTensorPolicyWorker pid=343603)   File "/app/nemo-rl/nemo_rl/models/policy/dtensor_policy_worker.py", line 266, in __init__ [repeated 6x across cluster]
(DTensorPolicyWorker pid=343603)     torch.distributed.broadcast(buf, src=0) [repeated 6x across cluster]
(DTensorPolicyWorker pid=343603)   File "/app/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/distributed/c10d_logger.py", line 81, in wrapper [repeated 6x across cluster]
(DTensorPolicyWorker pid=343603)     return func(*args, **kwargs) [repeated 6x across cluster]
(DTensorPolicyWorker pid=343603)            ^^^^^^^^^^^^^^^^^^^^^ [repeated 6x across cluster]
(DTensorPolicyWorker pid=343603)   File "/app/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/distributed/distributed_c10d.py", line 2714, in broadcast [repeated 6x across cluster]
(DTensorPolicyWorker pid=343603)     work = group.broadcast([tensor], opts) [repeated 6x across cluster]
(DTensorPolicyWorker pid=343603)   File "/app/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/_compile.py", line 51, in inner [repeated 6x across cluster]
(DTensorPolicyWorker pid=343603)     return disable_fn(*args, **kwargs) [repeated 6x across cluster]
(DTensorPolicyWorker pid=343603)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^ [repeated 6x across cluster]
(DTensorPolicyWorker pid=343603)   File "/app/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 838, in _fn [repeated 6x across cluster]
(DTensorPolicyWorker pid=343603)     return fn(*args, **kwargs) [repeated 6x across cluster]
(DTensorPolicyWorker pid=343603)            ^^^^^^^^^^^^^^^^^^^ [repeated 6x across cluster]
(DTensorPolicyWorker pid=343603)   File "/app/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/distributed/tensor/_api.py", line 344, in __torch_dispatch__ [repeated 6x across cluster]
(DTensorPolicyWorker pid=343603)     return DTensor._op_dispatcher.dispatch( [repeated 6x across cluster]
(DTensorPolicyWorker pid=343603)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [repeated 6x across cluster]
(DTensorPolicyWorker pid=343603)   File "/app/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/distributed/tensor/_dispatch.py", line 167, in dispatch [repeated 6x across cluster]
(DTensorPolicyWorker pid=343603)     op_info = self.unwrap_to_op_info(op_call, args, kwargs) [repeated 6x across cluster]
(DTensorPolicyWorker pid=343603)               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [repeated 6x across cluster]
(DTensorPolicyWorker pid=343603)   File "/app/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/distributed/tensor/_dispatch.py", line 393, in unwrap_to_op_info [repeated 6x across cluster]
(DTensorPolicyWorker pid=343603)     assert compute_mesh is not None, ( [repeated 6x across cluster]
(DTensorPolicyWorker pid=343603)            ^^^^^^^^^^^^^^^^^^^^^^^^ [repeated 6x across cluster]
(DTensorPolicyWorker pid=343603) AssertionError: found no DeviceMesh from dtensor args for c10d.broadcast_.default! [repeated 6x across cluster]

Environment overview (please complete the following information)

Environment details

If NVIDIA docker image is used you don't need to specify these.
Otherwise, please provide:

  • OS version
  • PyTorch version
  • Python version

Additional context

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions