gpt4 book ai didi

keras - 使用 model.fit_generator 时如何获取混淆矩阵

转载 作者:行者123 更新时间:2023-12-04 02:16:34 24 4
gpt4 key购买 nike

我正在使用 model.fit_generator 来训练我的二进制(两类)模型并获得结果,因为我直接从我的文件夹中提供输入图像。在这种情况下如何获得混淆矩阵(TP、TN、FP、FN),因为我通常使用 confusion_matrix sklearn.metrics的命令得到它,这需要predicted , 和 actual标签。但在这里我没有两者。可能我可以从 predict=model.predict_generator(validation_generator) 计算预测标签命令。但我不知道我的模型是如何从我的图像中获取输入标签的。我的输入文件夹的一般结构是:

train/
class1/
img1.jpg
img2.jpg
........
class2/
IMG1.jpg
IMG2.jpg
test/
class1/
img1.jpg
img2.jpg
........
class2/
IMG1.jpg
IMG2.jpg
........

我的一些代码块是:
train_generator = train_datagen.flow_from_directory('train',  
target_size=(50, 50), batch_size=batch_size,
class_mode='binary',color_mode='grayscale')


validation_generator = test_datagen.flow_from_directory('test',
target_size=(50, 50),batch_size=batch_size,
class_mode='binary',color_mode='grayscale')

model.fit_generator(
train_generator,steps_per_epoch=250 ,epochs=40,
validation_data=validation_generator,
validation_steps=21 )

所以上面的代码自动接受两个类输入,但我不知道它考虑哪个类 0 和哪个类 1。

最佳答案

probabilities = model.predict_generator(generator=test_generator)

会给我们一组概率。
y_true = test_generator.classes

会给我们真正的标签。

因为这是一个二元分类问题,所以你必须找到预测的标签。为此,您可以使用
y_pred = probabilities > 0.5

然后我们在测试数据集上有真实标签和预测标签。因此,混淆矩阵由下式给出
font = {
'family': 'Times New Roman',
'size': 12
}
matplotlib.rc('font', **font)
mat = confusion_matrix(y_true, y_pred)
plot_confusion_matrix(conf_mat=mat, figsize=(8, 8), show_normed=False)

关于keras - 使用 model.fit_generator 时如何获取混淆矩阵,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47907061/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com