gpt4 book ai didi

machine-learning - TensorFlow:当批处理完成训练后,tf.train.batch 是否会自动加载下一批?

转载 作者:行者123 更新时间:2023-11-30 08:22:06 25 4
gpt4 key购买 nike

例如,在我创建操作、通过操作提供批处理数据并运行操作后,tf.train.batch 是否会自动向 session 提供另一批数据?

我问这个是因为 tf.train.batch 有一个 allow_smaller_final_batch 属性,这使得最终批处理加载的大小可以小于指定的批处理大小。这是否意味着即使没有循环,下一批也可以自动送料?从教程代码来看,我很困惑。当我加载单个批处理时,我实际上得到了形状为 [batch_size, height, width, num_channels] 的单个批处理大小,但是 documentation说它在张量中创建批量张量。另外,当我阅读 tf-slim walkthrough tutorial 中的教程代码时,其中有一个名为 load_batch 的函数,仅返回 3 个张量:images、images_raw、labels。文档中解释的“批处理”数据在哪里?

感谢您的帮助。

最佳答案

... does tf.train.batch automatically feeds in another batch of data to the session?

没有。没有什么是自动发生的。您必须再次调用 sess.run(...) 才能加载新批处理。

Does this mean even without a loop, the next batch could be automatically fed?

没有。 tf.train.batch(..) 将始终加载 batch_size 张量。例如,如果您有 100 个图像和一个 batch_size=30 那么您将拥有 3*30 个批处理,因为您可以在输入之前调用 sess.run(batch) 三次队列将从头开始(如果epoch=1则停止)。这意味着您在训练中错过了 100-3*30=10 个样本。如果您不想错过它们,您可以执行 tf.train.batch(...,allow_smaller_final_batch=True) 所以现在您将拥有 3x 30-sample-batches 和 1x 10-sample-batches在输入队列重新启动之前进行批处理。

让我用一个代码示例来详细说明:

queue = tf.train.string_input_producer(filenames,
num_epochs=1) # only iterate through all samples in dataset once

reader = tf.TFRecordReader() # or any reader you need
_, example = reader.read(queue)

image, label = your_conversion_fn(example)

# batch will now load up to 100 image-label-pairs on sess.run(...)
# most tf ops are tuned to work on batches
# this is faster and also gives better result on e.g. gradient calculation
batch = tf.train.batch([image, label], batch_size=100)

with tf.Session() as sess:
# "boilerplate" code
sess.run([
tf.local_variables_initializer(),
tf.global_variables_initializer(),
])
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

try:
# in most cases coord.should_stop() will return True
# when there are no more samples to read
# if num_epochs=0 then it will run for ever
while not coord.should_stop():
# will start reading, working data from input queue
# and "fetch" the results of the computation graph
# into raw_images and raw_labels
raw_images, raw_labels = sess.run([images, labels])
finally:
coord.request_stop()
coord.join(threads)

关于machine-learning - TensorFlow:当批处理完成训练后,tf.train.batch 是否会自动加载下一批?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41673889/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com