-
-
Notifications
You must be signed in to change notification settings - Fork 394
Multiple models & optimizers & phases support. Vanilla GAN example on MNIST. #365
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multiple models & optimizers & phases support. Vanilla GAN example on MNIST. #365
Conversation
self.lr = None | ||
self.momentum = None | ||
single_optimizer = isinstance(optimizer, Optimizer) | ||
self.lr = None if single_optimizer else defaultdict(lambda: None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dont use defaultdict
, try https://github.com/catalyst-team/safitty from catalyst-ecosystem
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So far I believe defaultdicts are much more clear here with much less code.
I agree that dicts instead may prevent some errors (but without safitty, again)
Maybe you can give me a masterclass?
self.lr = None
self.momentum = None
if isinstance(optimizer, dict):
self.lr = dict()
self.momentum = dict()
for key in optimizer:
safitty.set(self.lr, key, value=None)
safitty.set(self.momentum, key, value=None)
examples/mnist_gans/runner.py
Outdated
from catalyst.dl import Runner | ||
|
||
|
||
class GANRunner(Runner): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe it's better to move GANRunner
from examples into catalyst/dl/runner
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
Actually right now to implement some other GAN basing on current vanilla GANRunner someone will have to override it almost completely. I suppose we will have to think about better/more generic/easy extendable implementation which will be more useful parent class.
But I'm going to left this important task to other PR's =)
if wrapper_params: | ||
wrapper_params["base_callback"] = callback | ||
return ConfigExperiment._get_callback(**wrapper_params) | ||
return callback |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what do you think about another approach, like here https://github.com/catalyst-team/catalyst/blob/master/catalyst/contrib/optimizers/lookahead.py#L84 ?
so in the config it should looks like:
callback_name:
callback: WrapperCallback
param_1: ...
param_2: ...
base_callback_params:
param_3: ...
param_4: ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think my solution needs 1 line in config.yml compared to 2 lines here
@@ -166,11 +169,17 @@ def _run_epoch(self, loaders): | |||
self._run_loader(loader) | |||
self._run_event("loader_end") | |||
|
|||
def _run_prestage(self, stage: str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like a dirty hack :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I created this method to memorize something in runner before stage begins but after initialization is done. It's obviously a little bit hacky, but I think it's still useful, i.e. the hack is not dirty=)
examples/mnist_gans/callbacks.py
Outdated
|
||
@staticmethod | ||
def _get_tensorboard_logger(state: RunnerState) -> SummaryWriter: | ||
for logger_name, logger in state.loggers.items(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can simply import state.loggers["tensorboard"]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and what will happen when you rename that logger? =(
what do you think about the new solution? it seems to become even less readable
somehow moved to #407 |
Finally! GANs supported!
For now as a standalone example, which may be generalized in the future to any multi-model & multi-optimizer & multi-phase experiments