gpt4 book ai didi

tensorflow - 使用 tf.data.Dataset 评估性能的最佳方式

转载 作者:行者123 更新时间:2023-12-05 04:58:54 24 4
gpt4 key购买 nike

我训练了一个模型,现在想评估它在测试集上的表现。测试集被加载为 tf.data.TFRecordDataset 对象(来自多个 TFRecord,每个 TFRecord 中有多个示例),其中包含元组(图像,标签)形式的约百万个示例,数据是批处理的。然后将原始标签映射到模型需要预测的目标整数(单热编码)。

我知道我可以将 Dataset 对象作为输入传递给 model.predict(),它将输出数据集中每个示例的预测。然而,为了计算一些指标,我需要将真实目标值与预测值进行比较,并获得前者我需要遍历 Dataset,因为所有真实标签都存储在那里。

这似乎是一项常见任务,但我找不到适用于 TFRecord 格式的大型数据集的直接解决方案。例如,在这种情况下,计算每个类的 AUC 的最佳方法是什么?我应该将回调与 model.predict(test_dataset) 一起使用吗?或者我应该在一个循环中一个一个地处理每个示例,将真实值和预测值保存到数组中,然后使用例如 sklearn.metrics.roc_auc_score() 来计算两个数组的 AUC 分数?或者我可能遗漏了一些明显的方法?

提前致谢!

最佳答案

如果您需要所有标签,为什么不只:

model.evaluate(test_dataset.take(-1))

或者如果您的 ds 对于此操作来说太大,只需遍历您的数据集,最后计算您的指标和平均值。

关于tensorflow - 使用 tf.data.Dataset 评估性能的最佳方式,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/63789585/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com