gpt4 book ai didi

python - 如何在 tensorflow 中收集带有索引的元素

转载 作者:行者123 更新时间:2023-12-01 01:42:58 25 4
gpt4 key购买 nike

例如,

import tensorflow as tf

index = tf.constant([[1],[1]])
values = tf.constant([[0.2, 0.8],[0.4, 0.6]])

如果我使用extract = tf.gather_nd(values, index)返回的是

[[0.4 0.6]
[0.4 0.6]]

但是,我想要的结果是

[[0.8], [0.6]]

其中索引沿 axis = 1,但是 tf.gather_nd 中没有轴参数设置。

我该怎么办?谢谢!

最佳答案

将范围连接到索引:

index = tf.stack([tf.range(index.shape[0])[:, None], index], axis=2)
result = tf.gather_nd(values, index)

result.eval(session=tf.Session())
array([[0.8],
[0.6]], dtype=float32)

关于python - 如何在 tensorflow 中收集带有索引的元素,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51690095/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com