gpt4 book ai didi

python - plot_confusion_matrix() 使用 sklearn 得到了一个意外的关键字参数 'classes'

转载 作者:行者123 更新时间:2023-12-04 08:07:56 25 4
gpt4 key购买 nike

我是 python 和深度学习的新手,我训练了一个多分类器模型并想绘制一个混淆矩阵,但我遇到了一个错误这是我的代码

from sklearn.metrics import plot_confusion_matrix
import matplotlib.pyplot as plt
from sklearn.metrics import ConfusionMatrixDisplay
Y_pred = model.predict_generator(test_generator)
y_pred = np.argmax(Y_pred, axis=1)
category_names = sorted(os.listdir('D:/DiabaticRetinopathy/mq_dataset/DR_Normal/train'))
print(category_names)
cm = confusion_matrix(test_generator.classes, y_pred)
plot_confusion_matrix(cm, classes = category_names, title='Confusion Matrix', normalize=False, figname = 'Confusion_matrix_concrete.jpg')

我将我的 sklearn 更新到了 0.24 版本。更新后我重新启动了内核,但仍然出现错误:

TypeError: plot_confusion_matrix() got an unexpected keyword argument 'classes'

最佳答案

使用 labels 而不是类,然后 Remove title, figname

plot_confusion_matrix(X = test_generator.classes, y_true = y_pred,labels= category_names, normalize=False)

文档:https://scikit-learn.org/stable/modules/generated/sklearn.metrics.plot_confusion_matrix.html

关于python - plot_confusion_matrix() 使用 sklearn 得到了一个意外的关键字参数 'classes',我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/66132654/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com