gpt4 book ai didi

python - 在 TensorFlow 中索引子张量

转载 作者:太空宇宙 更新时间:2023-11-03 15:44:20 25 4
gpt4 key购买 nike

我想通过将index应用于输入张量的子张量来获得输出张量。

例如在 NumPy 中,

import numpy as np

input = np.random.random((100,5)) # matrix
index = np.randint(5, size=(100,)) # vector
output = data[np.arange(index.shape[0]), index] # vector

给我想要的输出(我想要这个的符号版本)。

与 Theano 类似,

import theano.tensor as T
import theano

input = T.matrix() # symbolic matrix
index = T.ivector() # symbolic vector
output = input[T.arange(index.shape[0]), index] # symbolic vector

给了我想要的输出

如何在 TensorFlow 中执行此操作?

import tensorflow as tf
input = tf.placeholder('float32', [None, 5])
index = tf.placeholder('int32', [None])
output = ???

与 NumPy 的示例不同,index 的长度(=input 的第一个维度)不固定。

最佳答案

您可以使用 tf.gather_nd 进行切片:

output = tf.gather_nd(input, tf.stack((tf.range(tf.shape(index)[0]), index), -1))

关于python - 在 TensorFlow 中索引子张量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41892636/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com