Skip to content

Conversation

Nexisato
Copy link
Collaborator

@Nexisato Nexisato commented Jul 27, 2025

  • 基于 scikit-learnpyecharts 简单实现的 ML 相关 metrics
  • PR 曲线和 ROC 曲线由于 symbol 较多所以直接隐藏打点,auc 有计算,未标记在 title 中,主要太丑了
  • 注意,sklearn 的 confusion_matrix 的左上角为 X/Y 轴的 0 值起始点,在 pyecharts 的 heatmap 中是左下角,可以强转顺序,但是横轴的 class_names 可能会对不起来 (不过问题应该不是很大?)

参考测试代码:

from sklearn.datasets import make_classification
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import xgboost as xgb
import swanlab


def get_roc_pr_curve():
    X, y = make_classification(n_samples=1000, n_features=20, n_informative=2, n_redundant=10, random_state=42)

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

    model = xgb.XGBClassifier(use_label_encoder=False, eval_metric='logloss')
    model.fit(X_train, y_train)

    y_pred_proba = model.predict_proba(X_test)[:, 1]
    return y_test, y_pred_proba


def get_cm_mock():
    iris_data = load_iris()
    X = iris_data.data
    y = iris_data.target
    class_names = iris_data.target_names.tolist()

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

    dtrain = xgb.DMatrix(X_train, label=y_train)
    dtest = xgb.DMatrix(X_test, label=y_test)

    param = {
        'objective': 'multi:softmax',
        'num_class': len(class_names),
        'eta': 0.1,
        'max_depth': 3,
        'eval_metric': 'mlogloss',
    }
    num_round = 100

    bst = xgb.train(param, dtrain, num_round)
    preds = bst.predict(dtest)
    return y_test, preds, class_names


if __name__ == '__main__':

    y_test, y_pred_proba = get_roc_pr_curve()
    cm_test, cm_pred, class_names = get_cm_mock()
    print(cm_test, cm_pred, class_names)

    swanlab.login(api_key="<API_KEY>", host="<DEV_HOST>")
    swanlab.init(project="Echarts-Metrics-Demo", experiment_name="PR&ROC Curve")

    swanlab.log(
        {
            "pr_curve": swanlab.pr_curve(y_test, y_pred_proba),
            "roc_curve": swanlab.roc_curve(y_test, y_pred_proba),
        }
    )

    swanlab.log({"confusion_matrix": swanlab.confusion_matrix(cm_test, cm_pred, class_names)})

    swanlab.finish()

参考链接:
https://dev001.swanlab.cn/@nexisato/Echarts-Metrics-Demo/runs/5thtkcqquhn5ttcjkedhq/chart

Issue: #1160

@Nexisato Nexisato requested a review from Zeyi-Lin July 27, 2025 13:30
@Nexisato Nexisato self-assigned this Jul 27, 2025
@Nexisato Nexisato added the 💪 enhancement New feature or request label Jul 27, 2025
@Zeyi-Lin
Copy link
Member

建议不把scikit-learn作为requirements,而是替换成调用pr_curve等API时进行scikit-learn包引入检查的方式

@Zeyi-Lin
Copy link
Member

PR和ROC曲线有title是个好主意,可以使用一个title参数控制,比如默认title=True

@Zeyi-Lin Zeyi-Lin merged commit a6cbadf into main Jul 27, 2025
5 checks passed
@Zeyi-Lin Zeyi-Lin deleted the feat/ml-metric-chart branch August 11, 2025 04:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
💪 enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants