gpt4 book ai didi

python - Tensorflow:如何像在 numpy 中一样使用 2D 索引对张量进行索引

转载 作者:太空狗 更新时间:2023-10-30 02:08:28 25 4
gpt4 key购买 nike

我想在 Tensorflow 中执行以下 numpy 代码:

input = np.array([[1,2,3]
[4,5,6]
[7,8,9]])
index1 = [0,1,2]
index2 = [2,2,0]
output = input[index1, index2]
>> output
[3,6,7]

给定一个输入,例如:

input = tf.constant([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])

我尝试了以下方法,但似乎有点过头了:

index3 = tf.range(0, input.get_shape()[0])*input.get_shape()[1] + index2
output = tf.gather(tf.reshape(input, [-1]), index3)
sess = tf.Session()
sess.run(output)
>> [3,6,7]

这之所以有效,是因为我的第一个索引很方便 [0,1,2],但对于 [0,0,2] 是不可行的(除了看起来又长又丑)。

你有没有更简单的语法,更张量/pythonic 的东西?

最佳答案

您可以使用 tf.gather_nd (tf.gather_nd official doc)如下:

import tensorflow as tf
inp = tf.constant([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
res=tf.gather_nd(inp,list(zip([0,1,2],[2,2,0])))
sess = tf.Session()
sess.run(res)

结果是array([3, 6, 7])

关于python - Tensorflow:如何像在 numpy 中一样使用 2D 索引对张量进行索引,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43145200/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com