Skip to content

Conversation

RunDevelopment
Copy link
Member

@RunDevelopment RunDevelopment commented Feb 20, 2024

This PR adds an internal API for hyperparameters. The basic idea is that the model classes now have a @store_hyperparameters decorator that changes the class to automatically store all given hyperparameters in a hyperparameters field. Example:

@store_hyperparameters()
class CodeFormer(VQAutoEncoder):
    hyperparameters = {}

Unfortunately, classes still have to declare the hyperparameters class variable for pyright to pick up on the field.

While all models now have a hyperparameters field, this field is not part of the public API (for now). In this PR, I just add implementation for hyperparameters and use them in tests.

Speaking of tests: assert_loads_correctly now uses the new hyperparameters to test whether a model was loaded correctly. This is much stricter than the old opt-in system via condition=... and already found a few minor differences between detected hyperparameters.

A future PR will deal with defining the public API for hyperparameters.

@RunDevelopment RunDevelopment changed the title Add hyperparameter API Add internal hyperparameter API Feb 21, 2024
@RunDevelopment RunDevelopment marked this pull request as ready for review February 21, 2024 14:02
@joeyballentine joeyballentine merged commit b3e3424 into main Feb 21, 2024
@joeyballentine joeyballentine deleted the hyperparameters branch February 21, 2024 14:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants