gpt4 book ai didi

Tensorflow 索引到具有 1d 张量的 2d 张量

转载 作者:行者123 更新时间:2023-12-04 02:10:57 25 4
gpt4 key购买 nike

我有一个二维张量 A ,形状为 [batch_size, D] ,一维张量 B 形状为 [batch_size ]B 的每个元素都是 A 的列索引,对于 A 的每一行,例如。 B[i] in [0,D).

tensorflow 中获取值的最佳方法是什么 A[B]

例如:

A = tf.constant([[0,1,2],
[3,4,5]])
B = tf.constant([2,1])

具有所需的输出:

some_slice_func(A, B) -> [2,4]

还有一个限制。实际上,batch_size 实际上是 None

提前致谢!

最佳答案

我能够使用线性索引让它工作:

def vector_slice(A, B):
""" Returns values of rows i of A at column B[i]

where A is a 2D Tensor with shape [None, D]
and B is a 1D Tensor with shape [None]
with type int32 elements in [0,D)

Example:
A =[[1,2], B = [0,1], vector_slice(A,B) -> [1,4]
[3,4]]
"""
linear_index = (tf.shape(A)[1]
* tf.range(0,tf.shape(A)[0]))
linear_A = tf.reshape(A, [-1])
return tf.gather(linear_A, B + linear_index)

虽然这感觉有点老套。

如果有人知道更好(如更清晰或更快),请也留下答案! (暂时不接受自己的)

关于Tensorflow 索引到具有 1d 张量的 2d 张量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/38492608/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com