gpt4 book ai didi

python-3.x - 无法遍历 PyTorch MNIST 数据集

转载 作者:行者123 更新时间:2023-12-01 21:55:56 24 4
gpt4 key购买 nike

我正在尝试在 Pytorch 中加载 MNIST 数据集,并使用内置数据加载器迭代训练示例。但是,在迭代器上调用 next() 时出现错误。 CIFAR10 没有这个问题。

import torch
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 128

dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
dataiter = iter(dataloader)
dataiter.next() # ERROR
# RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]

我正在使用 Python 3.7.3 和 PyTorch 1.1.0

最佳答案

MNIST 数据集由灰度图像组成,即每个图像只有 1 channel ,而 CIFAR10 数据集由彩色图像组成,即,每张图片都有 3 channel 。

因此,如果是 MNIST 数据集,请将 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 替换为 transforms.Normalize([0.5], [0.5]).

关于python-3.x - 无法遍历 PyTorch MNIST 数据集,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57130264/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com