gpt4 book ai didi

python - TF 中的索引操作

转载 作者:行者123 更新时间:2023-12-01 08:52:18 24 4
gpt4 key购买 nike

有没有办法在 tensorflow 中索引操作?特别是,我对通过 tf.while_loop 的迭代器变量进行索引感兴趣。

更具体地说,假设我有 my_ops = [op1, op2]。我想要:

my_ops = [...]
i = tf.constant(0)
c = lambda i: tf.less(i, 10)
b = lambda i: my_ops[i](...)
r = tf.while_loop(c, b, [i])

不幸的是,这不起作用,因为 python 数组仅支持整数索引。

最佳答案

我相信这是不可能的。但是,您可以改为使用 tf.stack堆叠操作的输出张量,然后使用 tf.gather以获得所需的输出。

这里有一个例子:

import tensorflow as tf


def condition(i, x):
return tf.less(i, 10)


def body_1(my_ops):
def b(i, x):
stacked_results = tf.stack([op(x) for op in my_ops])
gather_idx = tf.mod(i, 2)
return [i + 1, tf.gather(stacked_results, gather_idx)]

return b


def body_2(my_ops):
def b(i, x):
nb_ops = len(my_ops)
pred_fn_pairs = [(tf.equal(tf.mod(i, nb_ops), 0), lambda: my_ops[0](x)),
(tf.equal(tf.mod(i, nb_ops), 1), lambda: my_ops[1](x))]
result = tf.case(pred_fn_pairs)
return [i + 1, result]

return b


my_ops = [lambda x: tf.Print(x + 1, [x, 1]),
lambda x: tf.Print(x + 2, [x, 2])]
i = tf.constant(0)
x = tf.constant(0)
r = tf.while_loop(condition, body_2(my_ops), [i, x]) # See the difference with body_1

with tf.Session() as sess:
i, x = sess.run(r)
print(x) # Prints 15 = 5*2 + 5*1

关于python - TF 中的索引操作,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53033780/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com