gpt4 book ai didi

pytorch - PyTorch 中的自定义损失函数

转载 作者:行者123 更新时间:2023-12-04 00:59:52 29 4
gpt4 key购买 nike

我有三个简单的问题。

  • 如果我的自定义损失函数不可微会发生什么? pytorch 会通过错误还是做其他事情?
  • 如果我在我的自定义函数中声明了一个损失变量来表示模型的最终损失,我应该放 requires_grad = True对于那个变量?或者没关系?如果没关系,那为什么?
  • 我看到人们有时会编写一个单独的层并计算 forward 中的损失。功能。哪种方法更可取,编写函数或层?为什么?

  • 我需要对这些问题进行清晰而漂亮的解释来解决我的困惑。请帮忙。

    最佳答案

    让我去试试。

  • 这取决于您所说的“不可微分”是什么意思。这里第一个有意义的定义是 PyTorch 不知道如何计算梯度。尽管如此,如果您尝试计算梯度,则会引发错误。两种可能的情况是:

    a) 您正在使用尚未实现渐变的自定义 PyTorch 操作,例如torch.svd() .在这种情况下,您将收到 TypeError :
    import torch
    from torch.autograd import Function
    from torch.autograd import Variable

    A = Variable(torch.randn(10,10), requires_grad=True)
    u, s, v = torch.svd(A) # raises TypeError

    b) 您已经实现了自己的操作,但没有定义 backward() .在这种情况下,你会得到一个 NotImplementedError :
    class my_function(Function): # forgot to define backward()

    def forward(self, x):
    return 2 * x

    A = Variable(torch.randn(10,10))
    B = my_function()(A)
    C = torch.sum(B)
    C.backward() # will raise NotImplementedError

    第二个有意义的定义是“数学上不可微分”。显然,数学上不可微的运算不应该有 backward()实现的方法或合理的子梯度。考虑例如 torch.abs()谁的backward()方法返回 0 处的次梯度 0:
    A = Variable(torch.Tensor([-1,0,1]),requires_grad=True)
    B = torch.abs(A)
    B.backward(torch.Tensor([1,1,1]))
    A.grad.data

    对于这些情况,您应该直接引用 PyTorch 文档并挖掘出 backward()直接进行相应操作的方法。
  • 没关系。 requires_grad的使用是为了避免对子图进行不必要的梯度计算。如果需要梯度的操作的单个输入,则其输出也需要梯度。相反,只有当所有输入都不需要梯度时,输出也不需要它。子图中从不执行反向计算,其中所有变量都不需要梯度。

    因为,很可能有一些 Variables (例如 nn.Module() 子类的参数),您的 loss变量也将自动需要渐变。但是,您应该注意到 requires_grad有效(再次参见上文),您只能更改 requires_grad无论如何,对于图形的叶变量。
  • 所有自定义 PyTorch 损失函数都是 _Loss 的子类这是 nn.Module 的子类. See here.如果你想坚持这个约定,你应该继承 _Loss在定义自定义损失函数时。除了一致性之外,一个优点是您的子类将引发 AssertionError ,如果您尚未将目标变量标记为 volatilerequires_grad = False .另一个优点是您可以将损失函数嵌套在 nn.Sequential() 中。 ,因为它是 nn.Module由于这些原因,我会推荐这种方法。
  • 关于pytorch - PyTorch 中的自定义损失函数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44597523/

    29 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com