-
Notifications
You must be signed in to change notification settings - Fork 17
Description
A key need for users is to get a quick sense of how a model is performing.
To do this, they need to load a model and inspect its outputs.
Due to the current design of the API, it's both complicated and unclear how to make an instance of a model, and then load a checkpoint into that model.
Intuitively, you'd think you would be able to just say, e.g.
model = TweetyNetModel()
model.load(checkpoint_path)
However this fails with:
TypeError: __init__() missing 4 required positional arguments: 'network', 'loss', 'optimizer', and 'metrics'
The "correct" way is to use the from_config
class method, but even knowing this, it's still annoying to use.
That's because this method does not just accept, e.g., a path to a config file, but instead requires a dictionary config
that maps a set of keys, {'network', 'optimizer', 'loss'}
to dictionaries that are themselves key-value pairs mapping hyperparameter names to their values.
How do I get such a dictionary, you might ask?
Even more counterintuitively, you cannot just pass the whole config.toml file to vak.parse.config
and get back a model config dict
. Although functions for parsing model-specific tables in a .toml config file live within the config
module, they are not called directly by vak.parse.config
. Instead you have to call those functions directly, passing in both the path to the config.toml file and a list of mdoel names.
Accordingly, all the cli
functions ('eval', 'predict', 'train', etc.) do something like this to get the model config:
model_config_map = vak.config.models.map_from_path(config_gr41rd51, config.learncurve.models)
models_map = vak.models.from_model_config_map(
model_config_map,
num_classes=len(labelmap),
input_shape=val_dataset.shape[1:],
)
This is both verbose and non-obvious.
A relatively quick fix would be to add another class method, from_toml_path
, that does all of this work for a user.
It's still annoying that you can't just say `model = ModelClass()' though.
Fixing that would require a breaking change, to something like a pytorch-lightning model, where the person designing the model specifies default hyperparameters, e.g. through methods like get_default_loss
, but those can optionally be changed by passing in configurations as kwargs dicts. For my taste it would be possibly cleaner if defaults are just specified as class-level variables instead of requiring users to write methods which would then obscure the default hyperparameters--i.e. with pytorch-lightning you can't easily "introspect" the model to find out that the default learning rate is 0.001 because it's a magic number assigned within a method.