-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Closed
Labels
module: fsdponcall: distributedAdd this issue/PR to distributed oncall triage queueAdd this issue/PR to distributed oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
see related user reporting issues in tatsu-lab/stanford_alpaca#81 and lm-sys/FastChat#256
A workaround that the community is applying is:
Assume you are using torch=1.13.0, change python/lib/python3.9/site packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:2224 from state_dict[fqn] = state_dict[fqn].clone().detach() to state_dict[fqn] = state_dict[fqn].cpu().clone().detach()`
This is pretty manual monkey patching and we should really fix this in pytorch directly.
@fegin @awgu @rohan-varma @zhaojuanmao
Versions
This happens since pytorch 1.13 and I don't think we have fixed it so far.
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu
Metadata
Metadata
Assignees
Labels
module: fsdponcall: distributedAdd this issue/PR to distributed oncall triage queueAdd this issue/PR to distributed oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module