gpt4 book ai didi

python - 在 Pytorch 中加载我的模型时丢失和意外键的问题

转载 作者:行者123 更新时间:2023-11-28 18:05:13 31 4
gpt4 key购买 nike

我正在尝试使用本教程加载模型:https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-model-for-inference .不幸的是,我是初学者,遇到了一些问题。

我已经创建了检查点:

checkpoint = {'epoch': epochs, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(),'loss': loss}
torch.save(checkpoint, 'checkpoint.pth')

然后我为我的网络编写了类,我想加载文件:

class Network(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(9216, 4096)
self.fc2 = nn.Linear(4096, 1000)
self.fc3 = nn.Linear(1000, 102)

def forward(self, x):
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.relu(x)
x = self.fc3(x)
x = log(F.softmax(x, dim=1))
return x

像那样:

def load_checkpoint(filepath):
checkpoint = torch.load(filepath)
model = Network()
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model = load_checkpoint('checkpoint.pth')

我收到此错误(已编辑以显示整个通信):

RuntimeError: Error(s) in loading state_dict for Network:
Missing key(s) in state_dict: "fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias", "fc3.weight", "fc3.bias".
Unexpected key(s) in state_dict: "features.0.weight", "features.0.bias", "features.3.weight", "features.3.bias", "features.6.weight", "features.6.bias", "features.8.weight", "features.8.bias", "features.10.weight", "features.10.bias", "classifier.fc1.weight", "classifier.fc1.bias", "classifier.fc2.weight", "classifier.fc2.bias", "classifier.fc3.weight", "classifier.fc3.bias".

这是我的model.state_dict().keys():

odict_keys(['features.0.weight', 'features.0.bias', 'features.3.weight', 
'features.3.bias', 'features.6.weight', 'features.6.bias',
'features.8.weight', 'features.8.bias', 'features.10.weight',
'features.10.bias', 'classifier.fc1.weight', 'classifier.fc1.bias',
'classifier.fc2.weight', 'classifier.fc2.bias', 'classifier.fc3.weight',
'classifier.fc3.bias'])

这是我的模型:

AlexNet(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
(1): ReLU(inplace)
(2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
(3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(4): ReLU(inplace)
(5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
(6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(7): ReLU(inplace)
(8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(9): ReLU(inplace)
(10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace)
(12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)

((classifier): Sequential(
(fc1): Linear(in_features=9216, out_features=4096, bias=True)
(relu1): ReLU()
(fc2): Linear(in_features=4096, out_features=1000, bias=True)
(relu2): ReLU()
(fc3): Linear(in_features=1000, out_features=102, bias=True)
(output): LogSoftmax()
)
)

这是我有史以来的第一个网络,我一直在犯错。感谢您引导我走向正确的方向!

最佳答案

所以您的网络本质上是AlexNet分类器部分,您希望加载预训练的AlexNet重量。问题在于 state_dict 中的键是“完全合格的”,这意味着如果您将网络视为嵌套模块的树,则键只是每个分支中的模块列表,加入带有像 grandparent.parent.child 这样的点。你想要

  1. 仅保留名称以“分类器”开头的张量。
  2. 删除“分类器”。部分 key

那么试试

model = Network()
loaded_dict = checkpoint['model_state_dict']
prefix = 'classifier.'
n_clip = len(prefix)
adapted_dict = {k[n_clip:]: v for k, v in loaded_dict.items()
if k.startswith(prefix)}
model.load_state_dict(adapted_dict)

关于python - 在 Pytorch 中加载我的模型时丢失和意外键的问题,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53907073/

31 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com