gpt4 book ai didi

tensorflow - 如何根据 tensorflow 中的某些谓词从队列中过滤张量?

转载 作者:行者123 更新时间:2023-12-04 23:18:31 24 4
gpt4 key购买 nike

如何使用谓词函数过滤存储在队列中的数据?例如,假设我们有一个存储特征和标签张量的队列,我们​​只需要满足谓词的那些。我尝试了以下实现但没有成功:

feature, label = queue.dequeue()
if (predicate(feature, label)):
enqueue_op = another_queue.enqueue(feature, label)

最佳答案

最直接的方法是出列一批,通过谓词测试运行它们,使用 tf.where 生成匹配谓词的密集向量,并使用 tf.gather 收集结果,并将该批次入队。如果你想让它自动发生,你可以在第二个队列上启动一个队列运行器 - 最简单的方法是使用 tf.train.batch :

例子:

import numpy as np
import tensorflow as tf

a = tf.constant(np.array([5, 1, 9, 4, 7, 0], dtype=np.int32))

q = tf.FIFOQueue(6, dtypes=[tf.int32], shapes=[])
enqueue = q.enqueue_many([a])
dequeue = q.dequeue_many(6)
predmatch = tf.less(dequeue, [5])
selected_items = tf.reshape(tf.where(predmatch), [-1])
found = tf.gather(dequeue, selected_items)

secondqueue = tf.FIFOQueue(6, dtypes=[tf.int32], shapes=[])
enqueue2 = secondqueue.enqueue_many([found])
dequeue2 = secondqueue.dequeue_many(3) # XXX, hardcoded

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(enqueue) # Fill the first queue
sess.run(enqueue2) # Filter, push into queue 2
print sess.run(dequeue2) # Pop items off of queue2

谓词产生一个 bool 向量; tf.where产生真值索引的密集向量,以及 tf.gather根据这些索引从原始张量中收集项目。

在这个例子中,很多东西都是硬编码的,当然,你需要在现实中进行非硬编码,但希望它显示了你正在尝试做的事情的结构(创建一个过滤管道)。在实践中,您希望 QueueRunners 在那里保持自动搅拌。使用 tf.train.batch自动处理非常有用——见 Threading and Queues了解更多详情。

关于tensorflow - 如何根据 tensorflow 中的某些谓词从队列中过滤张量?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/33903569/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com