gpt4 book ai didi

python - 在 PyTorch 中计算 Conv2d 的输入和输出大小以进行图像分类

转载 作者:太空宇宙 更新时间:2023-11-04 06:59:32 25 4
gpt4 key购买 nike

我正在尝试在此处运行有关 CIFAR10 图像分类的 PyTorch 教程 - http://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py

我做了一个小改动,我正在使用不同的数据集。我有来自 Wikiart 数据集的图像,我想按艺术家分类(标签 = 艺术家姓名)。

这是 Net 的代码 -

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16*5*5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16*5*5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x

然后是我开始训练网络的这部分代码。

for epoch in range(2):
running_loss = 0.0

for i, data in enumerate(wiki_train_dataloader, 0):
inputs, labels = data['image'], data['class']
print(inputs.shape)
inputs, labels = Variable(inputs), Variable(labels)

optimizer.zero_grad()

# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()

# print statistics
running_loss += loss.data[0]
if i % 2000 == 1999: # print every 2000 mini-batches
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 2000))
running_loss = 0.0

这一行 print(inputs.shape) 给我 torch.Size([4, 32, 32, 3]) 我的 Wikiart 数据集,而在原始示例中对于 CIFAR10,它打印 torch.Size([4, 3, 32, 32])

现在,我不确定如何更改我的网络中的 Conv2d 以与 torch.Size([4, 32, 32, 3]) 兼容。

我收到这个错误:

RuntimeError:给定输入大小:(3 x 32 x 3)。计算出的输出大小:(6 x 28 x -1)。/opt/conda/conda-bld/pytorch_1503965122592/work/torch/lib/THNN/generic/SpatialConvolutionMM.c:45 处的输出大小太小

在读取 Wikiart 数据集的图像时,我将它们的大小调整为 (32, 32),这些是 3 channel 图像。

我尝试过的事情:

1) CIFAR10 教程使用了我没有使用的转换。我无法将其合并到我的代码中。

2) 将 self.conv2 = nn.Conv2d(6, 16, 5) 更改为 self.conv2 = nn.Conv2d(3, 6, 5)。这给了我与上面相同的错误。我只是更改它以查看错误消息是否更改。

任何有关如何在 PyTorch 中计算输入和输出大小或自动 reshape 张量的资源都将不胜感激。我刚开始学习 Torch,我发现尺寸计算很复杂。

最佳答案

您必须将您的输入调整为这种格式(批处理、数字 channel 、高度、宽度)。目前您的格式为 (B,H,W,C) (4, 32, 32, 3),因此您需要交换第 4 轴和第 2 轴以使用 (B,C,H,W) 塑造数据。你可以这样做:

inputs, labels = Variable(inputs), Variable(labels)
inputs = inputs.transpose(1,3)
... the rest

关于python - 在 PyTorch 中计算 Conv2d 的输入和输出大小以进行图像分类,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47128044/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com