gpt4 book ai didi

python - 运行时错误 : expected stride to be a single integer value

转载 作者:太空宇宙 更新时间:2023-11-04 00:10:59 24 4
gpt4 key购买 nike

我是 Pytorch 的新手,很抱歉回答这个基本问题。该模型给我尺寸不匹配错误如何解决这个问题?也许其中不止一个问题。任何帮助将不胜感激。谢谢

class PR(nn.Module):
def __init__(self):
super(PR, self).__init__()
self.conv1 = nn.Conv2d(3,6,kernel_size=5)
self.conv2 = nn.Conv2d(6,1,kernel_size=2)
self.dens1 = nn.Linear(300, 256)
self.dens2 = nn.Linear(256, 256)
self.dens3 = nn.Linear(512, 24)
self.drop = nn.Dropout()

def forward(self, x):
out = self.conv1(x)
out = self.conv2(x)
out = self.dens1(x)
out = self.dens2(x)
out = self.dens3(x)
return out

model = PR()
input = torch.rand(28,28,3)
output = model(input)

最佳答案

请查看更正后的代码。我对进行更正的行进行了编号并在下面进行了描述。

class PR(torch.nn.Module):
def __init__(self):
super(PR, self).__init__()
self.conv1 = torch.nn.Conv2d(3,6, kernel_size=5) # (2a) in 3x28x28 out 6x24x24
self.conv2 = torch.nn.Conv2d(6,1, kernel_size=2) # (2b) in 6x24x24 out 1x23x23 (6)
self.dens1 = torch.nn.Linear(529, 256) # (3a)
self.dens2 = torch.nn.Linear(256, 256)
self.dens3 = torch.nn.Linear(256, 24) # (4)
self.drop = torch.nn.Dropout()

def forward(self, x):
out = self.conv1(x)
out = self.conv2(out) # (5)

out = out.view(-1, 529) # (3b)

out = self.dens1(out)
out = self.dens2(out)
out = self.dens3(out)
return out

model = PR()
ins = torch.rand(1, 3, 28, 28) # (1)
output = model(ins)
  1. 首先,pytorch 处理图像张量(您执行二维卷积,因此我假设这是一个图像输入)如下:[batch_size x image_depth x height width]
  2. 重要的是要了解卷积核、填充和步幅的工作原理。在你的情况下 kernel_size 是 5 并且你没有填充(和步幅 1)。这意味着特征图的尺寸减小了(如图所示)。在您的情况下,第一个转换。层采用 3x28x28 张量并产生 6x24x24 张量,第二层从 1x23x23 中取出 6x24x24。我发现在定义转换层旁边对输入和输出张量维度进行注释非常有用(参见上面的代码)

No padding, stride 1 2-d conv - https://github.com/vdumoulin/conv_arithmetic

  1. 在这里,您需要将 [batch_size x 深度 x 高度 x 宽度] 张量“展平”为 [batch_size x 全连接输入]。这可以通过 tensor.view() 来完成。

  2. 线性层输入错误

  3. forward-pass 中的每个操作都采用输入值 x,相反我认为您可能希望将每一层的结果传递给下一层

虽然这段代码现在可以运行,但这并不意味着它完全有意义。最重要的事情(对于一般的神经网络,我会说)是激活函数。这些完全没有了。

要开始使用 pytorch 中的神经网络,我强烈推荐很棒的 pytorch 教程:https://pytorch.org/tutorials/ (我将从 60 分钟 Blitz 教程开始)

希望这对您有所帮助!

关于python - 运行时错误 : expected stride to be a single integer value,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52503695/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com