gpt4 book ai didi

python - 使用 Tensorflow 2.0 中的另一个张量索引张量的第 k 维

转载 作者:行者123 更新时间:2023-12-04 09:44:37 25 4
gpt4 key购买 nike

我有一个张量 probs具有形状 (None, None, 110)代表(batch_size, sequence_length, 110)在 LSTM 中。
我有另一个张量 indices具有形状 (None, None) ,其中包含要从 probs 的第三维中选择的元素的索引.

我想用indices索引张量 probs .

Numpy 等价物:

k, j = np.meshgrid(np.arange(probs.shape[1]), np.arange(probs.shape[0]))
indexed_probs = probs[j, k, indices]

shape[0]shape[1]probs未知, tf.meshgrid()不是一个选择。
我找到了 tf.gather , tf.gather_ndtf.batch_gather ,但他们似乎都没有做我想做的事。

有人知道怎么做这个吗?

最佳答案

你可以用 tf.gather_nd 做到这一点像这样:

indexed_probs = tf.gather_nd(probs, tf.expand_dims(indices, axis=-1), batch_dims=2)

顺便说一下,在 NumPy 中你可以使用 np.take_along_axis 做同样的事情:

indexed_probs = np.take_along_axis(probs, np.expand_dims(indices, axis=-1), axis=-1)[..., 0]

关于python - 使用 Tensorflow 2.0 中的另一个张量索引张量的第 k 维,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62191509/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com