- html - 出于某种原因,IE8 对我的 Sass 文件中继承的 html5 CSS 不友好?
- JMeter 在响应断言中使用 span 标签的问题
- html - 在 :hover and :active? 上具有不同效果的 CSS 动画
- html - 相对于居中的 html 内容固定的 CSS 重复背景?
我正在通过 neural transfer pytorch tutorial并且对 retain_variable
的使用感到困惑(已弃用,现在称为 retain_graph
)。代码示例显示:
class ContentLoss(nn.Module):
def __init__(self, target, weight):
super(ContentLoss, self).__init__()
self.target = target.detach() * weight
self.weight = weight
self.criterion = nn.MSELoss()
def forward(self, input):
self.loss = self.criterion(input * self.weight, self.target)
self.output = input
return self.output
def backward(self, retain_variables=True):
#Why is retain_variables True??
self.loss.backward(retain_variables=retain_variables)
return self.loss
retain_graph (bool, optional) – If False, the graph used to compute the grad will be freed. Note that in nearly all cases setting this option to True is not needed and often can be worked around in a much more efficient way. Defaults to the value of create_graph.
retain_graph= True
,我们不会释放为后向传递的图形分配的内存。保留这个内存有什么好处,我们为什么需要它?
最佳答案
@cleros 关于 retain_graph=True
的使用非常重要.本质上,它会保留计算某个变量所需的任何信息,以便我们可以对其进行反向传递。
一个说明性的例子
假设我们有一个上面显示的计算图。变量 d
和 e
是输出,a
是输入。例如,
import torch
from torch.autograd import Variable
a = Variable(torch.rand(1, 4), requires_grad=True)
b = a**2
c = b*2
d = c.mean()
e = c.sum()
d.backward()
, 那没关系。经过这个计算,计算图的部分
d
默认情况下将被释放以节省内存。所以如果我们这样做
e.backward()
,会弹出错误信息。为了做
e.backward()
,我们必须设置参数
retain_graph
至
True
在
d.backward()
, IE。,
d.backward(retain_graph=True)
retain_graph=True
在您的向后方法中,您可以随时向后执行:
d.backward(retain_graph=True) # fine
e.backward(retain_graph=True) # fine
d.backward() # also fine
e.backward() # error will occur!
loss1
和
loss2
并且它们位于不同的层中。为了反向传播
loss1
的梯度和
loss2
w.r.t 独立于网络的可学习权重。您必须使用
retain_graph=True
在
backward()
第一个反向传播损失的方法。
# suppose you first back-propagate loss1, then loss2 (you can also do the reverse)
loss1.backward(retain_graph=True)
loss2.backward() # now the graph is freed, and next process of batch gradient descent is ready
optimizer.step() # update the network parameters
关于neural-network - 变量的backward() 方法中的参数retain_graph 是什么意思?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46774641/
我知道,在使用 loss.backward() 时,如果有多个网络和多个损失函数来分别优化每个网络,我们需要指定 retain_graph=True .但即使有(或没有)指定此参数,我也会收到错误。以
我是 Python 和 PyTorch 的学生和初学者。我有一个非常基本的神经网络,我遇到了上面提到的 RunTimeError。重现错误的代码是这样的: import torch from torc
我是一名优秀的程序员,十分优秀!