gpt4 book ai didi

tensorflow - 这是在训练期间重用 tf.slim 图进行验证的正确方法吗?

转载 作者:行者123 更新时间:2023-12-04 16:03:51 29 4
gpt4 key购买 nike

我正在尝试在训练期间的每个时期之后进行验证。

我正在创建如下图表:

import tensorflow as tf
from networks import densenet
from networks.densenet_utils import dense_arg_scope

with tf.variable_scope('scope') as scope:
with slim.arg_scope(dense_arg_scope()):
logits_train, _ = densenet(images, blocks=networks[
'densenet_265'], num_classes=1000, data_name='imagenet', is_training=True, scope='densenet265',
reuse=tf.AUTO_REUSE)
scope.reuse_variables()
with slim.arg_scope(dense_arg_scope()):
logits_val, _ = densenet(images, blocks=networks[
'densenet_265'], num_classes=1000, data_name='imagenet', is_training=False, scope='densenet265',
reuse=tf.AUTO_REUSE)

为了在训练或验证期间获得 logits,我执行以下操作:

is_training = tf.Variable(True, trainable=False, dtype=tf.bool)
training_mode = tf.assign(is_training, True)
validation_mode = tf.assign(is_training, False)
logits = tf.cond(tf.equal(is_training, tf.constant(True, dtype=tf.bool)), lambda: logits_train,
lambda: logits_val)

但是,当我运行我的代码时,出现 OOM 错误。我确信这不是因为批量大。这是因为,之前我犯了一个错误,在训练和验证期间使用了相同的图表。当时批处理大小为 32,图像大小为 224x224x3,代码运行良好。

我怀疑我在使用 is_training=False 验证期间尝试重用图表时犯了一些错误。

densenet 的代码取自以下两个文件: densenet_utils.py densenet.py

最佳答案

您在 logits_train 和 logits_val 中创建了两个独立的网络,因此这会占用您的网络原本占用的内存的两倍。 (我假设它设置正确并且变量共享正确,这可能是另一个问题,但这不太可能导致 OOM,大数据是激活,而不是权重。)

没有必要这样做。也使用相同的网络 logits_train 进行验证。事实证明,参数 is_training 也可以采用 bool 标量张量,因此您可以即时切换训练或推理模式。

所以就在您设置images 占位符的地方,将这一行作为下一行:

training_mode = tf.placeholder( shape = None, dtype = tf.bool )

然后在上面的代码中,像这样设置你的网络:

logits_train, _ = densenet(images, blocks=networks['densenet_265'],
num_classes=1000, data_name='imagenet', is_training=training_mode,
scope='densenet265', reuse=tf.AUTO_REUSE)

请注意,is_training 参数的值由上面的张量 training_mode 填充!

然后当您执行 sess.run( [ ... ] ) 命令(在上面的代码中不可见)时,您应该在 training_mode 中包含 code>feed_dict 像这样(伪代码):

result = sess.run( [ ??? ], feed_dict = { images : ???, training_mode : True / False } )

请注意,training_mode 张量现在根据您是否正在进行训练填充为 False 或 True。

这是基于我对batch_normalizationdropout 层的研究。

关于tensorflow - 这是在训练期间重用 tf.slim 图进行验证的正确方法吗?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49598717/

29 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com