gpt4 book ai didi

python - 如何改变 Pytorch 数据集的大小?

转载 作者:行者123 更新时间:2023-11-30 08:26:13 24 4
gpt4 key购买 nike

假设我正在从 torchvision.datasets.MNIST 加载 MNIST ,但我只想加载总共 10000 个图像,我将如何切片数据以将其限制为仅某些数量的数据点?据我了解,DataLoader是生成指定批量大小的数据的生成器,但是如何对数据集进行切片?

tr = datasets.MNIST('../data', train=True, download=True, transform=transform)
te = datasets.MNIST('../data', train=False, transform=transform)
train_loader = DataLoader(tr, batch_size=args.batch_size, shuffle=True, num_workers=4, **kwargs)
test_loader = DataLoader(te, batch_size=args.batch_size, shuffle=True, num_workers=4, **kwargs)

最佳答案

您可以使用torch.utils.data.Subset(),例如对于前 10,000 个元素:

import torch.utils.data as data_utils

indices = torch.arange(10000)
tr_10k = data_utils.Subset(tr, indices)

关于python - 如何改变 Pytorch 数据集的大小?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44856691/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com