gpt4 book ai didi

parallel-processing - Pytorch softmax沿着不同的掩码没有for循环

转载 作者:行者123 更新时间:2023-12-02 04:26:17 26 4
gpt4 key购买 nike

假设我有一个向量 a ,其索引向量 b 的长度相同。索引范围为0~N-1,对应N组。如何在没有 for 循环的情况下对每个组执行 softmax?

我在这里进行某种注意力操作。每个组的数字都不相同,所以我无法将 a reshape 为矩阵并在标准 Softmax() 中使用 dim API。

玩具示例:

a = torch.rand(10)
a: tensor([0.3376, 0.0557, 0.3016, 0.5550, 0.5814, 0.1306, 0.2697, 0.9989, 0.4917,
0.6306])
b = torch.randint(0,3,(1,10), dtype=torch.int64)
b: tensor([[1, 2, 0, 2, 2, 0, 1, 1, 1, 1]])

我想像这样做softmax

for index in range(3):
softmax(a[b == index])

但没有 for 循环以节省时间。

最佳答案

也许这个答案必须根据对我的评论的潜在回应略有改变,但我只是继续前进并投入我对 Softmax 的两分钱。

一般来说,softmax 的公式在 PyTorch documentation 中有很好的解释。 ,我们可以看到这是当前值的指数,除以所有类的总和。
这样做的原因是基于概率论,可能有点超出我的舒适范围,但本质上它可以帮助您维护一个相当简单的反向传播导数,当它与一种称为“交叉熵损失”的流行损失策略结合使用时(CE)(参见 PyTorch 中的相应函数 here

此外,您还可以在 CE 的描述中看到它自动组合了两个函数,即 softmax 函数的(数值稳定的)版本,以及负对数似然损失(NLLL ).

现在,回到您最初的问题,并希望能解决您的问题:
为了这个问题——以及你提出问题的方式——你似乎在玩流行的 MNIST handrwitten 数字数据集,我们想在其中预测你当前输入图像的一些值。

我还假设您的输出 a 将在某个时候成为神经网络层的输出。这是否被压缩到特定范围并不重要(例如,通过应用某种形式的激活函数),因为 softmax 基本上是一个归一化。具体来说,如前所述,它将为我们提供所有预测值的某种分布形式,并且所有类别的总和为 1。为此,我们可以简单地应用类似

soft_a = softmax(a, dim=0) # otherwise throws error if we don't specify axis
print(torch.sum(soft_a)) # should return "Tensor(1.)"

现在,如果我们假设您想做“经典”MNIST 示例,您可以使用 argmax() 函数来预测您的系统认为哪个值是正确答案,并计算基于此的错误,例如,使用 nn.NLLLoss() 函数。

如果您确实要预测单个输出中每个位置的值,则您必须对此略有不同。
首先,softmax() 在这里不再有意义,因为您正在计算概率分布跨多个输出,除非您相当确定它们的分布取决于以一种非常具体的方式相互交流,我认为这里不是这种情况。

此外,请记住,您随后要计算成对损失,即输出的每个索引的损失。为此特定目的想到的功能是 nn.BCELoss() ,计算交叉熵的二值化(逐元素)版本。为此,您可以简单地“支持”您的原始预测张量 a,以及您的基本事实张量 b。一个最小的例子如下所示:

bce = torch.nn.BCELoss(reduction="none") # to keep losses for each element separate
loss = bce(a,b) # returns tensor with respective pairwise loss

如果您对单个损失感兴趣,您显然可以使用 BCELoss 和不同的 reduction 参数,如文档中所述。如果我可以为您澄清部分答案,请告诉我。

编辑:这里还有一点要记住:BCELoss() 要求您输入可能接近您想要预测的值的值。如果您首先将值输入激活函数(例如 sigmoid 或 tanh),这尤其是一个问题,然后激活函数永远不会达到您想要预测的值,因为它们受区间限制!

关于parallel-processing - Pytorch softmax沿着不同的掩码没有for循环,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54284077/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com