Skip to content

CP+SP (sequence parallel) failed in DTensor worker #659

@yuki-97

Description

@yuki-97

Qwen2.5-32B w/ tp8+cp2+sp (sequence parallel) failed.
tp8+sp could run on 32nodes (OOM on 16nodes).

Repro:

RUN_COMMAND="uv run python examples/run_grpo_math.py \
    --config examples/configs/grpo_math_8B.yaml \
    policy.model_name="Qwen/Qwen2.5-32B" \
    policy.generation.vllm_cfg.tensor_parallel_size=4 \
    policy.max_total_sequence_length=16384 \
    policy.dtensor_cfg.enabled=true \
    policy.dtensor_cfg.tensor_parallel_size=8 \
    policy.dtensor_cfg.context_parallel_size=2 \
    policy.dtensor_cfg.sequence_parallel=true \
    policy.dtensor_cfg.activation_checkpointing=true \
    policy.dynamic_batching.enabled=true \
    policy.dynamic_batching.train_mb_tokens=16384 \
    policy.dynamic_batching.logprob_mb_tokens=32768 \
    checkpointing.enabled=false \
    logger.wandb_enabled=true \
    logger.tensorboard_enabled=false \
    logger.monitor_gpus=true \
    logger.wandb.project=${PROJECT_NAME} \
    logger.wandb.name=${EXP_NAME} \
    cluster.num_nodes=16 \
    cluster.gpus_per_node=8"

Error:
https://wandb.ai/nvidia/grpo-dev-yukih-dtensor/runs/n96qd72r/logs

ray::DTensorPolicyWorker.train() (pid=2937548, ip=10.65.25.21, actor_id=75a6681698e862c3c537dc5501000000, repr=DTensorPolicyWorker[rank=117])
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lustre/fs1/portfolios/coreai/users/yukih/NeMo-RL/code_snapshots_Qwen2.5-32B-tp8cp2-233cfca4-0710/nemo_rl/models/policy/dtensor_policy_worker.py", line 544, in train
    outputs = self.model(
              ^^^^^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1857, in _call_impl
    return inner()
           ^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1805, in inner
    result = forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/transformers/utils/generic.py", line 969, in wrapper
    output = func(self, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 703, in forward
    outputs: BaseModelOutputWithPast = self.model(
                                       ^^^^^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/transformers/utils/generic.py", line 969, in wrapper
    output = func(self, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 426, in forward
    position_embeddings = self.rotary_emb(hidden_states, position_ids)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1857, in _call_impl
    return inner()
           ^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1794, in inner
    args_result = hook(self, args)
                  ^^^^^^^^^^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/distributed/tensor/_api.py", line 937, in <lambda>
    lambda mod, inputs: input_fn(mod, inputs, device_mesh)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lustre/fs1/portfolios/coreai/users/yukih/NeMo-RL/code_snapshots_Qwen2.5-32B-tp8cp2-233cfca4-0710/nemo_rl/models/dtensor/parallelize.py", line 65, in _prepare_input_fn
    raise ValueError(
ValueError: Failed to shard tensor for sequence parallelism. Local Shape is (torch.Size([1, 1003, 5120])) at rank 117. Different TP ranks must have the same shape. Original error: Inconsistent tensor metadata (including shape and stride) across ranks.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions