gpt4 book ai didi

python - PyTorch - 被覆盖的变量是否保留在图中?

转载 作者:行者123 更新时间:2023-11-28 19:00:54 25 4
gpt4 key购买 nike

我想知道覆盖 Python 变量的 PyTorch 张量是否仍保留在 PyTorch 的计算图中。


所以这是一个小例子,我有一个 RNN 模型,其中隐藏状态(和一些其他变量)在每次迭代后重置,稍后调用 backward()< p>

示例:

for i in range(5):
output = rnn_model(inputs[i])
loss += criterion(output, target[i])
## hidden states are overwritten with a zero vector
rnn_model.reset_hidden_states()
loss.backward()

所以我的问题是:

  • 在调用 backward() 之前覆盖隐藏状态是否有问题?

  • 或者计算图是否将先前迭代的隐藏状态的必要信息保存在内存中以计算梯度?<​​/p>

  • 编辑:最好能有官方来源的声明。例如声明所有与 CG 相关的变量都被保留——不管是否还有其他 python 引用这个变量。我假设图表本身有一个引用阻止垃圾收集器删除它。但我想知道是否真的如此。

提前致谢!

最佳答案

我觉得先reset再倒退是可以的。该图保留了所需的信息。

class A (torch.nn.Module):
def __init__(self):
super().__init__()
self.f1 = torch.nn.Linear(10,1)
def forward(self, x):
self.x = x
return torch.nn.functional.sigmoid (self.f1(self.x))
def reset_x (self):
self.x = torch.zeros(self.x.shape)
net = A()
net.zero_grad()
X = torch.rand(10,10)
loss = torch.nn.functional.binary_cross_entropy(net(X), torch.ones(10,1))
loss.backward()
params = list(net.parameters())
for i in params:
print(i.grad)
net.zero_grad()

loss = torch.nn.functional.binary_cross_entropy(net(X), torch.ones(10,1))
net.reset_x()
print (net.x is X)
del X
loss.backward()
params = list(net.parameters())
for i in params:
print(i.grad)

在上面的代码中,我打印了带/不带重置输入 x 的梯度。梯度肯定取决于 x 并且重置它并不重要。因此,我认为图保留了信息来做反向操作。

关于python - PyTorch - 被覆盖的变量是否保留在图中?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52755805/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com