Skip to content

Conversation

Innixma
Copy link
Contributor

@Innixma Innixma commented Feb 19, 2025

Issue #, if available:

Description of changes:

  • Add ag_key, ag_name and ag_priority class variables to all AbstractModel subclasses.
  • Add ag_model_register and ModelRegister to simplify model book keeping and enable registering custom models easily.
  • Add 134 unit tests for ModelRegister to ensure matching functionality with mainline.
  • This logic will be extremely helpful with making TabRepo 2.0 adoption lightweight.
  • Add name for DummyModel -> "Dummy", made relevant unit test changes. Previously it was the only unnamed model.

WIP:

  • Replace logic in autogluon.tabular.trainer.model_presets.presets to use ag_model_register. I decided not to do that in this PR as it would expand the scope to the point of being hard to review. So I will do this as a follow-up PR.

Example:

import pandas as pd

from autogluon.core.models import AbstractModel
from autogluon.tabular.register import ag_model_register


class MyCustomModel(AbstractModel):
    ag_key = "MYMODEL"
    ag_name = "MyModel"
    ag_priority = 70


ag_model_register.add(MyCustomModel)
print(f"Class mapped to 'MYMODEL' key: {ag_model_register.key_to_cls(key='MYMODEL')}")

print("Contents of ag_model_register:")
with pd.option_context("display.max_rows", None, "display.max_columns", None, "display.width", 1000):
    print(ag_model_register.to_frame())

Output:

Class mapped to 'MYMODEL' key: <class '__main__.MyCustomModel'>
Contents of ag_model_register:
                                       model_cls                    ag_name ag_priority
ag_key                                                                                 
RF                                       RFModel               RandomForest          80
XT                                       XTModel                 ExtraTrees          60
KNN                                     KNNModel                 KNeighbors         100
GBM                                     LGBModel                   LightGBM          90
CAT                                CatBoostModel                   CatBoost          70
XGB                                 XGBoostModel                    XGBoost          40
NN_TORCH              TabularNeuralNetTorchModel             NeuralNetTorch          25
LR                                   LinearModel                LinearModel          30
FASTAI                      NNFastAiTabularModel            NeuralNetFastAI          50
TRANSF                       TabTransformerModel                Transformer           0
AG_TEXT_NN                    TextPredictorModel              TextPredictor           0
AG_IMAGE_NN                  ImagePredictorModel             ImagePredictor           0
AG_AUTOMM               MultiModalPredictorModel        MultiModalPredictor           0
FT_TRANSFORMER                FTTransformerModel              FTTransformer           0
TABPFN                               TabPFNModel                     TabPFN         110
TABPFNMIX                         TabPFNMixModel                  TabPFNMix          45
FASTTEXT                           FastTextModel                   FastText           0
VW                             VowpalWabbitModel               VowpalWabbit          10
ENS_WEIGHTED         GreedyWeightedEnsembleModel           WeightedEnsemble           0
SIMPLE_ENS_WEIGHTED  SimpleWeightedEnsembleModel           WeightedEnsemble           0
IM_RULEFIT                          RuleFitModel                    RuleFit           0
IM_GREEDYTREE                    GreedyTreeModel                 GreedyTree           0
IM_FIGS                                FigsModel                       Figs           0
IM_HSTREE                            HSTreeModel  HierarchicalShrinkageTree           0
IM_BOOSTEDRULES                BoostedRulesModel               BoostedRules           0
DUMMY                                 DummyModel                      Dummy           0
MYMODEL                            MyCustomModel                    MyModel          70

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

@Innixma Innixma added API & Doc Improvements or additions to documentation enhancement New feature or request module: tabular labels Feb 19, 2025
@Innixma Innixma added this to the 1.3 Release milestone Feb 19, 2025
@Innixma
Copy link
Contributor Author

Innixma commented Feb 19, 2025

@shchur This logic might be applicable to TimeSeries as well. Curious on your thoughts.

@LennartPurucker
Copy link
Collaborator

LennartPurucker commented Feb 20, 2025

Nit: How far is the ag_priority of the model class connected or disconnected from the priority of a model configuration?
Would it be clearer to name this the ag_default_priority?

Copy link
Collaborator

@LennartPurucker LennartPurucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@Innixma
Copy link
Contributor Author

Innixma commented Feb 20, 2025

Nit: How far is the ag_priority of the model class connected or disconnected from the priority of a model configuration? Would it be clearer to name this the ag_default_priority?

Model configs can have specified priorities via "ag_args": {"priority": N}, as can the model name.

While it is true that this is only the default, it would be overly verbose to put "default" in the name IMO, since the same argument could be said for ag_name -> ag_default_name just feels a bit clunky.

@LennartPurucker
Copy link
Collaborator

Agree. Maybe then a comment somewhere clarifying this to users that have to set this value

Copy link

Job PR-4913-4a17edf is done.
Docs are uploaded to http://autogluon-staging.s3-website-us-west-2.amazonaws.com/PR-4913/4a17edf/index.html

Copy link
Collaborator

@shchur shchur left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, just several tangential thoughts / ideas and a few nits

ag_key: str | None = None # set to string value for subclasses for use in AutoGluon
ag_name: str | None = None # set to string value for subclasses for use in AutoGluon
ag_priority: int = 0 # set to int value for subclasses for use in AutoGluon
ag_priority_by_problem_type: dict[str, int] = {} # if not set, we fall back to ag_priority
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Setting a mutable object as the default value might lead to some problems in the future.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any thoughts on how to avoid this? We could move it to a property instead, but it would look less elegant.

Copy link
Collaborator

@shchur shchur Feb 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One alternative idea that comes to mind is turning this into an instance attribute and calling copy during __init__. This also feels less elegant though.

Another option is to use the built-in types.MappingProxyType to make the dict read-only.

from types import MappingProxyType

mapping = MappingProxyType({"a": 1, "b": 2})
mapping["c"] = 3  # will raise an exception

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MappingProxyType looks reasonable, I've updated the PR to use it

]

# TODO: Replace logic in `autogluon.tabular.trainer.model_presets.presets` with `ag_model_register`
ag_model_register = ModelRegister(model_cls_list=REGISTERED_MODEL_CLS_LST)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A potentially less verbose way to implement this pattern is using a decorator.

# in _model_register.py
ALIAS_TO_CLS: dict[str, type] = {}

def register(alias: str):
    def decorator(cls):
        ALIAS_TO_CLS[alias] = cls
        return cls
    return decorator

# in lgb_model.py
@register(alias="LGB")
class LGBModel(AbstractTimeSeriesModel):
    # ...

This could also be easier to use by the users that implement custom models.

ALIAS_TO_CLS can be replaced with something more advanced like the actual ModelRegistry class, or a dict where each alias corresponds to a tuple [model_cls, priority, priority_dict].

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One possible downside is this links all models to a global object inherently, whereas the ModelRegister is something that isn't inherently global and theoretically a user could pass their own ModelRegister object to the predictor. No idea if that is something that is desirable though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

merging the current logic as-is, we can re-visit this alternative if we think it is cleaner in future.

Comment on lines +91 to +207
ag_key: str | None = None # set to string value for subclasses for use in AutoGluon
ag_name: str | None = None # set to string value for subclasses for use in AutoGluon
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel that both ag_key and ag_name serve a similar purpose, and one of them could potentially be removed. For example, we could generate the ag_name automatically from the class name with something like self.name = re.sub(r"Model$", "", self.__class__.__name__).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed that they serve a similar purpose. For now I will keep both for legacy reasons, but we can consider standardizing. Maybe we can have self.ag_name default to your regix logic if it is None? Such as self._ag_name = None with def ag_name()` as a property?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that AG Tabular doesn't have the same orderly name structure of the classes as TimeSeries (but I probably should do it soon...):

For example:

RFModel -> RandomForest
XTModel -> ExtraTrees
LGBModel -> LightGBM

@canerturkmen
Copy link
Contributor

@shchur This logic might be applicable to TimeSeries as well. Curious on your thoughts.

I was about to start doing this.

@canerturkmen
Copy link
Contributor

Another common way to do this in Python. cf. here.

i.e., you don't even need a decorator. anyone who inherits from your Model class will be registered. You can still have a decorator if this is important though.

@canerturkmen
Copy link
Contributor

Another design consideration is if "priority" is an intrinsic property of the model itself, or is it something the Trainer assigns to the model?

@Innixma
Copy link
Contributor Author

Innixma commented Feb 24, 2025

Another design consideration is if "priority" is an intrinsic property of the model itself, or is it something the Trainer assigns to the model?

@canerturkmen This is a good question. For now I think it can be a value that is overwritten by the trainer if needed, but the reason I have it as part of the model is so that a model contribution would only require the contributor to make edits to the model file and nowhere else. Also, this means someone can bring their custom model to AutoGluon and easily specify the priority relative to other models when fitting.

Luckily this logic is rather easy to change later and deprecate if needed.

@Innixma
Copy link
Contributor Author

Innixma commented Feb 24, 2025

Another common way to do this in Python. cf. here.

i.e., you don't even need a decorator. anyone who inherits from your Model class will be registered. You can still have a decorator if this is important though.

@canerturkmen Do you have unit tests / example code showing what you can do with this?

One thing I might be wary of with automatic register on import is that I want to ensure no duplicate keys, and also it becomes a bit weird in some ways because whether the user's code works or not will now depend on what has been imported elsewhere, something the IDE wouldn't track and could make script sharing confusing. Just a hypothetical though, I don't know how real this concern is.

Copy link

Job PR-4913-dc11e1a is done.
Docs are uploaded to http://autogluon-staging.s3-website-us-west-2.amazonaws.com/PR-4913/dc11e1a/index.html

Copy link

Job PR-4913-062c3e4 is done.
Docs are uploaded to http://autogluon-staging.s3-website-us-west-2.amazonaws.com/PR-4913/062c3e4/index.html

@canerturkmen
Copy link
Contributor

Another common way to do this in Python. cf. here.
i.e., you don't even need a decorator. anyone who inherits from your Model class will be registered. You can still have a decorator if this is important though.

@canerturkmen Do you have unit tests / example code showing what you can do with this?

One thing I might be wary of with automatic register on import is that I want to ensure no duplicate keys, and also it becomes a bit weird in some ways because whether the user's code works or not will now depend on what has been imported elsewhere, something the IDE wouldn't track and could make script sharing confusing. Just a hypothetical though, I don't know how real this concern is.

This is the caller.

You could de-duplicate keys in the __new__ method, and raise an informative exception if there is a clash. The IDE will know where to find the exception. Other than that, I think it's a usability issue, I guess an argument could be made it's "too magic" but I would beg to differ. It's just one extra thing to remember and one extra place to touch when adding a new model.

I'm also pretty surprised this is automerging. I touched abstract_model quite a bit and it might be worth a rebase.

Other random consideration: I would pull it up in the namespace if you would like it to be used often and it's a part of the public API: autogluon.tabular.model_register, etc.

@Innixma
Copy link
Contributor Author

Innixma commented Feb 25, 2025

@canerturkmen re namespace: The example I had was outdated, here is the current path:

from autogluon.tabular.register import ag_model_register

I could maybe move it to

from autogluon.tabular import ag_model_register

but will wait until I figure out how exactly I want users to interact with the register in the follow-up PR.

@Innixma
Copy link
Contributor Author

Innixma commented Feb 25, 2025

This is the caller.

You could de-duplicate keys in the new method, and raise an informative exception if there is a clash. The IDE will know where to find the exception. Other than that, I think it's a usability issue, I guess an argument could be made it's "too magic" but I would beg to differ. It's just one extra thing to remember and one extra place to touch when adding a new model.

I'm also pretty surprised this is automerging. I touched abstract_model quite a bit and it might be worth a rebase.

Other random consideration: I would pull it up in the namespace if you would like it to be used often and it's a part of the public API: autogluon.tabular.model_register, etc.

This is a solid option, the only concern really is the "magic" aspect and oddities related to seemingly unused imports actually impacting whether a script works or not. I'll keep as is for now but we can always change it to be auto-add in a follow-up PR if we decide it is ultimately more user friendly.

The following situation is what I'm not a huge fan of:

# Fails due to not knowing `CUSTOM_MODEL_KEY`
from autogluon.tabular import TabularPredictor

predictor = TabularPredictor(label).fit(..., hyperparameters={"CUSTOM_MODEL_KEY": {...}})
# Succeeds because of the "unused" AGCustomModel import, which adds AGCustomModel to the registry
from autogluon.tabular import TabularPredictor

from ag_custom_model import AGCustomModel

predictor = TabularPredictor(label).fit(..., hyperparameters={"CUSTOM_MODEL_KEY": {...}})

Copy link
Contributor

@canerturkmen canerturkmen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@Innixma Innixma merged commit 6448dcc into autogluon:master Feb 25, 2025
14 checks passed
Copy link

Job PR-4913-22076b6 is done.
Docs are uploaded to http://autogluon-staging.s3-website-us-west-2.amazonaws.com/PR-4913/22076b6/index.html

@Innixma Innixma deleted the ag_model_register branch April 16, 2025 21:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
API & Doc Improvements or additions to documentation enhancement New feature or request module: tabular
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

4 participants