gpt4 book ai didi

python - 如果用于梯度更新的索引叶变量如何解决就地操作错误?

转载 作者:太空狗 更新时间:2023-10-30 02:53:32 24 4
gpt4 key购买 nike

当我尝试索引叶变量以使用自定义收缩函数更新梯度时,我遇到了就地操作错误。我无法解决它。非常感谢任何帮助!

import torch.nn as nn
import torch
import numpy as np
from torch.autograd import Variable, Function

# hyper parameters
batch_size = 100 # batch size of images
ld = 0.2 # sparse penalty
lr = 0.1 # learning rate

x = Variable(torch.from_numpy(np.random.normal(0,1,(batch_size,10,10))), requires_grad=False) # original

# depends on size of the dictionary, number of atoms.
D = Variable(torch.from_numpy(np.random.normal(0,1,(500,10,10))), requires_grad=True)

# hx sparse representation
ht = Variable(torch.from_numpy(np.random.normal(0,1,(batch_size,500,1,1))), requires_grad=True)

# Dictionary loss function
loss = nn.MSELoss()

# customized shrink function to update gradient
shrink_ht = lambda x: torch.stack([torch.sign(i)*torch.max(torch.abs(i)-lr*ld,0)[0] for i in x])

### sparse reprsentation optimizer_ht single image.
optimizer_ht = torch.optim.SGD([ht], lr=lr, momentum=0.9) # optimizer for sparse representation

## update for the batch
for idx in range(len(x)):
optimizer_ht.zero_grad() # clear up gradients
loss_ht = 0.5*torch.norm((x[idx]-(D*ht[idx]).sum(dim=0)),p=2)**2
loss_ht.backward() # back propogation and calculate gradients
optimizer_ht.step() # update parameters with gradients
ht[idx] = shrink_ht(ht[idx]) # customized shrink function.

RuntimeError Traceback (most recent call last) in ()
15 loss_ht.backward() # back propogation and calculate gradients
16 optimizer_ht.step() # update parameters with gradients
—> 17 ht[idx] = shrink_ht(ht[idx]) # customized shrink function.
18
19

/home/miniconda3/lib/python3.6/site-packages/torch/autograd/variable.py in setitem(self, key, value)
85 return MaskedFill.apply(self, key, value, True)
86 else:
—> 87 return SetItem.apply(self, key, value)
88
89 def deepcopy(self, memo):

RuntimeError: a leaf Variable that requires grad has been used in an in-place operation.

具体来说,下面这行代码似乎会出错,因为它同时索引和更新叶变量。

ht[idx] = shrink_ht(ht[idx])  # customized shrink function.

谢谢。

W.S.

最佳答案

我刚发现:为了更新变量,需要ht.data[idx]而不是ht[idx]。我们可以使用 .data 直接访问张量。

关于python - 如果用于梯度更新的索引叶变量如何解决就地操作错误?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49161652/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com