gpt4 book ai didi

image-processing - state_dict 中缺少键

转载 作者:行者123 更新时间:2023-11-30 09:04:37 25 4
gpt4 key购买 nike

我在 Google Colab 上加载模型时遇到问题。这是代码:

我已附上以下代码

我尝试更改 statedict 的名称,但没有帮助基本上,我试图保存我的模型以供以后使用,但是,这变得非常困难,因为我无法正确保存和加载它。请帮我解决这个问题。在代码部分之后,您还会发现我在下面附加的错误。

这是代码

from zipfile import ZipFile
file_name = 'data.zip'
with ZipFile(file_name, 'r') as zip:
zip.extractall()

from zipfile import ZipFile
file_name = 'results.zip'
with ZipFile(file_name, 'r') as zip:
zip.extractall()

!pip install tensorflow-gpu

from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable


batchSize = 64
imageSize = 64

transform = transforms.Compose([transforms.Resize(imageSize), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),])

dataset = dset.CIFAR10(root = './data', download = True, transform = transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size = batchSize, shuffle = True, num_workers = 2)


def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)


class G(nn.Module):

def __init__(self):
super(G, self).__init__()
self.main = nn.Sequential(
nn.ConvTranspose2d(100, 512, 4, 1, 0, bias = False),
nn.BatchNorm2d(512),
nn.ReLU(True),
nn.ConvTranspose2d(512, 256, 4, 2, 1, bias = False),
nn.BatchNorm2d(256),
nn.ReLU(True),
nn.ConvTranspose2d(256, 128, 4, 2, 1, bias = False),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, 64, 4, 2, 1, bias = False),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.ConvTranspose2d(64, 3, 4, 2, 1, bias = False),
nn.Tanh()
)

def forward(self, input):
output = self.main(input)
return output


netG = G()
netG.load_state_dict(torch.load('generator.pth'))
netG.eval()
#netG.apply(weights_init)



class D(nn.Module):

def __init__(self):
super(D, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(3, 64, 4, 2, 1, bias = False),
nn.LeakyReLU(0.2, inplace = True),
nn.Conv2d(64, 128, 4, 2, 1, bias = False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace = True),
nn.Conv2d(128, 256, 4, 2, 1, bias = False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace = True),
nn.Conv2d(256, 512, 4, 2, 1, bias = False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace = True),
nn.Conv2d(512, 1, 4, 1, 0, bias = False),
nn.Sigmoid()
)

def forward(self, input):
output = self.main(input)
return output.view(-1)


netD = D()
netD.load_state_dict(torch.load('discriminator.pth'))
netD.eval()
#netD.apply(weights_init)


criterion = nn.BCELoss()
checkpoint = torch.load('discriminator.pth')
optimizerD = optim.Adam(netD.parameters(), lr = 0.0002, betas = (0.5, 0.999))
optimizerD.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
errD = checkpoint['loss']
checkpoint1 = torch.load('genrator.pth')
optimizerG = optim.Adam(netG.parameters(), lr = 0.0002, betas = (0.5, 0.999))
optimizerG.load_state_dict(checkpoint1['optimizer_state_dict'])
errG = checkpoint1['loss']
k = epoch
for j in range(k, 10):

for i, data in enumerate(dataloader, 0):


netD.zero_grad()


real, _ = data
input = Variable(real)
target = Variable(torch.ones(input.size()[0]))
output = netD(input)
errD_real = criterion(output, target)


noise = Variable(torch.randn(input.size()[0], 100, 1, 1))
fake = netG(noise)
target = Variable(torch.zeros(input.size()[0]))
output = netD(fake.detach())
errD_fake = criterion(output, target)


errD = errD_real + errD_fake
errD.backward()
optimizerD.step()



netG.zero_grad()
target = Variable(torch.ones(input.size()[0]))
output = netD(fake)
errG = criterion(output, target)
errG.backward()
optimizerG.step()



print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f' % (epoch+1, 10, i+1, len(dataloader), errD.data, errG.data))
if i % 100 == 0:
vutils.save_image(real, '%s/real_samples.png' % "./results", normalize = True)
fake = netG(noise)
vutils.save_image(fake.data, '%s/fake_samples_epoch_%03d.png' % ("./results", epoch+1), normalize = True)

torch.save({
'epoch': epoch,
'model_state_dict': netD.state_dict(),
'optimizer_state_dict': optimizerD.state_dict(),
'loss': errD
}, 'discriminator.pth')
torch.save({
'epoch': epoch,
'model_state_dict': netG.state_dict(),
'optimizer_state_dict': optimizerG.state_dict(),
'loss': errG
}, 'generator.pth')

这是错误

RuntimeError                              Traceback (most recent call last)
<ipython-input-23-3e55546152c7> in <module>()
26 # Creating the generator
27 netG = G()
---> 28 netG.load_state_dict(torch.load('generator.pth'))
29 netG.eval()
30 #netG.apply(weights_init)

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
767 if len(error_msgs) > 0:
768 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
--> 769 self.__class__.__name__, "\n\t".join(error_msgs)))
770
771 def _named_members(self, get_members_fn, prefix='', recurse=True):

RuntimeError: Error(s) in loading state_dict for G:
Missing key(s) in state_dict: "main.0.weight", "main.1.weight", "main.1.bias", "main.1.running_mean", "main.1.running_var", "main.3.weight", "main.4.weight", "main.4.bias", "main.4.running_mean", "main.4.running_var", "main.6.weight", "main.7.weight", "main.7.bias", "main.7.running_mean", "main.7.running_var", "main.9.weight", "main.10.weight", "main.10.bias", "main.10.running_mean", "main.10.running_var", "main.12.weight".
Unexpected key(s) in state_dict: "epoch", "model_state_dict", "optimizer_state_dict", "loss".

最佳答案

您需要访问已加载检查点内的'model_state_dict'键。
尝试:

netG.load_state_dict(torch.load('generator.pth')['model_state_dict'])

您可能还需要对鉴别器应用相同的修复。

关于image-processing - state_dict 中缺少键,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55744941/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com