gpt4 book ai didi

python - 在 tf.nn.top_k 中加入 torch.topk 的 dim 参数

转载 作者:太空宇宙 更新时间:2023-11-04 04:34:08 25 4
gpt4 key购买 nike

Pytorch 提供 torch.topk(input, k, dim=None, largest=True, sorted=True) 函数来计算给定 k 最大元素>input 沿给定维度 dim 的张量。

我有一个形状为 (16, 512, 4096) 的张量,我正在按以下方式使用 torch.topk-

# inputs.shape (16L, 512L, 4096L)
dist, idx = torch.topk(inputs, 64, dim=2, largest=False, sorted=False)
# dist.shape (16L, 512L, 64L), idx.shape (16L, 512L, 64L)

我发现类似的 tensorflow 实现如下 - tf.nn.top_k(input, k=1, sorted=True, name=None)

我的问题是如何在tf.nn.top_k中加入dim=2参数,从而得到与pytorch计算的形状相同的张量?

最佳答案

tf.nn.top_k 处理输入的最后一个维度。这意味着它应该像您的示例一样工作:

dist, idx = tf.nn.top_k(inputs, 64, sorted=False)

一般来说,您可以想象 Tensorflow 版本的工作方式类似于 Pytorch 版本,但带有硬编码的 dim=-1,即最后一个维度。

但是看起来您实际上想要 k 个最小元素。在这种情况下我们可以做

dist, idx = tf.nn.top_k(-1*inputs, 64, sorted=False)
dist = -1*dist

所以我们采用 k 个最大的 negative 输入,它们是 k 个最小的原始输入。然后我们反转值的负值。

关于python - 在 tf.nn.top_k 中加入 torch.topk 的 dim 参数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52126579/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com