gpt4 book ai didi

neural-network - 变量的backward() 方法中的参数retain_graph 是什么意思?

转载 作者:行者123 更新时间:2023-12-03 10:11:16 25 4
gpt4 key购买 nike

我正在通过 neural transfer pytorch tutorial并且对 retain_variable 的使用感到困惑(已弃用,现在称为 retain_graph )。代码示例显示:

class ContentLoss(nn.Module):

def __init__(self, target, weight):
super(ContentLoss, self).__init__()
self.target = target.detach() * weight
self.weight = weight
self.criterion = nn.MSELoss()

def forward(self, input):
self.loss = self.criterion(input * self.weight, self.target)
self.output = input
return self.output

def backward(self, retain_variables=True):
#Why is retain_variables True??
self.loss.backward(retain_variables=retain_variables)
return self.loss

来自 the documentation

retain_graph (bool, optional) – If False, the graph used to compute the grad will be freed. Note that in nearly all cases setting this option to True is not needed and often can be worked around in a much more efficient way. Defaults to the value of create_graph.



所以通过设置 retain_graph= True ,我们不会释放为后向传递的图形分配的内存。保留这个内存有什么好处,我们为什么需要它?

最佳答案

@cleros 关于 retain_graph=True 的使用非常重要.本质上,它会保留计算某个变量所需的任何信息,以便我们可以对其进行反向传递。

一个说明性的例子

enter image description here

假设我们有一个上面显示的计算图。变量 de是输出,a是输入。例如,

import torch
from torch.autograd import Variable
a = Variable(torch.rand(1, 4), requires_grad=True)
b = a**2
c = b*2
d = c.mean()
e = c.sum()

当我们这样做时 d.backward() , 那没关系。经过这个计算,计算图的部分 d默认情况下将被释放以节省内存。所以如果我们这样做 e.backward() ,会弹出错误信息。为了做 e.backward() ,我们必须设置参数 retain_graphTrued.backward() , IE。,
d.backward(retain_graph=True)

只要您使用 retain_graph=True在您的向后方法中,您可以随时向后执行:
d.backward(retain_graph=True) # fine
e.backward(retain_graph=True) # fine
d.backward() # also fine
e.backward() # error will occur!

更多有用的讨论可以找到 here .

一个真实的用例

现在,一个真正的用例是多任务学习,其中您有多个损失,这些损失可能位于不同的层。假设您有 2 次损失: loss1loss2并且它们位于不同的层中。为了反向传播 loss1的梯度和 loss2 w.r.t 独立于网络的可学习权重。您必须使用 retain_graph=Truebackward()第一个反向传播损失的方法。
# suppose you first back-propagate loss1, then loss2 (you can also do the reverse)
loss1.backward(retain_graph=True)
loss2.backward() # now the graph is freed, and next process of batch gradient descent is ready
optimizer.step() # update the network parameters

关于neural-network - 变量的backward() 方法中的参数retain_graph 是什么意思?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46774641/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com