gpt4 book ai didi

pytorch - 如何在不破坏反向传播的情况下为 pytorch 变量分配新值?

转载 作者:行者123 更新时间:2023-12-03 13:33:15 35 4
gpt4 key购买 nike

我有一个 pytorch 变量,用作模型的可训练输入。在某些时候,我需要手动重新分配此变量中的所有值。
如何在不破坏与损失函数的连接的情况下做到这一点?
假设当前值为 [1.2, 3.2, 43.2]我只是想让他们成为[1,2,3] .

编辑
在我问这个问题的时候,我还没有意识到 PyTorch 没有像 Tensorflow 或 Keras 那样的静态图。
在 PyTorch 中,训练循环是手动进行的,您需要在每个训练步骤中调用所有内容。 (对于以后的馈送数据,没有占位符 + 静态图的概念)。
因此,我们不能“破坏图”,因为我们将使用新变量再次执行所有进一步的计算。我担心发生在 Keras 而不是 PyTorch 中的问题。

最佳答案

您可以使用 data用于修改值的张量属性,因为在 data 上进行了修改不影响图表。所以图表仍然是完整的,并且 data 的修改属性本身对图没有影响。 (data 上的操作和更改不会被 autograd 跟踪,因此不会出现在图表中)

由于您没有给出示例,因此此示例基于您的评论声明:“假设我想更改图层的权重。”
我在这里使用了普通张量,但这对 weight.data 的效果相同。和 bias.data图层的属性。

这是一个简短的例子:

import torch
import torch.nn.functional as F



# Test 1, random vector with CE
w1 = torch.rand(1, 3, requires_grad=True)
loss = F.cross_entropy(w1, torch.tensor([1]))
loss.backward()
print('w1.data', w1)
print('w1.grad', w1.grad)
print()

# Test 2, replacing values of w2 with w1, before CE
# to make sure that everything is exactly like in Test 1 after replacing the values
w2 = torch.zeros(1, 3, requires_grad=True)
w2.data = w1.data
loss = F.cross_entropy(w2, torch.tensor([1]))
loss.backward()
print('w2.data', w2)
print('w2.grad', w2.grad)
print()

# Test 3, replace data after computation
w3 = torch.rand(1, 3, requires_grad=True)
loss = F.cross_entropy(w3, torch.tensor([1]))
# setting values
# the graph of the previous computation is still intact as you can in the below print-outs
w3.data = w1.data
loss.backward()

# data were replaced with values from w1
print('w3.data', w3)
# gradient still shows results from computation with w3
print('w3.grad', w3.grad)

输出:
w1.data tensor([[ 0.9367,  0.6669,  0.3106]])
w1.grad tensor([[ 0.4351, -0.6678, 0.2326]])

w2.data tensor([[ 0.9367, 0.6669, 0.3106]])
w2.grad tensor([[ 0.4351, -0.6678, 0.2326]])

w3.data tensor([[ 0.9367, 0.6669, 0.3106]])
w3.grad tensor([[ 0.3179, -0.7114, 0.3935]])

这里最有趣的部分是 w3 .当时 backward被称为值替换为 w1 的值.但是梯度是基于 CE 函数计算的,其值为原始 w3 .替换的值对图表没有影响。
所以图连接没有断开,替换对图没有影响。我希望这就是你要找的!

关于pytorch - 如何在不破坏反向传播的情况下为 pytorch 变量分配新值?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53819383/

35 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com