Skip to content

FSDP state dict OOM during model saving #98823

@wanchaol

Description

@wanchaol

🐛 Describe the bug

see related user reporting issues in tatsu-lab/stanford_alpaca#81 and lm-sys/FastChat#256

A workaround that the community is applying is:

Assume you are using torch=1.13.0, change python/lib/python3.9/site packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:2224 from state_dict[fqn] = state_dict[fqn].clone().detach() to state_dict[fqn] = state_dict[fqn].cpu().clone().detach()`

This is pretty manual monkey patching and we should really fix this in pytorch directly.

@fegin @awgu @rohan-varma @zhaojuanmao

Versions

This happens since pytorch 1.13 and I don't think we have fixed it so far.

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu

Metadata

Metadata

Assignees

Labels

module: fsdponcall: distributedAdd this issue/PR to distributed oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions