gpt4 book ai didi

machine-learning - Tensorflow:恢复 session 后重新打开队列时遇到问题

转载 作者:行者123 更新时间:2023-11-30 08:36:27 25 4
gpt4 key购买 nike

我有一个经过训练的模型,正在尝试在单独的数据集上进行评估,但我的输入管道遇到了问题。恢复 session 后,尝试加载第一批验证数据时,抛出以下错误:

tensorflow.python.framework.errors.OutOfRangeError: FIFOQueue '_2_input/batch/fifo_queue' is closed and has insufficient elements (requested 1024, current size 0)

我的代码是根据 cifar10_eval.py 示例 ( see here ) 建模的。

def read_and_decode(filename_queue):
reader = tf.TFRecordReader()
_, record = reader.read(filename_queue)

features = tf.parse_single_example(
record,
features={
'image_raw': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64),
})

label = tf.cast(features['label'], tf.int32)
image = tf.decode_raw(features['image_raw'], tf.uint8)

image.set_shape([21*21*1])
image = tf.cast(tf.reshape(image, (21, 21, 1)), tf.float32)

return image, label

def inputs(train, batch_size, num_epochs):
if train:
filename = os.path.join(DATA_DIR, TRAIN_FILE)
else:
filename = os.path.join(DATA_DIR, TEST_FILE)

with tf.name_scope('input'):
filename_queue = tf.train.string_input_producer(
[filename], num_epochs=num_epochs, shuffle=train)

example, label = read_and_decode(filename_queue)

min_after_dequeue = 1000
capacity = min_after_dequeue + 3 * batch_size

if train:
example_batch, label_batch = tf.train.shuffle_batch(
[example, label], batch_size=batch_size, capacity=capacity,
min_after_dequeue=min_after_dequeue)
else:
example_batch, label_batch = tf.train.batch([example, label], batch_size=batch_size,
capacity = capacity)

return example_batch, label_batch

def evaluate_model():
with tf.Graph().as_default():
images, labels = inputs(train=False, batch_size=1024,
num_epochs=NUM_EPOCHS)

keep_prob = tf.Variable(1.0, name='keep_prob', trainable=False)

logits = inference(images, keep_prob)
training_error = batch_training_error(logits, labels)
summary_op = tf.merge_all_summaries()

sess = tf.Session()

log_dir = os.path.join(SUMMARY_DIR, "eval2")
writer = tf.train.SummaryWriter(log_dir, sess.graph)

saver = tf.train.Saver()
saver.restore(sess, 'checkpoint/model-1280')

keep_prob.assign(1.0)

coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

#threads = []
#for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
# new_threads = qr.create_threads(sess, coord=coord, daemon=True, start=True)
# threads.extend(new_threads)

try:
step = 0
while not coord.should_stop():
err = sess.run(training_error)
print("Step %d, batch training error: %.3f" % (step, err))

if step % 10 == 0:
summary = sess.run(summary_op)
writer.add_summary(summary, global_step=step)
print('Summary written.')

step += 1
#except tf.errors.OutOfRangeError:
# print('Done training for %d epochs, %d steps.' % (NUM_EPOCHS, step))
finally:
coord.request_stop()

coord.join(threads)
sess.close()

evaluate_model()

我是 Tensorflow 新手,我无法理解我哪里出了问题。任何帮助将不胜感激。

最佳答案

尝试更换

saver = tf.train.Saver()

saver = tf.train.Saver( tf.trainable_variables() )

这对我来说是成功的。我坚持评论中的解释。您需要避免恢复队列(input_ Producer)状态。我还必须将我想要跟踪的“global_step”等不可训练的内容附加到该列表中。

关于machine-learning - Tensorflow:恢复 session 后重新打开队列时遇到问题,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/37632102/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com