gpt4 book ai didi

python - Tensorflow:批处理输入队列然后更改队列源

转载 作者:太空狗 更新时间:2023-10-30 01:57:46 25 4
gpt4 key购买 nike

我有一个模型,它运行一组图像并使用它们计算一些统计数据 - 为简单起见,它只输出该组图像的平均图像(它在实践中做的比这更多)。我有许多包含图像的目录,我想从每个目录获取输出。每个目录中都有数量可变的图像。

我已经为我的脚本构造了一次图形、输出变量和损失函数。使用稍微调整的 code from here 对输入进行批处理.我对其进行了调整,使其采用一组路径,我使用可变大小的占位符将其输入。我得到了灵感 from here .

然后我遍历目录并运行以下命令:

  1. 初始化变量(这会根据先前目录的计算结果重置先前的输出变量)
  2. 将图像路径变量设置为新目录中的当前文件数组:sess.run(image_paths.initializer, feed_dict={image_paths_initializer: image_paths})
  3. 启动队列运行:queue_threads = tf.train.start_queue_runners(sess=sess, coord=coord)
  4. 运行多个 epoch 以获得结果
  5. 关闭线程 coord.request_stop(); coord.join(queue_threads); coord.clear_stop()
  6. 返回结果,保存结果,移动到下一个目录...

问题是,当涉及到第二个目录时,队列运行器线程拒绝启动(我可以通过调试 queue_threads 变量看到这一点)。这会产生如下错误:

Compute status: Aborted: FIFOQueue '_1_input_producer' is closed.
Compute status: Aborted: RandomShuffleQueue '_0_shuffle_batch/random_shuffle_queue' is closed.

如果我不关闭线程(并且不再次启动它们),那么它们就不会从新目录生成文件——它们会忽略 (2) 中的变量赋值操作。像这样重新启动队列是不可能的吗?

我已经尝试在他们自己的单独 session 中设置队列并从中提取批处理,但这给了我各种 CUDA/内存错误。如果我这样做并添加调试停止,我可以让它在它达到这个之前运行很远 - 但我不知道是否可以在不相交的 session /图形之间添加控制依赖性?

每个新目录都可以从头开始,但这会增加我试图避免的过程的大量开销。我已经在没有队列的情况下完成了类似的事情(即,重置变量并使用不同的输入重新运行)并且它节省了很多时间,所以我知道这个位有效。

你们这些优秀的 SO 人能想出解决办法吗?

最佳答案

string_input_producer 是一个 FIFOQueue + QueueRunner。如果您使用 FIFOQueue 并手动排队,您将获得更多控制权。像这样

filename_queue = tf.FIFOQueue(100, tf.string)
enqueue_placeholder = tf.placeholder(dtype=tf.string)
enqueue_op = filename_queue.enqueue(enqueue_placeholder)

config = tf.ConfigProto()
config.operation_timeout_in_ms=2000 # for debugging queue hangs
sess = tf.InteractiveSession(config=config)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)

sess.run([enqueue_op], feed_dict={enqueue_placeholder:"/temp/dir1/0"})
sess.run([enqueue_op], feed_dict={enqueue_placeholder:"/temp/dir1/1"})

# do stats for /temp/dir1

sess.run([enqueue_op], feed_dict={enqueue_placeholder:"/temp/dir2/0"})
sess.run([enqueue_op], feed_dict={enqueue_placeholder:"/temp/dir2/1"})

# do stats for /temp/dir2

coord.request_stop()
coord.join(threads)

关于python - Tensorflow:批处理输入队列然后更改队列源,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/36334371/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com