-
Notifications
You must be signed in to change notification settings - Fork 117
Open
Labels
Description
Qwen2.5-32B w/ tp8+cp2+sp (sequence parallel) failed.
tp8+sp could run on 32nodes (OOM on 16nodes).
Repro:
RUN_COMMAND="uv run python examples/run_grpo_math.py \
--config examples/configs/grpo_math_8B.yaml \
policy.model_name="Qwen/Qwen2.5-32B" \
policy.generation.vllm_cfg.tensor_parallel_size=4 \
policy.max_total_sequence_length=16384 \
policy.dtensor_cfg.enabled=true \
policy.dtensor_cfg.tensor_parallel_size=8 \
policy.dtensor_cfg.context_parallel_size=2 \
policy.dtensor_cfg.sequence_parallel=true \
policy.dtensor_cfg.activation_checkpointing=true \
policy.dynamic_batching.enabled=true \
policy.dynamic_batching.train_mb_tokens=16384 \
policy.dynamic_batching.logprob_mb_tokens=32768 \
checkpointing.enabled=false \
logger.wandb_enabled=true \
logger.tensorboard_enabled=false \
logger.monitor_gpus=true \
logger.wandb.project=${PROJECT_NAME} \
logger.wandb.name=${EXP_NAME} \
cluster.num_nodes=16 \
cluster.gpus_per_node=8"
Error:
https://wandb.ai/nvidia/grpo-dev-yukih-dtensor/runs/n96qd72r/logs
ray::DTensorPolicyWorker.train() (pid=2937548, ip=10.65.25.21, actor_id=75a6681698e862c3c537dc5501000000, repr=DTensorPolicyWorker[rank=117])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/lustre/fs1/portfolios/coreai/users/yukih/NeMo-RL/code_snapshots_Qwen2.5-32B-tp8cp2-233cfca4-0710/nemo_rl/models/policy/dtensor_policy_worker.py", line 544, in train
outputs = self.model(
^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1857, in _call_impl
return inner()
^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1805, in inner
result = forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/transformers/utils/generic.py", line 969, in wrapper
output = func(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 703, in forward
outputs: BaseModelOutputWithPast = self.model(
^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/transformers/utils/generic.py", line 969, in wrapper
output = func(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 426, in forward
position_embeddings = self.rotary_emb(hidden_states, position_ids)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1857, in _call_impl
return inner()
^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1794, in inner
args_result = hook(self, args)
^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/distributed/tensor/_api.py", line 937, in <lambda>
lambda mod, inputs: input_fn(mod, inputs, device_mesh)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/lustre/fs1/portfolios/coreai/users/yukih/NeMo-RL/code_snapshots_Qwen2.5-32B-tp8cp2-233cfca4-0710/nemo_rl/models/dtensor/parallelize.py", line 65, in _prepare_input_fn
raise ValueError(
ValueError: Failed to shard tensor for sequence parallelism. Local Shape is (torch.Size([1, 1003, 5120])) at rank 117. Different TP ranks must have the same shape. Original error: Inconsistent tensor metadata (including shape and stride) across ranks.