Skip to content

Conversation

mori360
Copy link
Contributor

@mori360 mori360 commented Oct 14, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Enable Optimizer-in-the-backward for full_finetune_distributed

Changelog

  • Update full_finetune_distributed for enabling Optimizer-in-the-backward
  • Update test_full_finetune_distributed with _optimizer_in_bwd config
  • updated test_distributed to test running with/without optimized_in_the_backward, and performance after saving-loading state_dict.

Test plan

  • Test running with optimizer_in_the_backward: tune run --nproc_per_node 2 full_finetune_distributed --config llama2/7B_full fsdp_cpu_offload=False max_steps_per_epoch=2 optimizer_in_bwd=True
  • Test running optimizer_in_the_backward with resume_from_checkpoint: tune run --nproc_per_node 2 full_finetune_distributed --config llama2/7B_full fsdp_cpu_offload=False max_steps_per_epoch=2 epochs=10 optimizer_in_bwd=True resume_from_checkpoint=True checkpointer.recipe_checkpoint=/tmp/Llama-2-7b-hf/recipe_state.pt checkpointer.checkpoint_files=[hf_model_0001_1.pt,hf_model_0002_1.pt]
  • Verify that running with Optimizer-in-the-backward could have the same loss, model_state_dict and optimizer_state_dict, model after saving and loading could also have the same: pytest tests/torchtune/training/test_distributed.py -k test_optimizer_in_backward

Memory cost analysis:
With each layer gradient cost 193MB memory, the origin(left) case has the peak memory at the 31th layer with accumulation of 193MB memory times 30.
The right case with optimizer-in-the-backward frees these memory during backward, gets lower peak memory.
memory compare

Training time and loss analysis:
training time and loss

Copy link

pytorch-bot bot commented Oct 14, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1833

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit ede3641 with merge base b02825a (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 14, 2024
self._optimizer.zero_grad(set_to_none=True)
if self._optimizer_in_bwd:
raise NotImplementedError(
"Gradient clipping is not supported after optimizer-in-the-backward."
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optimizer_in_backward frees gradient information during loss.backward, could not get the correct grad_norm

@@ -681,7 +735,12 @@ def train(self) -> None:
time_per_step = time.perf_counter() - t0
log_dict = {
"loss": loss_to_log,
"lr": self._optimizer.param_groups[0]["lr"],
"lr": get_lr(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

combine get_lr as an utils for both distributed and single_device to validate if all the LR are the same and return if True

@@ -29,7 +29,10 @@


class TestFullFinetuneDistributedRecipe:
def _get_test_config_overrides(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both "optimizer_in_bwd=True" and "clip_grad_norm=100" could cause the wrong grad_norm, separate them here to avoid, loss_value would not be affected by either "optimizer_in_bwd=True" or "clip_grad_norm=100"

@@ -60,9 +63,17 @@ def _fetch_expected_loss_values(self, model_type):
("llama3/8B_full", "llama3", "tune", "NO_SHARD"),
],
)
@pytest.mark.parametrize("optim_in_bwd", [True, False])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently add one more param input "optim_in_bwd" to have separate test, shall we have the test in another way? @ebsmothers

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this way is OK

@codecov-commenter
Copy link

codecov-commenter commented Oct 14, 2024

Codecov Report

Attention: Patch coverage is 10.52632% with 51 lines in your changes missing coverage. Please review.

Project coverage is 25.68%. Comparing base (c70ad29) to head (56bfafa).
Report is 8 commits behind head on main.

Files with missing lines Patch % Lines
recipes/full_finetune_distributed.py 0.00% 35 Missing ⚠️
torchtune/training/lr_schedulers.py 20.00% 12 Missing ⚠️
tests/recipes/test_full_finetune_distributed.py 40.00% 3 Missing ⚠️
recipes/full_finetune_single_device.py 0.00% 1 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (c70ad29) and HEAD (56bfafa). Click for more details.

HEAD has 2 uploads less than BASE
Flag BASE (c70ad29) HEAD (56bfafa)
3 1
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1833       +/-   ##
===========================================
- Coverage   67.30%   25.68%   -41.62%     
===========================================
  Files         304      305        +1     
  Lines       16000    16082       +82     
===========================================
- Hits        10768     4131     -6637     
- Misses       5232    11951     +6719     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Comment on lines 59 to 65
def get_lr(optimizer_in_bwd, vanilla_optimizer) -> str:
"""
Full_finetune_distributed and full_finetune_single_deivce assume all optimizers have
the same LR, here to validate whether all the LR are the same and return if True.
Bsed on optimizer_in_bwd, the second input here could be optimizer or optim_wrapper,
name it as vanilla_optimizer to be more general.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given this API is used in our recipes, we should

a) expose this as a public API here
b) add it to the API docs here
c) make sure the docstring's format matches those of our other public APIs (for example).

Also do you have pre-commit hooks installed? I think pydoclint should be complaining about this since you have raises that aren't documented in the docstring.

"""
Full_finetune_distributed and full_finetune_single_deivce assume all optimizers have
the same LR, here to validate whether all the LR are the same and return if True.
Bsed on optimizer_in_bwd, the second input here could be optimizer or optim_wrapper,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Bsed on optimizer_in_bwd, the second input here could be optimizer or optim_wrapper,
Based on optimizer_in_bwd, the second input here could be optimizer or optim_wrapper,

@@ -60,9 +63,17 @@ def _fetch_expected_loss_values(self, model_type):
("llama3/8B_full", "llama3", "tune", "NO_SHARD"),
],
)
@pytest.mark.parametrize("optim_in_bwd", [True, False])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this way is OK

@mori360 mori360 requested a review from ebsmothers October 15, 2024 20:25
@mori360 mori360 marked this pull request as ready for review October 17, 2024 21:35
Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more small comment on the versioning question. After that this should be good to go

@mori360 mori360 marked this pull request as draft October 22, 2024 00:11
@mori360 mori360 marked this pull request as ready for review October 23, 2024 21:19
@mori360 mori360 merged commit dc0591c into pytorch:main Oct 23, 2024
17 checks passed
@gameofdimension
Copy link

great work.
i am wondering whether it can be used with cilp_grad_norm_.

@awgu
Copy link

awgu commented Oct 24, 2024

Optimizer in backward and global gradient norm clipping does not algorithmically make sense 🤔

@gameofdimension
Copy link

so if cilp_grad_norm_ is required then we can not use "Optimizer in backward"?

@awgu
Copy link

awgu commented Oct 24, 2024

I think you would need to do something different mathematically, e.g. use previous iteration's total norm or clip each gradient separately.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants