gpt4 book ai didi

python - PyTorch:将向量的所有元素归零,除了前 k 个?

转载 作者:行者123 更新时间:2023-12-05 04:50:58 26 4
gpt4 key购买 nike

我正在尝试创建一个新的激活层,我们称它为 topk,它将按如下方式工作。它将以大小为 n 的向量 x 作为输入(将前一层输出乘以权重矩阵并添加偏差的结果)和一个正整数 k,并将输出大小为 n 的向量 topk(x),其元素为:

              x_i (if x_i is one of the top k elements of x) 
topk(x)_i =
0 (otherwise)

在计算topk(x)的梯度时,x的前k个元素的梯度应该是1,其他都是0。

我应该如何实现?任何帮助将不胜感激。

最佳答案

您可以为此使用 torch.topk:

k = 2
output = torch.randn(5)
vals, idx = output.topk(k)

topk = torch.zeros_like(output)
topk[idx] = vals
>>> topk
tensor([1.0557, 0.0000, 0.0000, 1.4562, 0.0000])

请注意,虽然 topk()'values' 是可微分的,但 'indices' are not(类似于 argmax 不可微分的方式)一个可微函数)。

关于python - PyTorch:将向量的所有元素归零,除了前 k 个?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/67099961/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com