gpt4 book ai didi

python - 如何保持多个 tensorflow 队列同步

转载 作者:太空宇宙 更新时间:2023-11-03 16:06:20 24 4
gpt4 key购买 nike

在我的tensorflor管道中,我创建了两个批处理队列:一个用于示例(图像),另一个用于标签(整数),本质上与 cifar10_input.py 相同。处理输入:

<read files>
image = tf.image.decode_png(file_contents)
label = tf.string_to_number(label_str, out_type=tf.int32)
# Batch examples here into two queues.
image_batch, label_batch = tf.train.shuffle_batch(
[image, label], ... )
model = build_model(image_batch)
loss = build_loss(model, label_batch)

模型然后将图像队列作为输入,损失评估模型输出和标签之间的差异。

我担心的是,如果我在一批图像上评估模型,标签队列将不再“对齐”,并且两个队列(图像和标签)将出现分歧。

(model_output, ) = sess.run([model])  # Uses the image queue as input
(model_output, true_labels) = sess.run([model, label_batch]) # Are image/labels pairs valid?

如何确保两个队列保持同步,以便从每个队列获取元素始终返回正确的图像/标签对?

最佳答案

Yaroslav 的评论是正确的 - 您必须确保它们始终一起出列。

如果您想以与训练步骤交织的方式评估非标记数据的模型,那么您可以实例化该模型的另一个副本(共享权重 - 设置重用 = true)并从您的备用的、非标记的数据源。如果您想同时进行训练和推理,那么这并不是一个不寻常的模式。

train_predict = model(train_input)
do something with the label here...
alternate:
use_predict = model(use_input, reuse=true)

(reuse=true 指的是变量范围)。

关于python - 如何保持多个 tensorflow 队列同步,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/39742185/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com