gpt4 book ai didi

python-3.x - Pytorch 中缺乏 L1 正则化的稀疏解决方案

转载 作者:行者123 更新时间:2023-12-01 13:20:42 27 4
gpt4 key购买 nike

我正在尝试在简单神经网络的第一层(1 个隐藏层)上实现 L1 正则化。我查看了 StackOverflow 上的其他一些帖子,这些帖子使用 Pytorch 应用 l1 正则化来弄清楚应该如何完成(引用: Adding L1/L2 regularization in PyTorch?In Pytorch, how to add L1 regularizer to activations? )。无论我将 lambda(l1 正则化强度参数)增加多高,我都不会在第一个权重矩阵中得到真正的零。为什么会这样? (代码如下)

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class Network(nn.Module):
def __init__(self,nf,nh,nc):
super(Network,self).__init__()
self.lin1=nn.Linear(nf,nh)
self.lin2=nn.Linear(nh,nc)

def forward(self,x):
l1out=F.relu(self.lin1(x))
out=F.softmax(self.lin2(l1out))
return out, l1out

def l1loss(layer):
return torch.norm(layer.weight.data, p=1)

nf=10
nc=2
nh=6
learningrate=0.02
lmbda=10.
batchsize=50

net=Network(nf,nh,nc)

crit=nn.MSELoss()
optimizer=torch.optim.Adagrad(net.parameters(),lr=learningrate)


xtr=torch.Tensor(xtr)
ytr=torch.Tensor(ytr)
#ytr=torch.LongTensor(ytr)
xte=torch.Tensor(xte)
yte=torch.LongTensor(yte)
#cyte=torch.Tensor(yte)

it=200
for epoch in range(it):
per=torch.randperm(len(xtr))
for i in range(0,len(xtr),batchsize):
ind=per[i:i+batchsize]
bx,by=xtr[ind],ytr[ind]
optimizer.zero_grad()
output, l1out=net(bx)
# l1reg=l1loss(net.lin1)
loss=crit(output,by)+lmbda*l1loss(net.lin1)
loss.backward()
optimizer.step()
print('Epoch [%i/%i], Loss: %.4f' %(epoch+1,it, np.float32(loss.data.numpy())))

corr=0
tot=0
for x,y in list(zip(xte,yte)):
output,_=net(x)
_,pred=torch.max(output,-1)
tot+=1 #y.size(0)
corr+=(pred==y).sum()
print(corr)

注意:数据有 10 个特征(2 个类别和 800 个训练样本),并且只有前 2 个是相关的(根据设计),因此人们会假设真正的零应该很容易学习。

最佳答案

您对 layer.weight.data 的使用从其自动微分上下文中删除参数(它是 PyTorch 变量),使其在优化器采用梯度时成为常量。这导致零梯度并且不计算 L1 损失。

如果删除 .data ,范数是根据 PyTorch 变量计算的,梯度应该是正确的。

有关 PyTorch 自动微分机制的更多信息,请参阅此 docs article或此 tutorial .

关于python-3.x - Pytorch 中缺乏 L1 正则化的稀疏解决方案,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50054049/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com