gpt4 book ai didi

keras - 这是 sklearn 分类报告对多标签分类报告的正确使用吗?

转载 作者:行者123 更新时间:2023-12-04 13:26:03 40 4
gpt4 key购买 nike

我正在用 tf-keras 训练一个神经网络。这是一个多标签分类,其中每个样本属于多个类别 [1,0,1,0..etc] .. 最终模型线(只是为了清楚起见)是:

model.add(tf.keras.layers.Dense(9, activation='sigmoid'))#final layer

model.compile(loss='binary_crossentropy', optimizer=optimizer,
metrics=[tf.keras.metrics.BinaryAccuracy(),
tfa.metrics.F1Score(num_classes=9, average='macro',threshold=0.5)])
我需要为这些生成精确度、召回率和 F1 分数(我已经在训练期间获得了 F1 分数)。为此,我使用 sklearns 分类报告,但我需要确认我在多标签设置中正确使用它。
from sklearn.metrics import classification_report

pred = model.predict(x_test)
pred_one_hot = np.around(pred)#this generates a one hot representation of predictions

print(classification_report(one_hot_ground_truth, pred_one_hot))
这很好用,我得到了每个类(class)的完整报告,包括与来自 tensorflow 插件的 F1score 指标相匹配的 F1 分数(对于宏 F1)。抱歉,这篇文章很冗长,但我不确定的是:
在多标签设置的情况下,预测需要进行单热编码是否正确?如果我传入正常的预测分数(sigmoid 概率),则会抛出错误...
谢谢你。

最佳答案

正确使用classification_report用于二元、多类和多标签分类。
在多类分类的情况下,标签不是单热编码的。他们只需要是 indiceslabels .
您可以看到下面的两个代码产生相同的输出:
索引示例

from sklearn.metrics import classification_report
import numpy as np

labels = np.array(['A', 'B', 'C'])


y_true = np.array([1, 2, 0, 1, 2, 0])
y_pred = np.array([1, 2, 1, 1, 1, 0])
print(classification_report(y_true, y_pred, target_names=labels))
带标签的示例
from sklearn.metrics import classification_report
import numpy as np

labels = np.array(['A', 'B', 'C'])

y_true = labels[np.array([1, 2, 0, 1, 2, 0])]
y_pred = labels[np.array([1, 2, 1, 1, 1, 0])]
print(classification_report(y_true, y_pred))
两者都返回
              precision    recall  f1-score   support

A 1.00 0.50 0.67 2
B 0.50 1.00 0.67 2
C 1.00 0.50 0.67 2

accuracy 0.67 6
macro avg 0.83 0.67 0.67 6
weighted avg 0.83 0.67 0.67 6
在多标签分类的背景下, classification_report可以像下面的例子一样使用:
from sklearn.metrics import classification_report
import numpy as np

labels =['A', 'B', 'C']

y_true = np.array([[1, 0, 1],
[0, 1, 0],
[1, 1, 1]])
y_pred = np.array([[1, 0, 0],
[0, 1, 1],
[1, 1, 1]])

print(classification_report(y_true, y_pred, target_names=labels))

关于keras - 这是 sklearn 分类报告对多标签分类报告的正确使用吗?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/68374999/

40 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com