gpt4 book ai didi

deep-learning - 在 pytorch 中获取参数与损失项的梯度

转载 作者:行者123 更新时间:2023-12-03 08:15:52 26 4
gpt4 key购买 nike

我的 Pytorch 训练使用复合损失函数,定义为: enter image description here 。为了更新权重 alpha 和 beta,我需要计算三个值:enter image description here这是网络中所有权重的损失项梯度的平均值。

有没有一种有效的方法可以在pytorch中编写它?

我的训练代码如下所示:

for epoc in range(1, nb_epochs+1):
#init
optimizer_fo.zero_grad()
#get the current loss
loss_total = mynet_fo.loss(tensor_xy_dirichlet,g_boundaries_d,tensor_xy_inside,tensor_f_inter,tensor_xy_neuman,g_boundaries_n)
#compute gradients
loss_total.backward(retain_graph=True)
#optimize
optimizer_fo.step()

我的 .loss() 函数直接返回各项之和。我考虑过进行第二次前向传递并独立地向后调用每个损失项,但这会非常昂贵。

最佳答案

1-使用torch.autograd.grad

您只能通过在网络上多次反向传播来获得梯度的不同项。为了避免对输入执行多次推理,您可以使用 torch.autograd.grad实用函数而不是执行传统的向后传递向后。这意味着您不会污染来自不同项的梯度。

这是一个展示基本思想的最小示例:

>>> x = torch.rand(1, 10, requires_grad=True)
>>> lossA = x.pow(2).sum()
>>> lossB = x.mean()

然后对每个不合适的项执行一次向后传递。您必须保留除最后一次之外的所有调用的图表:

>>> gradA = torch.autograd.grad(lossA, x, retain_graph=True)
(tensor([[1.5810, 0.6684, 0.1467, 0.6618, 0.5067, 0.2368, 0.0971, 0.4533, 0.3511,
1.9858]]),)

>>> gradB = torch.autograd.grad(lossB, x)
(tensor([[0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
0.1000]]),)

此方法有一些限制,因为您以元组形式接收参数的梯度,这不太方便。


2-缓存向后的结果

另一种解决方案是在每次连续的向后调用后缓存渐变:

>>> lossA = x.pow(2).sum()
>>> lossB = x.mean()

>>> lossA.backward(retain_graph=True)

存储梯度并清除.grad属性(不要忘记这样做,否则lossA的梯度会污染gradB。处理多个张量参数时,您必须使其适应一般情况:

>>> x.gradA = x.grad
>>> x.grad = None

向后传递下一个损失项:

>>> lossB.backward()
>>> x.gradB = x.grad

然后您可以在本地与每个梯度项进行交互(分别针对每个参数):

>>> x.gradA, x.gradB
(tensor([[1.5810, 0.6684, 0.1467, 0.6618, 0.5067, 0.2368, 0.0971, 0.4533, 0.3511,
1.9858]]),
tensor([[0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
0.1000]]))

后一种方法似乎更实用。


这本质上可以归结为torch.autograd.gradtorch.autograd.backward不合适 vs 就地...最终取决于您的需求。您可以阅读有关这两个函数的更多信息 here .

关于deep-learning - 在 pytorch 中获取参数与损失项的梯度,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/69448198/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com