-
Notifications
You must be signed in to change notification settings - Fork 1k
[core] refactor AbstractTrainer #4804
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core] refactor AbstractTrainer #4804
Conversation
1340302
to
f5e9042
Compare
Job PR-4804-f5e9042 is done. |
f5e9042
to
8320420
Compare
Job PR-4804-9eecbd8 is done. |
from autogluon.core.augmentation.distill_utils import augment_data, format_distillation_labels | ||
from autogluon.core.calibrate import calibrate_decision_threshold | ||
from autogluon.core.calibrate.conformity_score import compute_conformity_score | ||
from autogluon.core.calibrate.temperature_scaling import apply_temperature_scaling, tune_temperature_scaling | ||
from autogluon.core.callbacks import AbstractCallback | ||
from autogluon.core.constants import BINARY, MULTICLASS, QUANTILE, REFIT_FULL_NAME, REGRESSION, SOFTCLASS | ||
from autogluon.core.data.label_cleaner import LabelCleanerMulticlassToBinary | ||
from autogluon.core.metrics import Scorer, compute_metric, get_metric | ||
from autogluon.core.models import ( | ||
AbstractModel, | ||
BaggedEnsembleModel, | ||
GreedyWeightedEnsembleModel, | ||
SimpleWeightedEnsembleModel, | ||
StackerEnsembleModel, | ||
WeightedEnsembleModel, | ||
) | ||
from autogluon.core.pseudolabeling.pseudolabeling import assert_pseudo_column_match | ||
from autogluon.core.ray.distributed_jobs_managers import ParallelFitManager | ||
from autogluon.core.utils import ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why switch from relative to absolute imports? Tabular, core, common, and features all use relative imports.
Is there a general guideline you are following for favoring absolute over relative?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nothing particular in this case, it's just that these will move into tabular anyway. I can switch to relative imports for the ones that remain after that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PEP8 recommends absolute imports but I would vote in favor of consistency with other modules here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice! Added some minor comments but overall this is a great step in the right direction for unifying the API of trainer.
if not isinstance(model, str): | ||
model = model.name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic and variants of it can probably be put into utility methods since they are so common. Maybe as a fast follow PR. Will change a lot of 2-4 line logic into 1 line, and improve type hinting in the IDE so we can be more explicit about the type of the variable in the code.
if not self.low_memory: | ||
self.models[model.name] = model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The low_memory logic is something I implemented a very long time ago but at a certain point gave up on testing/using, since users didn't really care too much that files were being saved on disk, and it was a hassle trying to get everything working in-memory.
If time-series isn't using it, we may consider removing the low_memory logic entirely in a follow-up PR, it would make the code simpler and avoid us having to juggle multiple input/output variable types depending on the low_memory
setting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed. 👍
@property | ||
def path_root(self) -> str: | ||
"""directory containing learner.pkl""" | ||
return os.path.dirname(self.path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I remember wanting to refactor some of this logic as currently we have self.path
and self.path_root
which is a bit confusing. I might look into it closer as a follow-up.
The overall problem is that self.path
doesn't contain all of the artifacts needed by the Trainer, such as those from path_utils
. In an ideal world having it be self contained would be nice, kind of how Predictor is self-contained. But I would need to think about it more, because I also thought about a world where we could have multiple trainers for a single predictor, in which case it would be good for trainers to re-use certain artifacts between them.
Job PR-4804-1fa84e4 is done. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for this herculean effort! Only a several questions and minor comments
from autogluon.core.augmentation.distill_utils import augment_data, format_distillation_labels | ||
from autogluon.core.calibrate import calibrate_decision_threshold | ||
from autogluon.core.calibrate.conformity_score import compute_conformity_score | ||
from autogluon.core.calibrate.temperature_scaling import apply_temperature_scaling, tune_temperature_scaling | ||
from autogluon.core.callbacks import AbstractCallback | ||
from autogluon.core.constants import BINARY, MULTICLASS, QUANTILE, REFIT_FULL_NAME, REGRESSION, SOFTCLASS | ||
from autogluon.core.data.label_cleaner import LabelCleanerMulticlassToBinary | ||
from autogluon.core.metrics import Scorer, compute_metric, get_metric | ||
from autogluon.core.models import ( | ||
AbstractModel, | ||
BaggedEnsembleModel, | ||
GreedyWeightedEnsembleModel, | ||
SimpleWeightedEnsembleModel, | ||
StackerEnsembleModel, | ||
WeightedEnsembleModel, | ||
) | ||
from autogluon.core.pseudolabeling.pseudolabeling import assert_pseudo_column_match | ||
from autogluon.core.ray.distributed_jobs_managers import ParallelFitManager | ||
from autogluon.core.utils import ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PEP8 recommends absolute imports but I would vote in favor of consistency with other modules here.
model = model.name | ||
self.model_graph.nodes[model][attribute] = val | ||
|
||
def get_minimum_model_set(self, model, include_self=True) -> list: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Some methods are missing type hints.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not a nit at all and thanks for pointing it out.
a heads up, one thing that will come up in the next PR is how we handle this AbstractModel
type being passed around. Because TimeSeriesTrainer will like to only work with AbstractTimeSeriesModel
, however if we constrain it trivially in the function signature then pyright will complain LSP is violated. We will then have to revisit this class to use generics...
class AbstractTrainer(Generic[T]):
def work_with_model(model: T):
...
class AbstractTimeSeriesTrainer(AbstractTrainer[AbstractTimeSeriesModel]):
...
note the syntax gets much lighter for working with generics starting from 3.11.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was really hoping that by specializing in Python instead of C++ I will be able to avoid learning what a generic is... welp
1fa84e4
to
7842f9b
Compare
Job PR-4804-7842f9b is done. |
Issue #, if available:
Description of changes:
This is the first PR of a few for a major refactor of
AbstractTrainer
unifying the interfaces ofAbstractTimeSeriesTrainer
and tabular's Trainer class into one abstract class in core.This PR factors out
AbstractTrainer
incore
, including model management behavior that is common in time series and tabular, as well as common interfaces.Subsequent PRs will
AbstractTrainer
in core to tabular, since it is now almost exclusively tabular specific concretizations.SimpleAbstractTrainer
from timeseries, and move it to core.By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.