gpt4 book ai didi

python - 如何在pytorch MNIST数据集中选择特定标签

转载 作者:行者123 更新时间:2023-12-04 14:18:20 43 4
gpt4 key购买 nike

我正在尝试仅使用 PyTorch Mnist 数据集中的特定数字创建数据加载器

我已经尝试创建自己的采样器,但它不起作用,我不确定我是否正确使用了 mask 。

class YourSampler(torch.utils.data.sampler.Sampler):

def __init__(self, mask):

self.mask = mask


def __iter__(self):

return (self.indices[i] for i in torch.nonzero(self.mask))


def __len__(self):

return len(self.mask)


mnist = datasets.MNIST(root=dataroot, train=True, download=True, transform = transform)

mask = [True if mnist[i][1] == 5 else False for i in range(len(mnist))]

mask = torch.tensor(mask)

sampler = YourSampler(mask)

trainloader = torch.utils.data.DataLoader(mnist, batch_size=4, sampler = sampler, shuffle=False, num_workers=2)


到目前为止,我遇到了许多不同类型的错误。对于这个实现,它是“停止迭代”。
我觉得这很容易/愚蠢,但我找不到一种简单的方法来做到这一点。
感谢您的帮助!

最佳答案

我能想到的最简单的选择是就地减少数据集:

indices = dataset.targets == 5 # if you want to keep images with the label 5
dataset.data, dataset.targets = dataset.data[indices], dataset.targets[indices]

关于python - 如何在pytorch MNIST数据集中选择特定标签,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57913825/

43 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com