Skip to content

[BUG]: return_train_score not passed correctly #3839

@Nelsaur

Description

@Nelsaur

pycaret version checks

Issue Description

The "return_train_score" argument is not passed in the supervised_experiment.py file. This results in the "split_train_score" dictionary key not being added model_grid.cv_results variable.

Reproducible Example

import pycaret
from pycaret.datasets import get_data
from pycaret.regression import *
from pycaret.regression import RegressionExperiment
from pycaret.regression import tune_model

data = get_data('insurance')
s = setup(data, target = 'charges', session_id = 123)
exp = RegressionExperiment()
exp.setup(data, target = 'charges', session_id = 123, use_gpu=True)
new_model = create_model('gbr', return_train_score=True)
tuned_model, all_params = tune_model(estimator=new_model, return_train_score=True, verbose=True, tuner_verbose=True, return_tuner=True)

assert "split0_train_score" in all_params.cv_results_

Expected Behavior

"split0_train_score" should exist as a dictionary key in all_params.cv_results

Actual Results

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
c:\Users\user\source\pycaret_fix\test\Tutorial - Regression.ipynb Cell 2 line 8
      5 new_model = create_model('gbr', return_train_score=True)
      6 tuned_model, all_params = tune_model(estimator=new_model, return_train_score=True, verbose=True, tuner_verbose=True, return_tuner=True)
----> 8 assert "split0_train_score" in all_params.cv_results_

AssertionError:

Installed Versions

System: python: 3.10.13 | packaged by Anaconda, Inc. | (main, Sep 11 2023, 13:24:38) [MSC v.1916 64 bit (AMD64)] executable: [c:\Users\user\anaconda3\envs\pycaret_fix\python.exe](file:///C:/Users/user/anaconda3/envs/pycaret_fix/python.exe) machine: Windows-10-10.0.19045-SP0

PyCaret required dependencies:
pip: 23.3.1
setuptools: 68.0.0
pycaret: 3.2.0
IPython: 8.18.1
ipywidgets: 8.1.1
tqdm: 4.66.1
numpy: 1.25.2
pandas: 1.5.3
jinja2: 3.1.2
scipy: 1.10.1
joblib: 1.3.2
sklearn: 1.2.2
pyod: 1.1.2
imblearn: 0.11.0
category_encoders: 2.6.3
lightgbm: 4.1.0
numba: 0.58.1
requests: 2.31.0
matplotlib: 3.6.0
scikitplot: 0.3.7
yellowbrick: 1.5
plotly: 5.18.0
plotly-resampler: Not installed
kaleido: 0.2.1
schemdraw: 0.15
statsmodels: 0.14.0
sktime: 0.21.1
tbats: 1.1.3
pmdarima: 2.0.4
psutil: 5.9.6
markupsafe: 2.1.3
pickle5: Not installed
cloudpickle: 3.0.0
deprecation: 2.1.0
xxhash: 3.4.1
wurlitzer: Not installed

PyCaret optional dependencies:
shap: Not installed
interpret: Not installed
umap: Not installed
ydata_profiling: Not installed
explainerdashboard: Not installed
autoviz: Not installed
fairlearn: Not installed
deepchecks: Not installed
xgboost: Not installed
catboost: Not installed
kmodes: Not installed
mlxtend: Not installed
statsforecast: Not installed
tune_sklearn: Not installed
ray: Not installed
hyperopt: Not installed
optuna: Not installed
skopt: Not installed
mlflow: Not installed
gradio: Not installed
fastapi: Not installed
uvicorn: Not installed
m2cgen: Not installed
evidently: Not installed
fugue: Not installed
streamlit: Not installed
prophet: Not installed

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions