Skip to content

Conversation

awgu
Copy link
Collaborator

@awgu awgu commented Apr 22, 2024

Stack from ghstack (oldest at bottom):

This PR includes two things:

  1. Changes to support load_state_dict(assign=True)
    • These changes are not ideal, but until we have DTensor padding the local tensor and general swap_tensors adoption, we may need to make do.
  2. Example of how to convert a full state dict on rank 0 to sharded state dict on all ranks via broadcast
    • To-do: check for recordStream from the funcol broadcast; if being called, remediate either via async_op=False c10d broadcast or use TORCH_NCCL_AVOID_RECORD_STREAMS=1 switched to using c10d async_op=False broadcast
    • To-do: check for broadcast latency since not using any coalescing

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k

Copy link

pytorch-bot bot commented Apr 22, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/124651

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 1f5d795 with merge base c82fcb7 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ci-td-distributed oncall: distributed Add this issue/PR to distributed oncall triage queue topic: not user facing topic category labels Apr 22, 2024
cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
Comment on lines +164 to +170
# TODO: Remove this padding logic once DTensor pads the local tensor:
# https://github.com/pytorch/pytorch/issues/113045
self._post_load_hook_handle = (
module_info.module.register_load_state_dict_post_hook(
lambda *args, **kwargs: self.reset_sharded_param()
)
)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is needed to support load_state_dict(assign=True) today.

In the future, we want (1) DTensor to pad its local tensor and (2) for users to use the swap_tensors path for load_state_dict() with FSDP2. Together these should allow us to remove this hook.

This PR includes two things:
1. Changes to support `load_state_dict(assign=True)`
    - These changes are not ideal, but until we have `DTensor` padding the local tensor and general `swap_tensors` adoption, we may need to make do.
2. Example of how to convert a full state dict on rank 0 to sharded state dict on all ranks via broadcast
    - To-do: check for `recordStream` from the funcol broadcast; if being called, remediate either via `async_op=False` c10d broadcast or use `TORCH_NCCL_AVOID_RECORD_STREAMS=1`
    - To-do: check for broadcast latency since not using any coalescing

cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
@awgu awgu marked this pull request as ready for review April 22, 2024 22:33
This PR includes two things:
1. Changes to support `load_state_dict(assign=True)`
    - These changes are not ideal, but until we have `DTensor` padding the local tensor and general `swap_tensors` adoption, we may need to make do.
2. Example of how to convert a full state dict on rank 0 to sharded state dict on all ranks via broadcast
    - ~~To-do: check for `recordStream` from the funcol broadcast; if being called, remediate either via `async_op=False` c10d broadcast or use `TORCH_NCCL_AVOID_RECORD_STREAMS=1`~~ switched to using c10d `async_op=False` broadcast
    - To-do: check for broadcast latency since not using any coalescing

cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
This PR includes two things:
1. Changes to support `load_state_dict(assign=True)`
    - These changes are not ideal, but until we have `DTensor` padding the local tensor and general `swap_tensors` adoption, we may need to make do.
2. Example of how to convert a full state dict on rank 0 to sharded state dict on all ranks via broadcast
    - ~~To-do: check for `recordStream` from the funcol broadcast; if being called, remediate either via `async_op=False` c10d broadcast or use `TORCH_NCCL_AVOID_RECORD_STREAMS=1`~~ switched to using c10d `async_op=False` broadcast
    - To-do: check for broadcast latency since not using any coalescing

cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
This PR includes two things:
1. Changes to support `load_state_dict(assign=True)`
    - These changes are not ideal, but until we have `DTensor` padding the local tensor and general `swap_tensors` adoption, we may need to make do.
2. Example of how to convert a full state dict on rank 0 to sharded state dict on all ranks via broadcast
    - ~~To-do: check for `recordStream` from the funcol broadcast; if being called, remediate either via `async_op=False` c10d broadcast or use `TORCH_NCCL_AVOID_RECORD_STREAMS=1`~~ switched to using c10d `async_op=False` broadcast
    - To-do: check for broadcast latency since not using any coalescing

cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
awgu pushed a commit that referenced this pull request Apr 23, 2024
@awgu awgu requested a review from weifengpy April 23, 2024 01:00
@awgu awgu added the release notes: distributed (fsdp2) release notes category label Apr 23, 2024
@awgu awgu requested review from fegin and wz337 April 23, 2024 14:34
@@ -148,9 +148,14 @@ def _test_state_dict_save_load(self, model: nn.Module):
param.to_local(),
torch.ones_like(param.to_local()) * new_fill_value,
)
self.assertEqual(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The assertion was too strict before, and our implementation was actually incorrect since we did not re-pad the local tensor.

This PR includes two things:
1. Changes to support `load_state_dict(assign=True)`
    - These changes are not ideal, but until we have `DTensor` padding the local tensor and general `swap_tensors` adoption, we may need to make do.
2. Example of how to convert a full state dict on rank 0 to sharded state dict on all ranks via broadcast
    - ~~To-do: check for `recordStream` from the funcol broadcast; if being called, remediate either via `async_op=False` c10d broadcast or use `TORCH_NCCL_AVOID_RECORD_STREAMS=1`~~ switched to using c10d `async_op=False` broadcast
    - To-do: check for broadcast latency since not using any coalescing

cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
Copy link
Collaborator

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

sharded_tensor = distribute_tensor(
full_param, mesh, sharded_meta_param.placements
)
sharded_sd[param_name] = nn.Parameter(sharded_tensor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious if buffers are sometimes included in state dict? for llama2, the buffer is RotaryPositionalEmbeddings.theta but it's not included in HF state dict

pytorchmergebot pushed a commit that referenced this pull request Apr 24, 2024
This PR adds a unit test to show how we can convert FSDP2 GPU sharded state dicts to a CPU full state dict on rank 0.

Pull Request resolved: #124741
Approved by: https://github.com/wanchaol, https://github.com/wz337
ghstack dependencies: #124651
pytorchmergebot pushed a commit that referenced this pull request Apr 24, 2024
This PR makes sure to construct the `DeviceMesh`'s `mesh` tensor on CPU device in `init_device_mesh()`. This means that we can call `init_device_mesh()` under meta-device context and still construct the correct `mesh` tensor.

Pull Request resolved: #124767
Approved by: https://github.com/wz337
ghstack dependencies: #124651, #124741
pytorchmergebot pushed a commit that referenced this pull request Apr 24, 2024
Pull Request resolved: #124768
Approved by: https://github.com/wz337
ghstack dependencies: #124651, #124741, #124767
awgu pushed a commit to awgu/pytorch that referenced this pull request Apr 24, 2024
pytorchmergebot pushed a commit that referenced this pull request Apr 24, 2024
This PR adds a `DeviceMesh.from_group()` static method to convert an existing process group to a device mesh.

Motivation: We need `DeviceMesh.from_group()` to allow FSDP2 to interoperate with distributed libraries that do not use `DeviceMesh` for all parallelisms.

Pull Request resolved: #124787
Approved by: https://github.com/wanchaol
ghstack dependencies: #124651, #124741, #124767, #124768, #124780
alat-rights pushed a commit to alat-rights/pytorch that referenced this pull request Apr 26, 2024
This PR adds a `DeviceMesh.from_group()` static method to convert an existing process group to a device mesh.

Motivation: We need `DeviceMesh.from_group()` to allow FSDP2 to interoperate with distributed libraries that do not use `DeviceMesh` for all parallelisms.

Pull Request resolved: pytorch#124787
Approved by: https://github.com/wanchaol
ghstack dependencies: pytorch#124651, pytorch#124741, pytorch#124767, pytorch#124768, pytorch#124780
pytorchmergebot pushed a commit that referenced this pull request Apr 29, 2024
This PR renames the `FSDP` class to `FSDPModule`. This is a BC breaking change. The rationale is that `FSDPModule` is more descriptive since `fully_shard` is a module-level API (applied to a `module` arg), so the `FSDP` class will always correspond to a module.

Also, users commonly import `FullyShardedDataParallel` as `FSDP`, so this can help avoid some name conflict in some cases.

Pull Request resolved: #124955
Approved by: https://github.com/wanchaol, https://github.com/wconstab
ghstack dependencies: #124651, #124741, #124767, #124768, #124780, #124787
petrex pushed a commit to petrex/pytorch that referenced this pull request May 3, 2024
…#124651)

This PR includes two things:
1. Changes to support `load_state_dict(assign=True)`
    - These changes are not ideal, but until we have `DTensor` padding the local tensor and general `swap_tensors` adoption, we may need to make do.
2. Example of how to convert a full state dict on rank 0 to sharded state dict on all ranks via broadcast
    - ~~To-do: check for `recordStream` from the funcol broadcast; if being called, remediate either via `async_op=False` c10d broadcast or use `TORCH_NCCL_AVOID_RECORD_STREAMS=1`~~ switched to using c10d `async_op=False` broadcast
    - To-do: check for broadcast latency since not using any coalescing

Pull Request resolved: pytorch#124651
Approved by: https://github.com/wanchaol
petrex pushed a commit to petrex/pytorch that referenced this pull request May 3, 2024
…24741)

This PR adds a unit test to show how we can convert FSDP2 GPU sharded state dicts to a CPU full state dict on rank 0.

Pull Request resolved: pytorch#124741
Approved by: https://github.com/wanchaol, https://github.com/wz337
ghstack dependencies: pytorch#124651
pytorch-bot bot pushed a commit that referenced this pull request May 3, 2024
This PR makes sure to construct the `DeviceMesh`'s `mesh` tensor on CPU device in `init_device_mesh()`. This means that we can call `init_device_mesh()` under meta-device context and still construct the correct `mesh` tensor.

Pull Request resolved: #124767
Approved by: https://github.com/wz337
ghstack dependencies: #124651, #124741
pytorch-bot bot pushed a commit that referenced this pull request May 3, 2024
Pull Request resolved: #124768
Approved by: https://github.com/wz337
ghstack dependencies: #124651, #124741, #124767
pytorch-bot bot pushed a commit that referenced this pull request May 3, 2024
This PR adds a `DeviceMesh.from_group()` static method to convert an existing process group to a device mesh.

Motivation: We need `DeviceMesh.from_group()` to allow FSDP2 to interoperate with distributed libraries that do not use `DeviceMesh` for all parallelisms.

Pull Request resolved: #124787
Approved by: https://github.com/wanchaol
ghstack dependencies: #124651, #124741, #124767, #124768, #124780
petrex pushed a commit to petrex/pytorch that referenced this pull request May 3, 2024
This PR renames the `FSDP` class to `FSDPModule`. This is a BC breaking change. The rationale is that `FSDPModule` is more descriptive since `fully_shard` is a module-level API (applied to a `module` arg), so the `FSDP` class will always correspond to a module.

Also, users commonly import `FullyShardedDataParallel` as `FSDP`, so this can help avoid some name conflict in some cases.

Pull Request resolved: pytorch#124955
Approved by: https://github.com/wanchaol, https://github.com/wconstab
ghstack dependencies: pytorch#124651, pytorch#124741, pytorch#124767, pytorch#124768, pytorch#124780, pytorch#124787
@github-actions github-actions bot deleted the gh/awgu/566/head branch June 2, 2024 02:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci-td-distributed ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp2) release notes category topic: not user facing topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants