gpt4 book ai didi

python - tensorflow:使用队列运行器有效地提供 eval/train 数据

转载 作者:太空狗 更新时间:2023-10-29 17:49:42 27 4
gpt4 key购买 nike

我正在尝试运行 tensorflow 图来训练模型并使用单独的评估数据集定期进行评估。训练和评估数据都是使用队列运行器实现的。

我当前的解决方案是在同一个图中创建两个输入,并使用依赖于 is_training 占位符的 tf.cond。以下代码突出显示了我的问题:

import tensorflow as tf
from tensorflow.models.image.cifar10 import cifar10
from time import time


def get_train_inputs(is_training):
return cifar10.inputs(False)


def get_eval_inputs(is_training):
return cifar10.inputs(True)


def get_mixed_inputs(is_training):
train_inputs = get_train_inputs(None)
eval_inputs = get_eval_inputs(None)

return tf.cond(is_training, lambda: train_inputs, lambda: eval_inputs)


def time_inputs(inputs_fn, n_runs=10):
graph = tf.Graph()
with graph.as_default():
is_training = tf.placeholder(dtype=tf.bool, shape=(),
name='is_training')
images, labels = inputs_fn(is_training)

with tf.Session(graph=graph) as sess:
coordinator = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coordinator)
t = time()
for i in range(n_runs):
im, l = sess.run([images, labels], feed_dict={is_training: True})
dt = time() - t
coordinator.request_stop()
coordinator.join(threads)

return dt / n_runs

print('Train inputs: %.3f' % time_inputs(get_train_inputs))
print('Eval inputs: %.3f' % time_inputs(get_eval_inputs))
print('Mixed inputs: %.3f' % time_inputs(get_mixed_inputs))

我还必须注释掉 tensorflow/models/image/cifar10/cifar10_inputs.pyimage_summary133

这产生了以下结果:

Train inputs: 0.055
Eval inputs: 0.050
Mixed inputs: 0.105

在混合情况下,似乎两个输入都被读取/解析,即使只使用了 1 个。有没有办法避免这种冗余计算?或者是否有更好的方法在仍然利用队列运行器设置的训练/评估数据之间切换?

最佳答案

你读过这个 link 的最后一节了吗?关于多输入?我认为您可以在输入函数中添加一个 is_training 参数来区分训练数据和评估数据。然后,您可以重用共享变量来获取评估数据的逻辑值并为评估构建一个操作。然后在您的图表中,运行 valudation_accuracy=sess.run(eval_op) 以获得 eval 准确性。


更新:

你好,据我了解,如果你想训练 n 个批处理,评估、训练、评估,你可以在同一个图中保留两个操作,不需要构建一个新的。假设您已经构建了所有需要的功能,那么代码应该是这样的:

#the following two steps will add train and eval input queue to the graph
train_inputs,train_labels = inputs(is_train=True)
eval_inputs,eval_labels = inputs(is_train=False)

with tf.variable_scope("inference") as scope:
train_logits = inference(train_inputs)
scope.reuse_variables()
eval_logits = inference(eval_inputs)

loss = loss(train_logits,train_labels)
eval_accuracy = accuracy(eval_logits,eval_labels)

#...add train op here,start queue runner and train it ...

关于python - tensorflow:使用队列运行器有效地提供 eval/train 数据,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/39187764/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com