gpt4 book ai didi

python - 为什么 sklearn.metrics.confusion_matrix 和 sklearn.metrics.plot_confusion_matrix 的函数定义不一致?

转载 作者:行者123 更新时间:2023-12-04 10:12:16 28 4
gpt4 key购买 nike

我正在使用 sklearn,我注意到 sklearn.metrics.plot_confusion_matrix 的参数和 sklearn.metrics.confusion_matrix不一致。 plot_confusion_matrix用途 estimatorX构建 y_pred , 而 confusion_matrixy_pred直接作为论据。

这种不一致的原因可能是什么?

部分函数定义:

  • sklearn.metrics.plot_confusion_matrix(estimator, X, y_true, ...) [其中 X 应该是 X_test]
  • sklearn.metrics.confusion_matrix(y_true, y_pred, ...)

  • 资料来源:
  • plot_confusion_matrix
  • confusion_matrix
  • 最佳答案

    是的,您是对的,没有一致的 API 设计,但对此问题正在进行讨论 here .

    一种快速解决方法是 ConfusionMatrixDisplay .

    例子:

    from sklearn.datasets import make_classification
    from sklearn.preprocessing import StandardScaler
    from sklearn.pipeline import make_pipeline
    from sklearn.linear_model import LogisticRegression
    from sklearn.model_selection import train_test_split

    X, y = make_classification(random_state=1)
    X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y)

    clf = make_pipeline(StandardScaler(), LogisticRegression(random_state=0))
    clf.fit(X_train, y_train)

    from sklearn.metrics import confusion_matrix
    from sklearn.metrics import ConfusionMatrixDisplay

    y_pred = clf.predict(X_test)
    cm = confusion_matrix(y_test, y_pred)

    cm_display = ConfusionMatrixDisplay(cm, [0,1]).plot()

    关于python - 为什么 sklearn.metrics.confusion_matrix 和 sklearn.metrics.plot_confusion_matrix 的函数定义不一致?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61268535/

    28 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com