Skip to content

fix: Correct plotting of trees with more than 1 output #2668

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed

fix: Correct plotting of trees with more than 1 output #2668

wants to merge 1 commit into from

Conversation

tincher
Copy link

@tincher tincher commented May 21, 2024

Before submitting a pull request, please do the following steps:

I hereby agree to the terms of the CLA available at: https://yandex.ru/legal/cla/?lang=en.

  1. Read instructions for contributors.
  2. Make sure the code builds.
  3. If you add new functionality add tests to check it.
  4. Run existing tests to make sure you haven't broken anything.
  5. If you haven't already, sign the Contributor License Agreement.

#------

If a tree is trained with multiple outputs the plot_tree doesnt work for all trees because the indeces are too big.
I tested this with multiple outputs and multiple numbers of outputs.
I attached a minimal example which crashes:

import numpy as np
from catboost import CatBoostClassifier, Pool
from catboost.datasets import titanic

titanic_df = titanic()

y = titanic_df[0][["Survived", "Sex"]]
y.loc[:, "Sex"] = y.loc[:, "Sex"].map({"male": 1, "female": 0})
X = titanic_df[0].drop(["Survived", "Sex"], axis=1)

is_cat = X.dtypes != float
for feature, feat_is_cat in is_cat.to_dict().items():
    if feat_is_cat:
        X[feature].fillna("NAN", inplace=True)

cat_features_index = np.where(is_cat)[0]
pool = Pool(X, y, cat_features=cat_features_index, feature_names=list(X.columns))

parameters = {
    "iterations": 10,
    "depth": 3,
    "grow_policy": "Depthwise",
    "loss_function": "MultiLogloss",
    "random_seed": 0,
}
model = CatBoostClassifier(**parameters)
model.fit(pool)
model.plot_tree(0)

robot-piglet pushed a commit that referenced this pull request Aug 20, 2024
…ultidimensional approx with non-oblivious trees: #2668).

e47255eea952cef26d1cce2b8a960ad0bf3af6f8
@andrey-khropov
Copy link
Member

Thank you for the bug report. I've fixed the bug in the more general case (includes MultiRegression as well) in b45e7a5. The fix will be included in the next release.

@tincher tincher deleted the fix-tree_plotting branch August 31, 2024 05:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants