gpt4 book ai didi

logging - Tensorflow Estimator 打印损失时使用什么数据集

转载 作者:行者123 更新时间:2023-11-30 08:45:38 24 4
gpt4 key购买 nike

当使用 Tensorflow Estimator 时,它会在 python 控制台上打印(每 100 步)全局步数和损失(此外,它还会打印学习率、交叉熵和 MAE(这是我的评估指标),并打印这些3 个值在不同的行中,我认为这是由于某些包装函数造成的,该函数不是原始 Estimator API 的一部分,因为我正在使用 google 开发人员的 ResNet 实现)。它看起来像这样:

    I0530 19:20:42.748463 10964 tf_logging.py:116] learning_rate = 3.552962e-05, cross_entropy = 2.2080934, MAE = 5.135024 (62.295 sec)   
I0530 19:20:42.749458 10964 tf_logging.py:116] loss = 2.2080934, step = 76066 (62.295 sec)

我的问题是,正在计算什么损失(或者正在计算什么 MAE)?
是否是日志记录发生时特定步骤中仅一个示例的损失?
是记录发生时特定步骤的批处理损失吗?
或者也许是整个列车组的损失?

另外,如果我的假设有问题,请纠正我。我在这个领域还是个新手。
谢谢。

最佳答案

tf.Estimator 自动为损失和全局步骤设置一个 LoggingTensorHook。据推测,您运行的代码为其他值(学习率、交叉熵(这只是损失)和 MAE)设置了一个单独的钩子(Hook),这就是为什么它们被打印在单独的行上。

至于使用哪些数据来生成值:它是“当前”批处理的数据,即完成日志记录的步骤中使用的批处理。因此,在您提出的三个选项中,第二个是正确的。

这可以通过the source code确认,因为钩子(Hook)会“运行后”进行日志记录,并在 run_values 中接收最后一次 session.run() 调用的结果(一次仅获取一批)。

关于logging - Tensorflow Estimator 打印损失时使用什么数据集,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50609555/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com