gpt4 book ai didi

python - pytorch 数据集中每个类的实例数

转载 作者:行者123 更新时间:2023-12-04 04:28:56 26 4
gpt4 key购买 nike

我正在尝试使用 PyTorch 制作一个简单的图像分类器。
这是我将数据加载到数据集和 dataLoader 中的方式:

batch_size = 64
validation_split = 0.2
data_dir = PROJECT_PATH+"/categorized_products"
transform = transforms.Compose([transforms.Grayscale(), CustomToTensor()])

dataset = ImageFolder(data_dir, transform=transform)

indices = list(range(len(dataset)))

train_indices = indices[:int(len(indices)*0.8)]
test_indices = indices[int(len(indices)*0.8):]

train_sampler = SubsetRandomSampler(train_indices)
test_sampler = SubsetRandomSampler(test_indices)

train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=train_sampler, num_workers=16)
test_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=test_sampler, num_workers=16)

我想分别打印出训练和测试数据中每个类(class)的图像数量,如下所示:

在火车数据中:
  • 鞋子:20
  • 衬衫:14

  • 在测试数据中:
  • 鞋子:4
  • 衬衫:3

  • 我试过这个:
    from collections import Counter
    print(dict(Counter(sample_tup[1] for sample_tup in dataset.imgs)))

    但我收到了这个错误:
    AttributeError: 'MyDataset' object has no attribute 'img'

    最佳答案

    您需要使用 .targets访问数据标签,即

    print(dict(Counter(dataset.targets)))

    它将打印如下内容(例如在 MNIST 数据集中):
    {5: 5421, 0: 5923, 4: 5842, 1: 6742, 9: 5949, 2: 5958, 3: 6131, 6: 5918, 7: 6265, 8: 5851}

    此外,您可以使用 .classes.class_to_idx获取标签 id 到类的映射:
    print(dataset.class_to_idx)
    {'0 - zero': 0,
    '1 - one': 1,
    '2 - two': 2,
    '3 - three': 3,
    '4 - four': 4,
    '5 - five': 5,
    '6 - six': 6,
    '7 - seven': 7,
    '8 - eight': 8,
    '9 - nine': 9}

    编辑:方法 1

    从评论中,为了分别获得训练集和测试集的类分布,您可以简单地迭代子集,如下所示:
    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

    # labels in training set
    train_classes = [label for _, label in train_dataset]
    Counter(train_classes)
    Counter({0: 4757,
    1: 5363,
    2: 4782,
    3: 4874,
    4: 4678,
    5: 4321,
    6: 4747,
    7: 5024,
    8: 4684,
    9: 4770})

    编辑 (2):方法 2

    由于您有一个大型数据集,并且正如您所说,迭代所有训练集需要相当长的时间,因此还有另一种方法:

    您可以使用 .indices子集,它指的是为子集选择的原始数据集中的索引。

    IE。
    train_classes = [dataset.targets[i] for i in train_dataset.indices]
    Counter(train_classes) # if doesn' work: Counter(i.item() for i in train_classes)

    关于python - pytorch 数据集中每个类的实例数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62319228/

    26 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com