gpt4 book ai didi

python - 使用python制作ROC曲线进行多分类

转载 作者:行者123 更新时间:2023-11-28 16:36:43 24 4
gpt4 key购买 nike

从这里跟进:Converting a 1D array into a 2D class-based matrix in python

我想为我的 46 个类(class)中的每一个类(class)绘制 ROC 曲线。我有 300 个测试样本,我已经运行我的分类器来进行预测。

y_test 是真正的类,y_pred 是我的分类器预测的。

这是我的代码:

    from sklearn.metrics import confusion_matrix, roc_curve, auc
from sklearn.preprocessing import label_binarize
import numpy as np

y_test_bi = label_binarize(y_test, classes=[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18, 19,20,21,2,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,3,40,41,42,43,44,45])
y_pred_bi = label_binarize(y_pred, classes=[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18, 19,20,21,2,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,3,40,41,42,43,44,45])
# Compute ROC curve and ROC area for each class
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(2):
fpr[i], tpr[i], _ = roc_curve(y_test_bi, y_pred_bi)
roc_auc[i] = auc(fpr[i], tpr[i])

但是,现在我收到以下错误:

Traceback (most recent call last):
File "C:\Users\app\Documents\Python Scripts\gbc_classifier_test.py", line 152, in <module>
fpr[i], tpr[i], _ = roc_curve(y_test_bi, y_pred_bi)
File "C:\Users\app\Anaconda\lib\site-packages\sklearn\metrics\metrics.py", line 672, in roc_curve
fps, tps, thresholds = _binary_clf_curve(y_true, y_score, pos_label)
File "C:\Users\app\Anaconda\lib\site-packages\sklearn\metrics\metrics.py", line 505, in _binary_clf_curve
y_true = column_or_1d(y_true)
File "C:\Users\app\Anaconda\lib\site-packages\sklearn\utils\validation.py", line 265, in column_or_1d
raise ValueError("bad input shape {0}".format(shape))
ValueError: bad input shape (300L, 46L)

最佳答案

roc_curve 采用形状为 [n_samples] ( link ) 的参数,以及您的输入(y_test_biy_pred_bi) 的形状是 (300, 46)。注意第一个

我认为问题在于 y_pred_bi 是一个概率数组,通过调用 clf.predict_proba(X) 创建(请确认)。由于您的分类器针对所有 46 个类别进行了训练,它会为每个数据点输出一个 46 维向量,label_binarize 对此无能为力。

我知道有两种解决方法:

  1. 通过在 clf.fit() 之前调用 label_binarize 训练 46 个二元 分类器,然后计算 ROC 曲线
  2. 对 300×46 输出数组的每一列进行切片,并将其作为第二个参数传递给 roc_curve。这是我的首选方法,因为我假设 y_pred_bi 包含概率

关于python - 使用python制作ROC曲线进行多分类,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/25133718/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com