gpt4 book ai didi

neural-network - Pytorch - 如何使用加权随机采样器进行欠采样

转载 作者:行者123 更新时间:2023-12-05 08:31:03 25 4
gpt4 key购买 nike

我有一个不平衡的数据集,想对过度代表的类别进行欠采样。我该怎么做。我想使用 weightedrandomsampler,但我也愿意接受其他建议。

到目前为止,我假设我的代码的结构必须类似于以下内容。但我不知道如何准确地做到这一点。


trainset = datasets.ImageFolder(path_train,transform=transform)
...
sampler = data.WeightedRandomSampler(weights=..., num_samples=..., replacement=...)
...
trainloader = data.DataLoader(trainset, batchsize = batchsize, sampler=sampler)

我希望有人能提供帮助。非常感谢

最佳答案

根据我的理解,pytorch WeightedRandomSampler 'weights' 参数有点类似于 numpy.random.choice 'p' 参数,它是随机选择样本的概率。 Pytorch 使用权重来随机抽样训练示例,他们在文档中声明权重不必总和为 1,所以这就是我的意思,它与 numpy 的随机选择不完全一样。权重越大,该样本被采样的可能性就越大。

当您设置 replacement=True 时,这意味着可以多次绘制训练示例,这意味着您可以在训练集中拥有用于训练模型的训练示例副本;过采样。同时,如果与其他训练样本权重相比权重较低,则情况相反,这意味着这些样本被选中进行随机抽样的机会较低;欠采样。

我不知道 num_samples 参数在与 train loader 一起使用时如何工作,但我可以警告您不要将批量大小放在那里。今天,我尝试输入批量大小,但结果很糟糕。我的同事把类(class)数*100,他的结果好多了。我所知道的是你不应该把批量大小放在那里。我还尝试将所有训练数据的大小放入 num_samples 中,结果更好,但训练时间很长。无论哪种方式,请尝试一下,看看哪种方式最适合您。我猜想安全的做法是使用训练示例的数量作为 num_samples 参数。

这是我看到其他人使用的示例,我也将其用于二进制分类。它似乎工作得很好。您取每个类别的训练示例数量的倒数,并使用该类别的各自权重设置所有训练示例。

使用您的训练集对象的简单示例

labels = np.array(trainset.samples)[:,1] # 转到数组并取所有索引为 1 的列

labels = labels.astype(int) # 改为 int

majority_weight = 1/num_of_majority_class_training_examples

minority_weight = 1/num_of_minority_class_training_examples

sample_weights = np.array([majority_weight, minority_weight]) # 这是假设你的少数类是标签对象中的整数 1。如果不是,请交换位置,使其成为 minority_weight、majority_weight。

weights = samples_weights[labels] # 这遍历每个训练示例并使用标签 0 和 1 作为 sample_weights 对象中的索引,这是您想要的该类的权重。

sampler = WeightedRandomSampler(weights=weights, num_samples=, replacement=True)

trainloader = data.DataLoader(trainset, batchsize = batchsize, sampler=sampler)

由于 pytorch 文档说权重总和不必为 1,我认为您也可以只使用不平衡类之间的比率。例如,如果您有 100 个多数类训练示例和 50 个少数类训练示例,则比例为 2:1。为了平衡这一点,我认为您可以为每个多数类训练示例使用 1.0 的权重,为所有少数类训练示例使用 2.0 的权重,因为从技术上讲,您希望少数类被选中的可能性增加 2 倍,这将平衡您的随机选择期间的类(class)。

希望对您有所帮助。抱歉草率的写作,我很匆忙,看到没有人回答。我自己也在努力解决这个问题,但也找不到任何帮助。如果它没有意义就这么说,我会重新编辑它并在我有空的时候让它更清楚。

关于neural-network - Pytorch - 如何使用加权随机采样器进行欠采样,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60320232/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com