Skip to content

make it easier to instantiate a model #362

@NickleDave

Description

@NickleDave

A key need for users is to get a quick sense of how a model is performing.
To do this, they need to load a model and inspect its outputs.

Due to the current design of the API, it's both complicated and unclear how to make an instance of a model, and then load a checkpoint into that model.

Intuitively, you'd think you would be able to just say, e.g.

model = TweetyNetModel()
model.load(checkpoint_path)

However this fails with:

TypeError: __init__() missing 4 required positional arguments: 'network', 'loss', 'optimizer', and 'metrics'

The "correct" way is to use the from_config class method, but even knowing this, it's still annoying to use.
That's because this method does not just accept, e.g., a path to a config file, but instead requires a dictionary config that maps a set of keys, {'network', 'optimizer', 'loss'} to dictionaries that are themselves key-value pairs mapping hyperparameter names to their values.

How do I get such a dictionary, you might ask?

Even more counterintuitively, you cannot just pass the whole config.toml file to vak.parse.config and get back a model config dict. Although functions for parsing model-specific tables in a .toml config file live within the config module, they are not called directly by vak.parse.config. Instead you have to call those functions directly, passing in both the path to the config.toml file and a list of mdoel names.
Accordingly, all the cli functions ('eval', 'predict', 'train', etc.) do something like this to get the model config:

model_config_map = vak.config.models.map_from_path(config_gr41rd51, config.learncurve.models)
models_map = vak.models.from_model_config_map(
    model_config_map,
    num_classes=len(labelmap),
    input_shape=val_dataset.shape[1:],
)

This is both verbose and non-obvious.

A relatively quick fix would be to add another class method, from_toml_path, that does all of this work for a user.

It's still annoying that you can't just say `model = ModelClass()' though.

Fixing that would require a breaking change, to something like a pytorch-lightning model, where the person designing the model specifies default hyperparameters, e.g. through methods like get_default_loss, but those can optionally be changed by passing in configurations as kwargs dicts. For my taste it would be possibly cleaner if defaults are just specified as class-level variables instead of requiring users to write methods which would then obscure the default hyperparameters--i.e. with pytorch-lightning you can't easily "introspect" the model to find out that the default learning rate is 0.001 because it's a magic number assigned within a method.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions