Skip to content

Variable._execution_engine.run_backward eror during finetuning #25

@youliangtan

Description

@youliangtan

Issue

I got this error when running the finetuning script, in particular with quantization set to true.

 torchrun --standalone --nnodes 1 --nproc-per-node 2 vla-scripts/finetune.py --batch_size 4 --shuffle_buffer_size 1000 --lora_rank --use_quantization true ... # custom dataset etc....
  • Setup: running single node with 2 DGX V100 gpus

However this throws me an error message

Traceback (most recent call last):                                                                                                                                                                                                   
  File "/home/youliang/openvla/vla-scripts/finetune.py", line 326, in <module>
    finetune()
  File "/home/youliang/anaconda3/envs/vla/lib/python3.10/site-packages/draccus/argparsing.py", line 203, in wrapper_inner
    response = fn(cfg, *args, **kwargs)
  File "/home/youliang/openvla/vla-scripts/finetune.py", line 247, in finetune
    normalized_loss.backward()
  File "/home/youliang/anaconda3/envs/vla/lib/python3.10/site-packages/torch/_tensor.py", line 522, in backward
    torch.autograd.backward(
  File "/home/youliang/anaconda3/envs/vla/lib/python3.10/site-packages/torch/autograd/__init__.py", line 266, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/youliang/anaconda3/envs/vla/lib/python3.10/site-packages/torch/autograd/function.py", line 289, in apply
    return user_fn(self, *args)
  File "/home/youliang/anaconda3/envs/vla/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 319, in backward
    torch.autograd.backward(outputs_with_grad, args_with_grad)
  File "/home/youliang/anaconda3/envs/vla/lib/python3.10/site-packages/torch/autograd/__init__.py", line 266, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 1) Use of a module parameter outside the `forward` function. Please make sure model parameters are not shared across multiple concurrent forward-backward passes. or try to use _set_static_graph() as a workaround if this module graph does not change during training loop.2) Reused parameters in multiple reentrant backward passes. For example, if you use multiple `checkpoint` functions to wrap the same part of your model, it would result in the same set of parameters been used by different reentrant backward passes multiple times, and hence marking a variable ready multiple times. DDP does not support such use cases in default. You can try to use _set_static_graph() as a workaround if your module graph does not change over iterations.
Parameter at index 875 with name base_model.model.language_model.model.layers.31.mlp.down_proj.lora_B.default.weight has been marked as ready twice. This means that multiple autograd engine  hooks have fired for this particular parameter during this iteration.

Potential solution

To resolve this issue, I added vla._set_static_graph().

    # Wrap VLA in PyTorch DDP Wrapper for Multi-GPU Training
    vla = DDP(vla, device_ids=[device_id], find_unused_parameters=True, gradient_as_bucket_view=True)
    vla._set_static_graph()   # <---- ADD THIS LINE

Not sure if this the right way to resolve this.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions