gpt4 book ai didi

python - 如何使用索引有效地获取张量中每一行的值?

转载 作者:行者123 更新时间:2023-12-01 07:47:25 24 4
gpt4 key购买 nike

我有一个名为 my_tensor 的张量,其形状为 [batch_size, seq_length],我还有另一个名为 idx 的张量,其形状为 [batch_size, 1],即由从 0 开始并以“seq_length”结束的索引组成。

我想使用 idx 中定义的索引提取 my_tensor 每一行中的值。

我尝试使用tf.gather_ndtf.gather但没有成功。

考虑以下示例:

batch_size = 3
seq_length = 5
idx = [2, 0, 4]

my_tensor = tf.random.uniform(shape=(batch_size, seq_length))

我想获取

处的值
[[0, 2],
[1, 0],
[3, 4]]

来自 my_tensor。

我必须对它们进行进一步的处理,所以我希望以一种有效的方式同时拥有它们(我不知道是否可能);但是,我想不出任何其他方法。

感谢您的帮助:)

最佳答案

诀窍是首先将您的索引集转换为 bool 掩码,然后您可以使用它来减少my_tensor,正如您使用 boolean_mask 所描述的那样操作。

您可以通过 one-hot encoding 来完成此操作idx 张量。

因此,在 idx = [2, 0, 4] 的情况下,我们可以执行 tf.one_hot(idx, seq_length) 将其转换为如下所示:

[ [0., 0., 1., 0., 0.],
[1., 0., 0., 0., 0.],
[0., 0., 0., 0., 1.] ]

然后,将它们放在一起,例如 my_tensor:

[ [0.6413697 , 0.4079175 , 0.42499018, 0.3037368 , 0.8580252 ],
[0.8698617 , 0.29096508, 0.11531639, 0.25421357, 0.5844104 ],
[0.6442119 , 0.31816053, 0.6245482 , 0.7249261 , 0.7595779 ] ]

我们可以按如下方式进行:

result = tf.boolean_mask(my_tensor, tf.one_hot(idx,seq_length))

给予:

[0.42499018, 0.8698617 , 0.7595779 ]

正如预期

关于python - 如何使用索引有效地获取张量中每一行的值?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56402768/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com