Skip to content

Conversation

janeyx99
Copy link
Contributor

@janeyx99 janeyx99 commented Nov 20, 2023

Introduce OptimizerInfos + use them to refactor out the error testing.

Why OptimizerInfos?

  • cleaner, easier way to test all configs of optimizers
  • would plug in well with devicetype to auto-enable tests for devices like MPS, meta
  • would allow for more granular testing. currently, lots of functionality is tested in _test_basic_cases and some of that should be broken down more.

What did I do for error testing?

  • I moved out some error cases from _test_basic_cases into a new test_errors parametrized test.
  • The new test has to live in TestOptimRenewed (bikeshedding welcome) because the parametrized tests need to take in device and dtype and hook correctly, and not all tests in TestOptim do that.
  • TestOptimRenewed also is migrating to the toplevel test/test_optim.py now because importing TestOptimRenewed does not work (because of test instantiation, TestOptimRenewed gets replaced with TestOptimRenewedDevice for CPU, CUDA, and whatever other device).

Is there any change in test coverage?

  • INCREASE: The error case where a single Parameter (vs a container of them) are passed in has now expanded to all optims instead of only LBFGS
  • DECREASE: Not much. The only thing is we no longer test two error cases for foreach=True AND foreach=False, which I think is redundant. (Highlighted in comments)

Possible but not urgent next step: test ALL possible error cases by going through all the constructors.

Stack from ghstack (oldest at bottom):

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Nov 20, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/114178

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit e6a3a73 with merge base 7f49603 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

janeyx99 added a commit that referenced this pull request Nov 20, 2023
ghstack-source-id: b5742f3
Pull Request resolved: #114178
janeyx99 added a commit that referenced this pull request Nov 23, 2023
ghstack-source-id: 0ae75da
Pull Request resolved: #114178
@janeyx99 janeyx99 changed the title [init] Add OptimizerInfos Introduce OptimizerInfos + add a test_errors Nov 23, 2023
@@ -1715,8 +1666,7 @@ def test_rmsprop(self):
self._test_complex_optimizer(
lambda param: RMSprop([param], maximize=True, foreach=foreach)
)
with self.assertRaisesRegex(ValueError, "Invalid momentum value: -1.0"):
RMSprop(None, lr=1e-2, momentum=-1.0, foreach=foreach)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is one place where coverage "decreased", since I no longer test this error case for BOTH foreach T and F. Which I believe is unnecessary anyway.

Copy link
Contributor

@jbschlosser jbschlosser Nov 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe this is unnecessary, but it's easy enough to do this:

class TestOptimRenewed:
    ...
    @onlyCPU
    @optims([optim for optim in optim_db if optim.optim_error_inputs_func is not None])
    @parametrize("foreach", [False, True])
    def test_errors(self, device, dtype, optim_info, foreach):
        ...

if there are error cases that are only expected to fail for e.g. foreach=True, then I'd consider allowing for this flexibility in ErrorOptimizerInput

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, thanks for the code pointer. I don't believe it is necessary though.

@@ -1805,8 +1754,7 @@ def test_rprop(self):
[param], lr=0.001, maximize=True, foreach=foreach
)
)
with self.assertRaisesRegex(ValueError, "Invalid eta values: 1.0, 0.5"):
Rprop(None, lr=1e-2, etas=(1.0, 0.5), foreach=foreach)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And this is the other place where coverage "decreased", since I no longer test this error case for BOTH foreach T and F. Which I believe is unnecessary anyway.

Introduce OptimizerInfos + use them to refactor out the error testing.

Why OptimizerInfos?
- cleaner, easier way to test all configs of optimizers
- would plug in well with devicetype to auto-enable tests for devices like MPS, meta
- would allow for more granular testing. currently, lots of functionality is tested in `_test_basic_cases` and some of that should be broken down more.

What did I do for error testing?
- I moved out some error cases from `_test_basic_cases` into a new test_errors parametrized test.
- The new test has to live in TestOptimRenewed (bikeshedding welcome) because the parametrized tests need to take in device and dtype and hook correctly, and not all tests in TestOptim do that.
- TestOptimRenewed also is migrating to the toplevel test/test_optim.py now because importing TestOptimRenewed does not work (because of test instantiation, TestOptimRenewed gets replaced with TestOptimRenewedDevice for CPU, CUDA, and whatever other device).

Is there any change in test coverage?
- INCREASE: The error case where a single Parameter (vs a container of them) are passed in has now expanded to all optims instead of only LBFGS
- DECREASE: Not much. The only thing is we no longer test two error cases for foreach=True AND foreach=False, which I think is redundant. (Highlighted in comments)




[ghstack-poisoned]
Introduce OptimizerInfos + use them to refactor out the error testing.

Why OptimizerInfos?
- cleaner, easier way to test all configs of optimizers
- would plug in well with devicetype to auto-enable tests for devices like MPS, meta
- would allow for more granular testing. currently, lots of functionality is tested in `_test_basic_cases` and some of that should be broken down more.

What did I do for error testing?
- I moved out some error cases from `_test_basic_cases` into a new test_errors parametrized test.
- The new test has to live in TestOptimRenewed (bikeshedding welcome) because the parametrized tests need to take in device and dtype and hook correctly, and not all tests in TestOptim do that.
- TestOptimRenewed also is migrating to the toplevel test/test_optim.py now because importing TestOptimRenewed does not work (because of test instantiation, TestOptimRenewed gets replaced with TestOptimRenewedDevice for CPU, CUDA, and whatever other device).

Is there any change in test coverage?
- INCREASE: The error case where a single Parameter (vs a container of them) are passed in has now expanded to all optims instead of only LBFGS
- DECREASE: Not much. The only thing is we no longer test two error cases for foreach=True AND foreach=False, which I think is redundant. (Highlighted in comments)




[ghstack-poisoned]
janeyx99 added a commit that referenced this pull request Nov 28, 2023
ghstack-source-id: 2186040
Pull Request resolved: #114178
Introduce OptimizerInfos + use them to refactor out the error testing.

Why OptimizerInfos?
- cleaner, easier way to test all configs of optimizers
- would plug in well with devicetype to auto-enable tests for devices like MPS, meta
- would allow for more granular testing. currently, lots of functionality is tested in `_test_basic_cases` and some of that should be broken down more.

What did I do for error testing?
- I moved out some error cases from `_test_basic_cases` into a new test_errors parametrized test.
- The new test has to live in TestOptimRenewed (bikeshedding welcome) because the parametrized tests need to take in device and dtype and hook correctly, and not all tests in TestOptim do that.
- TestOptimRenewed also is migrating to the toplevel test/test_optim.py now because importing TestOptimRenewed does not work (because of test instantiation, TestOptimRenewed gets replaced with TestOptimRenewedDevice for CPU, CUDA, and whatever other device).

Is there any change in test coverage?
- INCREASE: The error case where a single Parameter (vs a container of them) are passed in has now expanded to all optims instead of only LBFGS
- DECREASE: Not much. The only thing is we no longer test two error cases for foreach=True AND foreach=False, which I think is redundant. (Highlighted in comments)




[ghstack-poisoned]
Introduce OptimizerInfos + use them to refactor out the error testing.

Why OptimizerInfos?
- cleaner, easier way to test all configs of optimizers
- would plug in well with devicetype to auto-enable tests for devices like MPS, meta
- would allow for more granular testing. currently, lots of functionality is tested in `_test_basic_cases` and some of that should be broken down more.

What did I do for error testing?
- I moved out some error cases from `_test_basic_cases` into a new test_errors parametrized test.
- The new test has to live in TestOptimRenewed (bikeshedding welcome) because the parametrized tests need to take in device and dtype and hook correctly, and not all tests in TestOptim do that.
- TestOptimRenewed also is migrating to the toplevel test/test_optim.py now because importing TestOptimRenewed does not work (because of test instantiation, TestOptimRenewed gets replaced with TestOptimRenewedDevice for CPU, CUDA, and whatever other device).

Is there any change in test coverage?
- INCREASE: The error case where a single Parameter (vs a container of them) are passed in has now expanded to all optims instead of only LBFGS
- DECREASE: Not much. The only thing is we no longer test two error cases for foreach=True AND foreach=False, which I think is redundant. (Highlighted in comments)




[ghstack-poisoned]
Introduce OptimizerInfos + use them to refactor out the error testing.

Why OptimizerInfos?
- cleaner, easier way to test all configs of optimizers
- would plug in well with devicetype to auto-enable tests for devices like MPS, meta
- would allow for more granular testing. currently, lots of functionality is tested in `_test_basic_cases` and some of that should be broken down more.

What did I do for error testing?
- I moved out some error cases from `_test_basic_cases` into a new test_errors parametrized test.
- The new test has to live in TestOptimRenewed (bikeshedding welcome) because the parametrized tests need to take in device and dtype and hook correctly, and not all tests in TestOptim do that.
- TestOptimRenewed also is migrating to the toplevel test/test_optim.py now because importing TestOptimRenewed does not work (because of test instantiation, TestOptimRenewed gets replaced with TestOptimRenewedDevice for CPU, CUDA, and whatever other device).

Is there any change in test coverage?
- INCREASE: The error case where a single Parameter (vs a container of them) are passed in has now expanded to all optims instead of only LBFGS
- DECREASE: Not much. The only thing is we no longer test two error cases for foreach=True AND foreach=False, which I think is redundant. (Highlighted in comments)

Possible but not urgent next step: test ALL possible error cases by going through all the constructors.




[ghstack-poisoned]
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

self,
optimizer_error_input,
*,
error_on=OptimizerErrorEnum.CONSTRUCTION_ERROR,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you expect these to be default being used? If not we can just remove the default values here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, most of them are on construction

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it true as well for the err type and regex below?

optim_inputs_func,
# Implementation specific kwargs the optimizer supports, e.g., fused, foreach, differentiable
# We consider capturable to be a base constructor flag since it is implemented across the board.
supported_impls: Tuple[str] = ("foreach", "differentiable"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be great for the comment above to mention every valid value for this flag.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait it does do that

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"kwargs the optimizer supports, e.g., fused, foreach, differentiable"
Is not the same as
"kwargs the optimizer supports, it must be a subset of fused, foreach, differentiable"

@janeyx99
Copy link
Contributor Author

janeyx99 commented Dec 5, 2023

@pytorchbot merge

let's see if the 3 days ago is recent enough

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 5, 2023
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

@janeyx99 janeyx99 added the topic: not user facing topic category label Dec 5, 2023
@janeyx99
Copy link
Contributor Author

janeyx99 commented Dec 5, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/janeyx99/108/head branch December 9, 2023 15:27
dmenig pushed a commit to dmenig/pytorch that referenced this pull request Dec 21, 2023
Introduce OptimizerInfos + use them to refactor out the error testing.

Why OptimizerInfos?
- cleaner, easier way to test all configs of optimizers
- would plug in well with devicetype to auto-enable tests for devices like MPS, meta
- would allow for more granular testing. currently, lots of functionality is tested in `_test_basic_cases` and some of that should be broken down more.

What did I do for error testing?
- I moved out some error cases from `_test_basic_cases` into a new test_errors parametrized test.
- The new test has to live in TestOptimRenewed (bikeshedding welcome) because the parametrized tests need to take in device and dtype and hook correctly, and not all tests in TestOptim do that.
- TestOptimRenewed also is migrating to the toplevel test/test_optim.py now because importing TestOptimRenewed does not work (because of test instantiation, TestOptimRenewed gets replaced with TestOptimRenewedDevice for CPU, CUDA, and whatever other device).

Is there any change in test coverage?
- INCREASE: The error case where a single Parameter (vs a container of them) are passed in has now expanded to all optims instead of only LBFGS
- DECREASE: Not much. The only thing is we no longer test two error cases for foreach=True AND foreach=False, which I think is redundant. (Highlighted in comments)

Possible but not urgent next step: test ALL possible error cases by going through all the constructors.

Pull Request resolved: pytorch#114178
Approved by: https://github.com/albanD
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged topic: not user facing topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants