gpt4 book ai didi

python - 如何将来自 tf.nn.top_k 的索引与 tf.gather_nd 一起使用?

转载 作者:太空宇宙 更新时间:2023-11-03 14:39:56 32 4
gpt4 key购买 nike

我正在尝试使用从 tf.nn.top_k 返回的索引从第二个张量中提取值。

我试过使用 numpy 类型索引,以及直接使用 tf.gather_nd,但我注意到索引是错误的。

#  temp_attention_weights of shape [I, B, 1, J]
top_values, top_indices = tf.nn.top_k(temp_attention_weights, k=top_k)

# top_indices of shape [I, B, 1, top_k], base_encoder_transformed of shape [I, B, 1, J]

# I now want to extract from base_encoder_transformed top_indices
base_encoder_transformed = tf.gather_nd(base_encoder_transformed, indices=top_indices)

# base_encoder_transformed should be of shape [I, B, 1, top_k]

我注意到 top_indices 的格式错误,但我似乎无法将其转换为在 tf.gather_nd 中使用,其中最内层的维度用于索引来自 base_encoder_transformed 的每个对应元素。有人知道将 top_indices 转换为正确格式的方法吗?

最佳答案

top_indices 只会在最后一个轴上建立索引,您也需要为其余轴添加索引。使用 tf.meshgrid 很容易:

import tensorflow as tf

# Example input data
I = 4
B = 3
J = 5
top_k = 2
x = tf.reshape(tf.range(I * B * J), (I, B, 1, J)) % 7
# Top K
top_values, top_indices = tf.nn.top_k(x, k=top_k)
# Make indices for the rest of axes
ii, jj, kk, _ = tf.meshgrid(
tf.range(I),
tf.range(B),
tf.range(1),
tf.range(top_k),
indexing='ij')
# Stack complete index
index = tf.stack([ii, jj, kk, top_indices], axis=-1)
# Get the same values again
top_values_2 = tf.gather_nd(x, index)
# Test
with tf.Session() as sess:
v1, v2 = sess.run([top_values, top_values_2])
print((v1 == v2).all())
# True

关于python - 如何将来自 tf.nn.top_k 的索引与 tf.gather_nd 一起使用?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54196149/

32 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com