gpt4 book ai didi

tensorflow 模型。评估和模型。预测非常不同的结果

转载 作者:行者123 更新时间:2023-12-03 14:47:35 27 4
gpt4 key购买 nike

我正在构建一个用于二进制图像分类的简单CNN,并且从model.evaluate()获得的 AUC远远高于从model.predict()+ roc_auc_score()获得的AUC。

整个笔记本为here

为model.fit()编译模型和输出:

model.compile(loss='binary_crossentropy',
optimizer=RMSprop(lr=0.001),
metrics=['AUC'])

history = model.fit(
train_generator,
steps_per_epoch=8,
epochs=5,
verbose=1)

时代1/5
8/8 [=============================]-21s 3s/step-损耗:6.7315-auc:0.5143

时代2/5
8/8 [=============================]-15s 2s/step-损失:0.6626-auc:0.6983

时代3/5
8/8 [=============================]-18s 2s/step-损耗:0.4296-auc:0.8777

时代4/5
8/8 [=============================]-14s 2s/step-损失:0.2330-auc:0.9606

时代5/5
8/8 [=============================]-18s 2s/step-损耗:0.1985-auc:0.9767

然后,model.evaluate()给出类似的内容:
model.evaluate(train_generator)

9/9 [==============================]-10s 1s/step-损耗:0.3056-auc:0.9956

但是直接从model.predict()方法计算出的AUC降低了两倍:
from sklearn import metrics

x = model.predict(train_generator)
metrics.roc_auc_score(train_generator.labels, x)

0.5006148007590132

我已经阅读了有关类似问题的几篇文章(例如 thisthisthisextensive discussion on github),但它们描述了与我的情况无关的原因:
  • 将binary_crossenthropy用于多类任务(不是我的情况)
  • 由于使用批处理vs整体,评估和预测之间的差异
    数据集(不应该引起像我这样的急剧下降)
  • 使用批处理规范化和规范化(不是我的情况,也应该
    不会造成如此大的下降)

  • 任何建议,不胜感激。谢谢!

    编辑!解决方案
    我已经创建了解决方案 here,我只需要调用
    train_generator.reset()

    在model.predict之前,并在flow_from_directory()函数中设置shuffle = False。产生差异的原因是,生成器从不同的位置开始输出批次,因此标签和预测将不匹配,因为它们与不同的对象相关。因此,问题不在于评估或预测方法,而在于生成器。

    编辑2
    如果使用flow_from_directory()创建生成器,则使用train_generator.reset()并不方便,因为它需要在flow_from_directory中设置shuffle = False,但这会在训练过程中创建包含单个类的批处理,这会影响学习。因此,在运行预测之前,我最终重新定义了train_generator。

    最佳答案

    tensorflow.keras AUC通过Riemann和计算近似AUC(曲线下的面积),这与scikit-learn的实现方式不同。

    如果要查找带有tensorflow.keras的AUC,请尝试:

    import tensorflow as tf

    m = tf.keras.metrics.AUC()

    m.update_state(train_generator.labels, x) # assuming both have shape (N,)

    r = m.result().numpy()

    print(r)

    关于 tensorflow 模型。评估和模型。预测非常不同的结果,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61899676/

    27 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com