gpt4 book ai didi

python - TensorFlow 获取特定列的每一行的元素

转载 作者:太空狗 更新时间:2023-10-29 18:29:28 31 4
gpt4 key购买 nike

如果 A 是像这样的 TensorFlow 变量

A = tf.Variable([[1, 2], [3, 4]])

index是另一个变量

index = tf.Variable([0, 1])

我想使用这个索引来选择每行中的列。在这种情况下,第一行的项目 0 和第二行的项目 1。

如果 A 是一个 Numpy 数组,那么要获取索引中提到的相应行的列,我们可以这样做

x = A[np.arange(A.shape[0]), index]

结果是

[1, 4]

TensorFlow 的等效操作是什么?我知道 TensorFlow 不支持很多索引操作。如果不能直接完成,有什么解决办法?

最佳答案

您可以使用行索引扩展列索引,然后使用 gather_nd:

import tensorflow as tf

A = tf.constant([[1, 2], [3, 4]])
indices = tf.constant([1, 0])

# prepare row indices
row_indices = tf.range(tf.shape(indices)[0])

# zip row indices with column indices
full_indices = tf.stack([row_indices, indices], axis=1)

# retrieve values by indices
S = tf.gather_nd(A, full_indices)

session = tf.InteractiveSession()
session.run(S)

关于python - TensorFlow 获取特定列的每一行的元素,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/39684415/

31 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com