gpt4 book ai didi

python - Pytorch - 无法切片 torchvision MNIST 数据集

转载 作者:行者123 更新时间:2023-12-04 13:02:04 26 4
gpt4 key购买 nike

在Pytorch中,当使用torchvision的MNIST数据集时,我们可以得到一个数字如下:

from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset, TensorDataset

tsfm = transforms.Compose([transforms.Resize((16, 16)),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])

mnist_ds = datasets.MNIST(root='../../../_data/mnist',train=True,download=True,
transform=tsfm)

digit_12 = mnist_ds[12]
虽然可以在许多数据集上切片,但我们不能在这个数据集上切片:
>>> digit_12_to_14 = mnist_ds[12:15]
ValueError: Too many dimensions: 3 > 2.
这是由于 Image.fromarray()getItem() .
是否可以在不使用 Dataloader 的情况下使用 MNIST 数据集?

PS:我想避免使用 Dataloader 的原因是一次向 GPU 发送一批会减慢训练速度。我更喜欢一次将整个数据集发送到 GPU。为此,我需要访问整个转换后的数据集。

最佳答案

Dataset 接口(interface)只需要

All subclasses should override __len__, that provides the size of the dataset, and __getitem__, supporting integer indexing in range from 0 to len(self) exclusive.



显然没有提到切片 - 其他数据集的切片行为是一个额外的功能。如果想一次性获取全部数据,可以查 implementation只需使用 mnist.datamnist.targets__init__ 末尾定义的张量.

如果要转换数据,可以使用
data = [mnist_ds[i] for i in range(len(mnist_ds))]
xs = torch.stack([d[0] for d in data], dim=0)
ys = torch.stack([d[1] for d in data], dim=0)

或转换 mnist.data一次全部张量(尽管这不适用于 torchvision.transform 转换)。

关于python - Pytorch - 无法切片 torchvision MNIST 数据集,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54251798/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com