Skip to content

qwen 32b on 4 nodes can train, but OOMs on checkpointing #263

@terrykong

Description

@terrykong
========================= Step 19/20 =========================
▶ Preparing batch...
▶ Taking a training step...

📊 Training Results:
  • Loss: 0.1625

⏱️  Timing:
  • Total step time: 10.78s
  • data_processing: 0.00s (0.0%)

========================= Step 20/20 =========================
▶ Preparing batch...
▶ Taking a training step...
▶ Starting validation at step 20...

📊 Validation Results:
    • Validation loss: 0.2010

  ⏱️  Validation Timing:
    • Total validation time: 26.28s
Saving checkpoint for step 20...
Traceback (most recent call last):
  File "/terryk/rewrite-aligner/reinforcer/code_snapshots/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt/examples/run_sft.py", line 211, in <module>
    main()
  File "/terryk/rewrite-aligner/reinforcer/code_snapshots/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt/examples/run_sft.py", line 196, in main
    sft_train(
  File "/terryk/rewrite-aligner/reinforcer/code_snapshots/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt/nemo_reinforcer/algorithms/sft.py", line 468, in sft_train
    policy.save_checkpoint(
  File "/terryk/rewrite-aligner/reinforcer/code_snapshots/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt/nemo_reinforcer/models/policy/hf_policy.py", line 321, in save_checkpoint
    ray.get(futures)
  File "/opt/reinforcer_venv/lib/python3.12/site-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/reinforcer_venv/lib/python3.12/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/reinforcer_venv/lib/python3.12/site-packages/ray/_private/worker.py", line 2771, in get
    values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/reinforcer_venv/lib/python3.12/site-packages/ray/_private/worker.py", line 919, in get_objects
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(OutOfMemoryError): ESC[36mray::DTensorPolicyWorker.save_checkpoint()ESC[39m (pid=3243433, ip=10.65.18.141, actor_id=fc6120dfe4f855a194b168e901000000, repr=DTensorPolicyWorker[rank=3])
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/terryk/rewrite-aligner/reinforcer/code_snapshots/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt/nemo_reinforcer/models/policy/dtensor_policy_worker.py", line 726, in save_checkpoint
    save_checkpoint(
  File "/tmp/ray/session_2025-04-24_13-41-52_097162_3613001/runtime_resources/working_dir_files/_ray_pkg_4b8ca8b50abb2e2f/nemo_reinforcer/utils/native_checkpoint.py", line 161, in save_checkpoint
    k: v.full_tensor()
       ^^^^^^^^^^^^^^^
  File "/terryk/rewrite-aligner/reinforcer/code_snapshots/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt/venvs/nemo_reinforcer.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/distributed/tensor/_api.py", line 572, in full_tensor
    redist_res = self.redistribute(
                 ^^^^^^^^^^^^^^^^^^
  File "/terryk/rewrite-aligner/reinforcer/code_snapshots/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt/venvs/nemo_reinforcer.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/distributed/tensor/_api.py", line 544, in redistribute
    return Redistribute.apply(self, device_mesh, placements, async_op)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/terryk/rewrite-aligner/reinforcer/code_snapshots/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt/venvs/nemo_reinforcer.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/autograd/function.py", line 575, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/terryk/rewrite-aligner/reinforcer/code_snapshots/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt/venvs/nemo_reinforcer.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/distributed/tensor/_redistribute.py", line 306, in forward
    output = redistribute_local_tensor(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/terryk/rewrite-aligner/reinforcer/code_snapshots/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt/venvs/nemo_reinforcer.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/distributed/tensor/_redistribute.py", line 213, in redistribute_local_tensor
    new_local_tensor = current_placement._to_replicate_tensor(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/terryk/rewrite-aligner/reinforcer/code_snapshots/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt/venvs/nemo_reinforcer.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/distributed/tensor/placement_types.py", line 524, in _to_replicate_tensor
    return torch.cat(new_tensor_shard_list, dim=self.dim).contiguous()
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 540.00 MiB. GPU 0 has a total capacity of 79.11 GiB of which 536.88 MiB is free. Including non-PyTorch memory, this process has 78.58 GiB memory in use. Of the allocated memory 64.40 GiB is allocated by PyTorch, and 8.59 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC
_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

repro:

#!/bin/bash
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
cd $SCRIPT_DIR

HF_HOME=... \
HF_DATASETS_CACHE=... \
COMMAND="apt install -y jq && uv run recipes/llm/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt.sh logger.wandb.project=nemo-rl-release logger.wandb.name=sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt-7ec0d2d" \
CONTAINER=... \
MOUNTS=... \
sbatch \
    --nodes=4 \
    --account=... \
    --job-name=... \
    --partition=batch \
    --time=0:30:0 \
    --gres=gpu:8 \
    --output=slurm-250424-121152-%j-sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt-1.1.out \
    ray.sub

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions