gpt4 book ai didi

python - 与 SubsetRandomSampler 有关

转载 作者:行者123 更新时间:2023-12-05 03:42:28 26 4
gpt4 key购买 nike

我正在使用 SubsetRandomSampler 将分类数据集拆分为测试和验证。我们可以为每个类拆分数据集吗?

import numpy as np
import torch
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler

train_transforms = transforms.Compose([transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
dataset = datasets.ImageFolder( '/data/images/train', transform=train_transforms )

validation_split = .2
shuffle_dataset = True
random_seed= 42
batch_size = 20

dataset_size = len(dataset) #4996
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))

if shuffle_dataset :
np.random.seed(random_seed)
np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]

train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
validation_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=valid_sampler)

最佳答案

您是说训练和验证而不是测试和验证吗?

如果是这样,SubsetRandomSampler 使用从索引中随机选择的样本。因此,您可以在将它们放入 train_indicesval_indices 之前随机拆分每个类的索引。

喜欢

indexs = [[] for _ in range(len(dataset.classes))]  # you can't use `[[]] * len(dataset.classes)`. Although there might be better ways but I don't know
for idx, (_, class_idx) in enumerate(dataset):
indexs[class_idx].append(idx)
train_indices, val_indices = [], []
for cl_idx in indexs:
size = len(cl_idx)
split = int(np.floor(validation_split * size))
np.random.shuffle(cl_idx)
train_indices.extend(cl_idx[split:])
val_indices.extend(cl_idx[:split])
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

关于python - 与 SubsetRandomSampler 有关,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/67250023/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com