gpt4 book ai didi

mathematical-optimization - 如何在pytorch中手动应用渐变

转载 作者:行者123 更新时间:2023-12-01 01:46:11 24 4
gpt4 key购买 nike

开始学习 pytorch 并尝试做一些非常简单的事情,尝试将大小为 5 的随机初始化向量移动到值 [1,2,3,4,5] 的目标向量。

但我的距离并没有减少!!还有我的矢量 x只是疯了。不知道我错过了什么。

import torch
import numpy as np
from torch.autograd import Variable

# regress a vector to the goal vector [1,2,3,4,5]

dtype = torch.cuda.FloatTensor # Uncomment this to run on GPU

x = Variable(torch.rand(5).type(dtype), requires_grad=True)
target = Variable(torch.FloatTensor([1,2,3,4,5]).type(dtype),
requires_grad=False)
distance = torch.mean(torch.pow((x - target), 2))

for i in range(100):
distance.backward(retain_graph=True)
x_grad = x.grad
x.data.sub_(x_grad.data * 0.01)

最佳答案

您的代码中有两个错误会阻止您获得所需的结果。

第一个错误是您应该将距离计算放在循环中。因为在这种情况下,距离就是损失。所以我们必须在每次迭代中监控它的变化。

第二个错误是您应该手动将 x.grad 清零。因为 pytorch won't zero out the grad in variable by default .

以下是按预期工作的示例代码:

import torch
import numpy as np
from torch.autograd import Variable
import matplotlib.pyplot as plt

# regress a vector to the goal vector [1,2,3,4,5]

dtype = torch.cuda.FloatTensor # Uncomment this to run on GPU

x = Variable(torch.rand(5).type(dtype), requires_grad=True)
target = Variable(torch.FloatTensor([1,2,3,4,5]).type(dtype),
requires_grad=False)

lr = 0.01 # the learning rate

d = []
for i in range(1000):
distance = torch.mean(torch.pow((x - target), 2))
d.append(distance.data)
distance.backward(retain_graph=True)

x.data.sub_(lr * x.grad.data)
x.grad.data.zero_()

print(x.data)

fig, ax = plt.subplots()
ax.plot(d)
ax.set_xlabel("iteration")
ax.set_ylabel("distance")
plt.show()

下面是距离w.r.t迭代的图

enter image description here

我们可以看到模型在大约 600 次迭代时收敛。如果我们将学习率设置得更高(例如,lr=0.1),模型会收敛得更快(大约需要 60 次迭代,见下图)

enter image description here

现在,x 变成如下所示

0.9878 1.9749 2.9624 3.9429 4.9292



这非常接近您的目标 [1, 2, 3, 4, 5]。

关于mathematical-optimization - 如何在pytorch中手动应用渐变,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49154514/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com