-
-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Description
Is your feature request related to a problem? Please describe.
As discussed in Deep Learning mentoring meet of 05/08/2022:
Current Implementation:
BaseDeepNetwork
is a base class for creating neural networks. Each specific neural network is child of the BaseDeepNetwork
. For example, CNNNetwork
, CNTCNetwork
, LSTMNetwork
etc would be classes inheriting from the BaseDeepNetwork
having a single function called build_model
(which builds and returns the created keras network)
For estimators, there exist specific BaseDeepClassifier
and BaseDeepRegressor
, inheriting from BaseClassifier
and BaseRegressor
respectively. Specific estimators like CNNClassifier
inherit from BaseDeepClassifier
, and within the init
method create an object of the class CNNNetwork
. Then, in the fit
method, it gets the respective keras neural network by calling build_model
method from the CNNNetwork
object. The positive of this design is that when creating a CNNRegressor
, we do not need to re-write the code for the main CNN, instead we just use CNNNetwork
similar to how it is used in CNNClassifer
.
A point to note: CNNNetwork
returns the keras network built with all except the output layer, which is added in the specific estimator like CNNClassifier
New Propositions (discussions with @fkiraly @ltsaprounis today, and @GuzalBulatova previously)
Make BaseDeepClassifer
inherit from both BaseClassifier
and BaseDeepNetwork
. In this case, BaseDeepClassifier
will have all the methods and structure similar to other classifiers via BaseClassifier
, and it can have methods specific to Deep Learning models via BaseDeepNetwork
. This would reduce the amount of redundant code written across all DL models, and would lead to easier implementation of CustomNetworks by the users (when we move towards having that feature).
An example of redundant code occuring across all DL classes:
Check PR #3128 , where I have implemented a save model and load model functionality for DL model separately (since pickling keras networks is troublesome and not easily implementable). Here, the same save
and load
function needs to be implemented in BaseDeepClassifer
and BaseDeepRegressor
and BaseDeepForecaster
, which could have easily solved by having a common DL class all DL networks inherit from. I am sure there will be other such redundancies which we will find as I continue to migrate DL models from sktime-dl to sktime.
Currently, input and discussion is required to better design BaseDeepNetwork
so as to minimize code redundancy and make a more intuitive structure.