gpt4 book ai didi

tensorflow - 使用可用的训练 Hook 在 tf.estimator.DNNRegressor 中实现提前停止

转载 作者:行者123 更新时间:2023-12-02 23:04:42 28 4
gpt4 key购买 nike

我是 tensorflow 新手,希望通过可用的训练 Hook 在 tf.estimator.DNNRegressor 中实现提前停止 Training Hooks对于 MNIST 数据集。如果在指定的步数内损失没有改善,早期停止钩子(Hook)将停止训练。 Tensorflow 文档仅提供 Logging hooks 的示例。有人可以编写一个代码片段来实现它吗?

最佳答案

这是一个 EarlyStoppingHook 示例实现:

import numpy as np
import tensorflow as tf
import logging
from tensorflow.python.training import session_run_hook


class EarlyStoppingHook(session_run_hook.SessionRunHook):
"""Hook that requests stop at a specified step."""

def __init__(self, monitor='val_loss', min_delta=0, patience=0,
mode='auto'):
"""
"""
self.monitor = monitor
self.patience = patience
self.min_delta = min_delta
self.wait = 0
if mode not in ['auto', 'min', 'max']:
logging.warning('EarlyStopping mode %s is unknown, '
'fallback to auto mode.', mode, RuntimeWarning)
mode = 'auto'

if mode == 'min':
self.monitor_op = np.less
elif mode == 'max':
self.monitor_op = np.greater
else:
if 'acc' in self.monitor:
self.monitor_op = np.greater
else:
self.monitor_op = np.less

if self.monitor_op == np.greater:
self.min_delta *= 1
else:
self.min_delta *= -1

self.best = np.Inf if self.monitor_op == np.less else -np.Inf

def begin(self):
# Convert names to tensors if given
graph = tf.get_default_graph()
self.monitor = graph.as_graph_element(self.monitor)
if isinstance(self.monitor, tf.Operation):
self.monitor = self.monitor.outputs[0]

def before_run(self, run_context): # pylint: disable=unused-argument
return session_run_hook.SessionRunArgs(self.monitor)

def after_run(self, run_context, run_values):
current = run_values.results

if self.monitor_op(current - self.min_delta, self.best):
self.best = current
self.wait = 0
else:
self.wait += 1
if self.wait >= self.patience:
run_context.request_stop()

此实现基于 Keras implementation .

与 CNN MNIST 一起使用 example创建钩子(Hook)并将其传递给train

early_stopping_hook = EarlyStoppingHook(monitor='sparse_softmax_cross_entropy_loss/value', patience=10)

mnist_classifier.train(
input_fn=train_input_fn,
steps=20000,
hooks=[logging_hook, early_stopping_hook])

这里 sparse_softmax_cross_entropy_loss/value 是该示例中损失操作的名称。

编辑1:

使用估计器时,似乎没有“官方”方法来查找损失节点(或者我找不到它)。

对于 DNNRegressor,此节点的名称为 dnn/head/weighted_loss/Sum

以下是如何在图表中找到它:

  1. 在模型目录中启动tensorboard。就我而言,我没有设置任何目录,因此估算器使用临时目录并打印此行:
    警告:tensorflow:使用临时文件夹作为模型目录:/tmp/tmpInj8SC
    启动张量板:

    tensorboard --logdir /tmp/tmpInj8SC
  2. 在浏览器中打开它并导航到“GRAPHS”选项卡。 enter image description here

  3. 在图中找出损失。按顺序展开 block :dnnheadweighted_loss 并单击 Sum 节点(请注意,有连接到它的名为 loss 的摘要节点)。 enter image description here

  4. 右侧信息“窗口”中显示的名称是所选节点的名称,需要将其传递给 monitor 参数 pf EarlyStoppingHook

默认情况下,DNNClassifier 的损失节点具有相同的名称。 DNNClassifierDNNRegressor 都有可选参数 loss_reduction 影响损失节点名称和行为(默认为 losses.Reduction.SUM >).

编辑2:

有一种方法可以不看图表就找到损失。
您可以使用 GraphKeys.LOSSES 集合来获取损失。但这种方式只有在训练开始后才有效。所以你只能在钩子(Hook)中使用它。

例如,您可以从 EarlyStoppingHook 类中删除 monitor 参数,并更改其 begin 函数以始终使用集合中的第一个损失:

self.monitor = tf.get_default_graph().get_collection(tf.GraphKeys.LOSSES)[0]

您可能还需要检查集合中是否存在丢失。

关于tensorflow - 使用可用的训练 Hook 在 tf.estimator.DNNRegressor 中实现提前停止,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48815906/

28 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com