gpt4 book ai didi

python-3.x - Pytorch 灰度输入到 Vgg

转载 作者:行者123 更新时间:2023-12-02 00:51:27 29 4
gpt4 key购买 nike

刚接触pytorch,想用Vgg做迁移学习。我想删除全连接层并添加一些新的全连接层。我还想使用灰度输入而不是 RGB 输入。为此,我将添加输入层的权重并获得单个权重。所以三个 channel 的权重会相加。

我成功删除了全连接层,但我在处理灰度部分时遇到了问题。我将三个权重加在一起形成一个新的权重。然后我尝试更改 vgg 模型的状态指令,但这给了我错误。网络代码如下:

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
vgg=models.vgg16(pretrained = True).features[:30]

w1=vgg.state_dict()['0.weight'][:,0,:,:] #first channel of first input layer's weight
w2=vgg.state_dict()['0.weight'][:,1,:,:]
w3=vgg.state_dict()['0.weight'][:,2,:,:]
w4=w1+w2+w3 # add the three weigths of the channels
w4=w4.unsqueeze(1) # make it 4 dimensional

a=vgg.state_dict()#create a new statedict
a['0.weight']=w4 #replace the new state dict's weigt

vgg.load_state_dict(a) # this line gives the error,load the new state dict

self.vgg =nn.Sequential(vgg)
self.fc1 = nn.Linear(14*14*512, 1000)
self.fc2 = nn.Linear(1000, 2)

def forward(self, x):
x = self.vgg(x)
x = x.view(-1, 14 * 14 * 512)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x

这给出了一个错误:

RuntimeError: Error(s) in loading state_dict for Sequential: size mismatch for 0.weight: copying a param with shape torch.Size([64, 1, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 3, 3, 3]).

所以它不允许我用不同大小的重量替换重量。是否有解决此问题的方法,或者还有其他我可以尝试的方法。我想要做的就是使用 vgg 的层直到完全连接的层并更改第一层的权重。

最佳答案

您没有指定您的 VGG 类来自哪里,但我假设它来自 torchvision.models

VGG 模型是为具有 3 个 channel 的图像创建的。你可以在 make_layers method on GitHub 中看到这个.

修改 torchvision 包中的代码可能不是一个好主意,但您可以在项目中创建一个副本并使 in_channels 可设置。

关于python-3.x - Pytorch 灰度输入到 Vgg,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57296799/

29 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com