gpt4 book ai didi

python-3.x - 层列表的 PyTorch 模型抛出错误

转载 作者:行者123 更新时间:2023-12-05 03:54:18 25 4
gpt4 key购买 nike

我设计了以下具有 2 个 conv2d 层的 torch 模型。它可以正常工作。

import torch.nn as nn
from torchsummary import summary

class mini_unet(nn.Module):
def __init__(self):
super(mini_unet, self).__init__()
self.c1 = nn.Conv2d(1, 1, 3, padding = 1)
self.r1 = nn.ReLU()
self.c2 = nn.Conv2d(1, 1, 3, padding = 1)
self.r2 = nn.ReLU()

def forward(self, x):
x = self.c1(x)
x = self.r1(x)
x = self.c2(x)
x = self.r2(x)
return x

a = mini_unet().cuda()

print(a)

但是,假设我有太多层,我不想在 forward 函数中明确地写出每一层。所以,我使用了一个列表来自动化它,如下所示。

import torch.nn as nn
from torchsummary import summary

class mini_unet2(nn.Module):
def __init__(self):
super(mini_unet2, self).__init__()
self.layers = [nn.Conv2d(1, 1, 3, padding = 1),
nn.ReLU(),
nn.Conv2d(1, 1, 3, padding = 1),
nn.ReLU()]

def forward(self, x):
for l in self.layers:
x = l(x)
return x

a2 = mini_unet2().cuda()
print(a2)
summary(a2, (1,4,4))

这给了我以下奇怪的错误,我已经使用了 cuda() 为什么它不起作用?

RuntimeError                              Traceback (most recent call last)
<ipython-input-36-1d71e75b96e0> in <module>
17 a2 = mini_unet2().cuda()
18 print(a2)
---> 19 summary(a2, (1,4,4))

~/anaconda3/envs/torch/lib/python3.6/site-packages/torchsummary/torchsummary.py in summary(model, input_size, batch_size, device)
70 # make a forward pass
71 # print(x.shape)
---> 72 model(*x)
73
74 # remove these hooks

~/anaconda3/envs/torch/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
487 result = self._slow_forward(*input, **kwargs)
488 else:
--> 489 result = self.forward(*input, **kwargs)
490 for hook in self._forward_hooks.values():
491 hook_result = hook(self, input, result)

<ipython-input-36-1d71e75b96e0> in forward(self, x)
12 def forward(self, x):
13 for l in self.layers:
---> 14 x = l(x)
15 return x
16

~/anaconda3/envs/torch/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
487 result = self._slow_forward(*input, **kwargs)
488 else:
--> 489 result = self.forward(*input, **kwargs)
490 for hook in self._forward_hooks.values():
491 hook_result = hook(self, input, result)

~/anaconda3/envs/torch/lib/python3.6/site-packages/torch/nn/modules/conv.py in forward(self, input)
318 def forward(self, input):
319 return F.conv2d(input, self.weight, self.bias, self.stride,
--> 320 self.padding, self.dilation, self.groups)
321
322

RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

最佳答案

该错误可能有点违反直觉,但错误源于您对层使用 python 列表。

根据文档,您需要使用 torch.nn.ModuleList 来包含子模块,而不是 python 列表。

因此,只需将 list 更改为 nn.Modulelist(list) 即可解决错误。

import torch.nn as nn
from torchsummary import summary

class mini_unet2(nn.Module):
def __init__(self):
super(mini_unet2, self).__init__()
self.layers = nn.ModuleList([nn.Conv2d(1, 1, 3, padding = 1),
nn.ReLU(),
nn.Conv2d(1, 1, 3, padding = 1),
nn.ReLU()])

def forward(self, x):
for l in self.layers:
x = l(x)
return x

a2 = mini_unet2().cuda()
print(a2)
summary(a2, (1,4,4))

关于python-3.x - 层列表的 PyTorch 模型抛出错误,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61116039/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com