gpt4 book ai didi

python - Torch 网络负载未正确处理

转载 作者:太空宇宙 更新时间:2023-11-03 20:08:36 24 4
gpt4 key购买 nike

我正在尝试在pytorch环境中使用3x64x64图像创建一个网络,看来我成功地训练了我的网络并保存了它。网络看起来像:

class LC_small(nn.Module):
def __init__(self,c_in,c_out = 256):
super(LC_small,self).__init__()
self.conv1 = conv(c_in,64,k=3,stride=1,pad=1)
self.conv2 = conv(64, 128, k=3, stride=2, pad=1)
self.conv3 = conv(128, 128, k=3, stride=1, pad=1)
self.conv4 = conv(128, 128, k=3, stride=2, pad=1)
self.conv5 = conv(128, 128, k=3, stride=1, pad=1)
self.conv6 = conv(128, 256, k=3, stride=2, pad=1)
self.conv7 = conv(256, 256, k=3, stride=1, pad=1)# int(h/8 x w/8 x 256)
self.flat = dense(int(w_rsz/8)*int(h_rsz/8)*256,256)
self.dense1 = dense(256,128,False)
self.dense2 = dense(128,3,False)
def forward(self, input):
out = self.conv1(input)
out = self.conv2(out)
out = self.conv3(out)
out = self.conv4(out)
out = self.conv5(out)
out = self.conv6(out)
out = self.conv7(out)
out = out.view(out.size(0),-1)
out = self.flat(out)
out = self.dense1(out)
out = self.dense2(out)
# print(out.shape)
normal = torch.nn.functional.normalize(out, 2, 1)

return normal

我在训练时保存了我的模型:

for epoch in range(10):
# continue # 현재 Training 됐다고 가정하고
total_loss = 0
route_param = open(route_diffuse+'/netparam.txt','w')
for param in lcnet.state_dict():
route_param.write(str(param)+'\t'+str(lcnet.state_dict()[param].size())+'\n')
for i,data in enumerate(load_LC,0):
input, gtval = data[0].to(dev),data[1].to(dev)
opt.zero_grad()

output = lcnet(input)
loss = crit(output,gtval)
loss.backward()
opt.step()
total_loss +=loss.item()
if i%10 == 9:
print(epoch,i,total_loss/10)
torch.save(lcnet,route_save)
total_loss = 0

但是,当我尝试加载我创建的网络时,我看到了如下错误消息:

Traceback (most recent call last):

File "E:/DLPrj/venv/torch_practice.py", line 324, in <module>

ipl,npl = getseqi_np(sq_t,lcnet) # data : 8 x 6 x w x h

File "E:/DLPrj/venv/torch_practice.py", line 133, in getseqi_np

l1 = net_lc(torch.from_numpy(i1r))

File "E:\DLPrj\venv\lib\site-packages\torch\nn\modules\module.py", line 541, in __call__

result = self.forward(*input, **kwargs)

File "E:/DLPrj/venv/torch_practice.py", line 216, in forward

out = self.conv1(input)

File "E:\DLPrj\venv\lib\site-packages\torch\nn\modules\module.py", line 541, in __call__

result = self.forward(*input, **kwargs)

File "E:\DLPrj\venv\lib\site-packages\torch\nn\modules\container.py", line 92, in forward

input = module(input)

File "E:\DLPrj\venv\lib\site-packages\torch\nn\modules\module.py", line 541, in __call__

result = self.forward(*input, **kwargs)

File "E:\DLPrj\venv\lib\site-packages\torch\nn\modules\conv.py", line 345, in forward

return self.conv2d_forward(input, self.weight)

File "E:\DLPrj\venv\lib\site-packages\torch\nn\modules\conv.py", line 342, in conv2d_forward

self.padding, self.dilation, self.groups)

RuntimeError: Expected 4-dimensional input for 4-dimensional weight 64 3 3 3, but got 3-dimensional input of size [64, 64, 3] instead

出现此错误后,pycharm 卡住,并且在重新启动 pycharm 之前我无法重新运行此代码。

当我训练网络时,我还会收到一些警告消息:

E:\DLPrj\venv\lib\site-packages\torch\serialization.py:292: UserWarning: Couldn't retrieve source code for container of type LC_small. It won't be checked for correctness upon loading.

"type " + obj.__name__ + ". It won't be checked "

E:\DLPrj\venv\lib\site-packages\torch\serialization.py:292: UserWarning: Couldn't retrieve source code for container of type Sequential. It won't be checked for correctness upon loading.

"type " + obj.__name__ + ". It won't be checked "

E:\DLPrj\venv\lib\site-packages\torch\serialization.py:292: UserWarning: Couldn't retrieve source code for container of type Conv2d. It won't be checked for correctness upon loading.

"type " + obj.__name__ + ". It won't be checked "

E:\DLPrj\venv\lib\site-packages\torch\serialization.py:292: UserWarning: Couldn't retrieve source code for container of type BatchNorm2d. It won't be checked for correctness upon loading.

"type " + obj.__name__ + ". It won't be checked "

E:\DLPrj\venv\lib\site-packages\torch\serialization.py:292: UserWarning: Couldn't retrieve source code for container of type LeakyReLU. It won't be checked for correctness upon loading.

"type " + obj.__name__ + ". It won't be checked "

E:\DLPrj\venv\lib\site-packages\torch\serialization.py:292: UserWarning: Couldn't retrieve source code for container of type Linear. It won't be checked for correctness upon loading.

"type " + obj.__name__ + ". It won't be checked "

我无法理解为什么网络的输入大小突然改变,或者为什么它错误地保存了我的网络。请检查我的问题,非常感谢。

最佳答案

所以你的第一条错误消息是因为 torch.from_numpy(i1r) 的形状错误。你需要做

np.expand_dims(i1r.transpose(2,0,1), axis=0) 

然后它就会被正确处理。这是因为它需要一个批处理维度,而您没有提供批处理维度以及第一个维度而不是最后一个维度中的 channel 。

至于您的第二条错误消息,可能是因为您错误地定义了转换和密集,因此在保存模型时会出现困惑。

关于python - Torch 网络负载未正确处理,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58824625/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com