Skip to content

[ENH] Class Design for BaseDeepNetwork #3190

@AurumnPegasus

Description

@AurumnPegasus

Is your feature request related to a problem? Please describe.
As discussed in Deep Learning mentoring meet of 05/08/2022:

Current Implementation:
BaseDeepNetwork is a base class for creating neural networks. Each specific neural network is child of the BaseDeepNetwork. For example, CNNNetwork, CNTCNetwork, LSTMNetwork etc would be classes inheriting from the BaseDeepNetwork having a single function called build_model (which builds and returns the created keras network)
For estimators, there exist specific BaseDeepClassifier and BaseDeepRegressor, inheriting from BaseClassifier and BaseRegressor respectively. Specific estimators like CNNClassifier inherit from BaseDeepClassifier, and within the init method create an object of the class CNNNetwork. Then, in the fit method, it gets the respective keras neural network by calling build_model method from the CNNNetwork object. The positive of this design is that when creating a CNNRegressor, we do not need to re-write the code for the main CNN, instead we just use CNNNetwork similar to how it is used in CNNClassifer.
A point to note: CNNNetwork returns the keras network built with all except the output layer, which is added in the specific estimator like CNNClassifier

New Propositions (discussions with @fkiraly @ltsaprounis today, and @GuzalBulatova previously)
Make BaseDeepClassifer inherit from both BaseClassifier and BaseDeepNetwork. In this case, BaseDeepClassifier will have all the methods and structure similar to other classifiers via BaseClassifier, and it can have methods specific to Deep Learning models via BaseDeepNetwork. This would reduce the amount of redundant code written across all DL models, and would lead to easier implementation of CustomNetworks by the users (when we move towards having that feature).

An example of redundant code occuring across all DL classes:
Check PR #3128 , where I have implemented a save model and load model functionality for DL model separately (since pickling keras networks is troublesome and not easily implementable). Here, the same save and load function needs to be implemented in BaseDeepClassifer and BaseDeepRegressor and BaseDeepForecaster, which could have easily solved by having a common DL class all DL networks inherit from. I am sure there will be other such redundancies which we will find as I continue to migrate DL models from sktime-dl to sktime.

Currently, input and discussion is required to better design BaseDeepNetwork so as to minimize code redundancy and make a more intuitive structure.

Metadata

Metadata

Labels

API designAPI design & software architectureenhancementAdding new functionalityimplementing frameworkImplementing or improving framework for learning tasks, e.g., base class functionality

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions