gpt4 book ai didi

python - 如何在 Tensorflow 中用 2-D 张量索引 3-D 张量?

转载 作者:行者123 更新时间:2023-12-01 02:04:11 25 4
gpt4 key购买 nike

我正在尝试使用 2-D 张量在 Tensorflow 中索引 3-D 张量。例如,我有形状为 [2, 3, 4]x:

[[[ 0,  1,  2,  3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],

[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]]

我想用形状 [2, 3] 的另一个张量 y 对其进行索引,其中 y 的每个元素索引最后一个元素x 的维度。例如,如果我们有 y ,例如:

[[0, 2, 3],
[1, 0, 2]]

输出的形状应为[2, 3]:

[[0, 6, 11],
[13, 16, 22]]

最佳答案

使用tf.meshgrid创建索引,然后使用tf.gather_nd提取元素:

# create a list of indices for except the last axis
idx_except_last = tf.meshgrid(*[tf.range(s) for s in x.shape[:-1]], indexing='ij')

# concatenate with last axis indices
idx = tf.stack(idx_except_last + [y], axis=-1)

# gather elements based on the indices
tf.gather_nd(x, idx).eval()

# array([[ 0, 6, 11],
# [13, 16, 22]])

关于python - 如何在 Tensorflow 中用 2-D 张量索引 3-D 张量?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49226523/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com