-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[Retiarii]: Add info required by nn-meter to graph ir #3910
Conversation
@@ -309,7 +309,7 @@ def add_node(self, name: str, operation: Operation) -> 'Node': ... | |||
@overload | |||
def add_node(self, name: str, type_name: str, parameters: Dict[str, Any] = {}) -> 'Node': ... | |||
|
|||
def add_node(self, name, operation_or_type, parameters={}): | |||
def add_node(self, name, operation_or_type, parameters=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why change the behavior of this API?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From my knowledge, it is not a good practice to use {}
as default values, here is a reference
https://florimond.dev/en/posts/2018/08/python-mutable-defaults-are-the-source-of-all-evil/
Not sure whether there is any other special reason, but I believe it worth a refactor to change all default {}
and []
to None
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why change the behavior of this API?
And this will cause parameters of the nodes to share the same dict object (the initial {}
). When one of the parameters is changed, all of the others will also be changed at the same time.
This results in some strange bugs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm okay with this change if you insist. But you should put a if parameters is None: parameters = {}
in the body of this function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If parameters is immutable in the function body, actually putting parameters={}
in the arguments works, except that IDE will complain.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm okay with this change if you insist. But you should put a if parameters is None: parameters = {} in the body of this function.
Yes. I have added that to Opeartion::__init__()
and Operation::__new__()
.
If parameters is immutable in the function body, actually putting parameters={} in the arguments works, except that IDE will complain.
But if you assign it to the attribute of the object, when next time you modify that attribute, the other objects will also be affected. (Actually this is why it incurs some strange bugs and it costs me a night to debug out 😂
nni/retiarii/nn/pytorch/api.py
Outdated
elif isinstance(candidates, list): | ||
for i, module in enumerate(candidates): | ||
self.add_module(str(i), module) | ||
self.names.append(str(i)) | ||
if not self.chosen: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need a chosen? This seems a fix
mode and should be done in __new__
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I need to call the first candidate. But as candidates
is a list/dict, instead of ModuleList
, so it can't be directly accessed in forward
(eg., candidates[0]
). I don't know if there are better ways to achieve that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In that case, you can write: self._modules[self.names[0]](x)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In that case, you can write: self._modulesself.names[0]
Thanks, I got it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In that case, you can write:
self._modules[self.names[0]](x)
It seems I can't access self._modules
in forward
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In that case, you can write self._first_module = self._modules[self.names[0]](x)
in __init__
. Might be more clear that self.chosen
.
My previous concern is that some one-shot algorithms use layer choice directly and might treat self.chosen
as another module due to wrong implementation. But I found that actually self.names
is used in __iter__
. So never mind...
nni/retiarii/converter/graph_gen.py
Outdated
cand_type = '__torch__.' + get_importable_name(cand.__class__) | ||
graph.add_node(cand_name, cand_type, get_init_parameters_or_fail(cand)) | ||
graph.add_node(cand_name, cand_type, get_init_parameters_or_fail(cand, silently=True)) | ||
self._convert_module(script_cand, cand, cand_name, ir_model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think choices in LayerChoice should stop parsing. Please test with examples in https://github.com/microsoft/nni/tree/master/test/retiarii_test/darts to see if that works.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But nn-meter needs the subgraph of layerchoice's candidates.
If there are really some errors, users just need to wrap the candidate with serialize
to stop parsing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually serialize
is required under the former implementation. There is no difference if serialize
is provided. What I implement is that when the candidate is not wrapped with serialize
, I will parse it recursively.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sense to me. If you have already tested the example, I'm okay with this change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggest to add ut for non-serialize case.
|
||
class HardwareAwareGraphConverter(GraphConverter): | ||
|
||
def convert_module(self, script_module, module, module_name, ir_model, example_inputs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggest adding docstring and unittests for this module.
I rebase and forced push to include multi-trial SPOS example (#3876) as later development is based on it. |
nni/retiarii/strategy/filter.py
Outdated
from nn_meter import get_default_config, load_latency_predictors # pylint: disable=import-error | ||
|
||
|
||
class LatencyFilter: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggest putting this filter into examples.
@@ -86,15 +86,28 @@ class Random(BaseStrategy): | |||
Do not try the same configuration twice. When variational is true, deduplication is not supported. Default: true. | |||
""" | |||
|
|||
def __init__(self, variational=False, dedup=True): | |||
def __init__(self, variational=False, dedup=True, model_filter=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please update the docstring correspondingly.
nni/retiarii/experiment/pytorch.py
Outdated
# TODO: this logic might need to be refactored into execution engine | ||
if full_ir: | ||
try: | ||
script_module = torch.jit.script(base_model) | ||
except Exception as e: | ||
_logger.error('Your base model cannot be parsed by torch.jit.script, please fix the following error:') | ||
raise e | ||
base_model_ir = convert_to_graph(script_module, base_model) | ||
if parse_shape: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like another execution engine.
I suppose you can merge parse_shape
and example_inputs
with full_ir
, and rename full_ir
to something like ir_format
.
@@ -171,7 +180,8 @@ def __init__(self, base_model: nn.Module, trainer: Union[Evaluator, BaseOneShotT | |||
|
|||
def _start_strategy(self): | |||
base_model_ir, self.applied_mutators = preprocess_model( | |||
self.base_model, self.trainer, self.applied_mutators, full_ir=self.config.execution_engine != 'py') | |||
self.base_model, self.trainer, self.applied_mutators, full_ir=self.config.execution_engine != 'py', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add another value to execution_engine
.
exp = RetiariiExperiment(base_model, trainer, [], simple_strategy) | ||
example_inputs = torch.randn(1, 3, 32, 32) | ||
|
||
base_model.eval() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is a little strange to config dummy input and "eval" here, let's discuss in the meeting
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree.
I think this input should be obtained from dataloader.
the built graph ir from module, ```None``` means do not further parse the module | ||
dict | ||
the input arguments of this module | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why wrap _convert_module
with exactly the same input arguments? and seems convert_module
does not have returns, but there are returns in the docstring
The tests failed because I move |
nni/retiarii/strategy/utils.py
Outdated
from typing import Dict, Any, List | ||
from nn_meter import get_default_config, load_latency_predictors # pylint: disable=import-error |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please move the filter to your example. This import will make nn_meter
a "required" dependency of NNI. But I don't think it should be required.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
nni/retiarii/experiment/pytorch.py
Outdated
@@ -154,7 +160,8 @@ def debug_mutated_model(base_model, trainer, applied_mutators): | |||
|
|||
class RetiariiExperiment(Experiment): | |||
def __init__(self, base_model: nn.Module, trainer: Union[Evaluator, BaseOneShotTrainer], | |||
applied_mutators: List[Mutator] = None, strategy: BaseStrategy = None): | |||
applied_mutators: List[Mutator] = None, strategy: BaseStrategy = None, | |||
parse_shape: bool = False, example_inputs = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any chance we can put them into config?
Tested on #3876 ShuffleNetV2