gpt4 book ai didi

python - 使用 PyTorch 和 TorchVision 对自定义数据集进行训练-验证-测试拆分

转载 作者:行者123 更新时间:2023-12-03 23:07:09 24 4
gpt4 key购买 nike

我有一些用于二进制分类任务的图像数据,图像被组织到 2 个文件夹中,分别是 data/model_data/class-A 和 data/model_data/class-B。

总共有N张图片。我想对训练/验证/测试进行 70/20/10 拆分。
我正在使用 PyTorch 和 Torchvision 来完成这项任务。这是我到目前为止的代码。

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils, datasets, models

data_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

model_dataset = datasets.ImageFolder(root, transform=data_transform)
train_count = int(0.7 * total_count)
valid_count = int(0.2 * total_count)
test_count = total_count - train_count - valid_count
train_dataset, valid_dataset, test_dataset = torch.utils.data.random_split(model_dataset, (train_count, valid_count, test_count))
train_dataset_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKER)
valid_dataset_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKER)
test_dataset_loader = torch.utils.data.DataLoader(test_dataset , batch_size=BATCH_SIZE, shuffle=False,num_workers=NUM_WORKER)
dataloaders = {'train': train_dataset_loader, 'val': valid_dataset_loader, 'test': test_dataset_loader}

我觉得这不是正确的做法,原因有两个。
  • 我对所有分割应用相同的变换。 (这显然不是我想要做的!解决方案很可能就是 here 。)
  • 通常人们首先将原始数据分成测试/训练,然后他们
    将 train 分成 train/val,而我直接将
    原始数据进入训练/验证/测试。 (这是正确的吗?)

  • 所以,我的问题是,我所做的是否正确? (可能不是)
    如果它不正确,我该如何编写数据加载器以实现所需的拆分,以便我可以对每个训练/测试/验证应用单独的转换?

    最佳答案

    Usually people first separate the original data into test/train and then they separate train into train/val, whereas I am directly separating the original data into train/val/test. (Is this correct?)



    是的,它完全正确,可读且完全没问题

    I am applying the same transform to all the splits. (This is not what I want to do, obviously! The solution for this is most probably the answer here.)



    是的,这个答案是有可能的,但它是毫无意义的冗长 tbh。您可以使用第三方工具 torchdata ,只需安装:
    pip install torchdata

    可以在 here 找到文档(同样免责声明:我是作者)。

    它允许您轻松地将转换映射到任何 torch.utils.data.Dataset (在本例中为 train )。您的代码看起来像这样(只需更改两行,检查注释,并格式化您的代码以使其更容易遵循):
    import torch
    import torchvision

    import torchdata as td

    data_transform = torchvision.transforms.Compose(
    [
    torchvision.transforms.RandomResizedCrop(224),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
    ),
    ]
    )

    # Single change, makes an instance of torchdata.Dataset
    # Works just like PyTorch's torch.utils.data.Dataset, but has
    # additional capabilities like .map, cache etc., see project's description
    model_dataset = td.datasets.WrapDataset(torchvision.datasets.ImageFolder(root))
    # Also you shouldn't use transforms here but below
    train_count = int(0.7 * total_count)
    valid_count = int(0.2 * total_count)
    test_count = total_count - train_count - valid_count
    train_dataset, valid_dataset, test_dataset = torch.utils.data.random_split(
    model_dataset, (train_count, valid_count, test_count)
    )

    # Apply transformations here only for train dataset

    train_dataset = train_dataset.map(data_transform)

    # Rest of the code goes the same

    train_dataset_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKER
    )
    valid_dataset_loader = torch.utils.data.DataLoader(
    valid_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKER
    )
    test_dataset_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKER
    )
    dataloaders = {
    "train": train_dataset_loader,
    "val": valid_dataset_loader,
    "test": test_dataset_loader,
    }

    是的,我同意在拆分之前指定 transform 不太清楚,而且 IMO 这更具可读性。

    关于python - 使用 PyTorch 和 TorchVision 对自定义数据集进行训练-验证-测试拆分,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61811946/

    24 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com