gpt4 book ai didi

python - 为 tf.split() 使用 num_splits 变量

转载 作者:太空狗 更新时间:2023-10-29 20:12:14 25 4
gpt4 key购买 nike

是否可以为 tf.split() 的 num_split 参数使用占位符输入?

理想情况下,我想做这样的事情:

num_splits = tf.placeholder(tf.int32)
inputs = tf.placeholder(tf.int32, [5, None])
split_inputs = tf.split(1, num_splits, inputs)

TypeError: Expected int for argument 'num_split' not .

我的方法可能有问题。我希望枚举可变形状张量中的一个维度。谢谢!

最佳答案

核心图操作有一个“张量输入-张量输出”的一般哲学,因此如果您可以重组计算以处理可变大小的单个张量而不是可变数量的张量,它可能会简化事情。

packunpacksplit 等操作处理多个张量,但它们在执行期间编译为“tensor-in/tensor-out”操作图构建时间,这就是为什么 num_splits 需要修复的原因。 dynamic_partitiondynamic_stitchdequeue_many 等操作接管了具有变量 0 维度的单个张量的部分功能.

如果您确实需要处理可变数量的张量,典型的方法是将计算分解为多个 session.run 调用,每个 run 调用一个输入张量,并且使用队列将事物联系在一起。有一个 slice_input_producer 沿第 0 维拆分可变大小的输入并为每一行生成一个张量,因此如果你想在 myfunction 的每一行上循环评估 inputs 你可以这样做

def myfunction(vector):
result = tf.reduce_sum(vector)
print_result = tf.Print(result, [result], "myfunction called ")
return print_result

MAX_ROWS = 10

# input matrix with 2 columns and unknown number of rows (<MAX_ROWS)
inputs = tf.placeholder(tf.int32, [None, 2])
# copy of inputs, will need to have a persistent copy of it because we will
# be fetching rows in different session.run calls
data = tf.Variable(inputs, validate_shape=False)
# input producer that iterates over the rows and pushes them onto Queue
row = tf.train.slice_input_producer([data], num_epochs=1, shuffle=False)[0]
myfunction_op = myfunction(row)

# this op will save placeholder values into the variable
init_op = tf.initialize_all_variables()

# Coordinator is not necessary in this case, but you'll need it if you have
# more than one Queue in order to close all queues together
sess = tf.Session()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

sess.run([init_op], feed_dict={inputs:[[0, 0], [1, 1], [2, 2]]})

try:
for i in range(MAX_ROWS):
sess.run([myfunction_op])
except tf.errors.OutOfRangeError:
print('Done iterating')
finally:
# When done, ask other threads to stop.
coord.request_stop()

如果你运行这个,你应该看到

myfunction called [0]
myfunction called [2]
myfunction called [4]
Done iterating

关于python - 为 tf.split() 使用 num_splits 变量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/34970582/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com