gpt4 book ai didi

python - 如何在 PyTorch 中平衡(过采样)不平衡数据(使用 WeightedRandomSampler)?

转载 作者:太空宇宙 更新时间:2023-11-04 04:21:34 27 4
gpt4 key购买 nike

我有一个 2 类问题,我的数据非常不平衡。我有来自一类的 232550 个样本和来自第二类的 13498。 PyTorch 文档和互联网告诉我为我的 DataLoader 使用类 WeightedRandomSampler。

我已尝试使用 WeightedRandomSampler,但我一直收到错误。

    trainratio = np.bincount(trainset.labels)
classcount = trainratio.tolist()
train_weights = 1./torch.tensor(classcount, dtype=torch.float)
train_sampleweights = train_weights[trainset.labels]
train_sampler = WeightedRandomSampler(weights=train_sampleweights,
num_samples = len(train_sampleweights))
trainloader = DataLoader(trainset, sampler=train_sampler,
shuffle=False)

我不明白为什么在初始化 WeightedRandomSampler 类时出现此错误?

我已经尝试过其他类似的解决方法,但到目前为止所有尝试都会产生一些错误。我应该如何实现它来平衡我的训练、验证和测试数据?

当前出现此错误:

train__sampleweights = train_weights[trainset.labels] ValueError: too many dimensions 'str'

最佳答案

问题出在trainset.labels的类型上要修复错误,可以将 trainset.labels 转换为 float

关于python - 如何在 PyTorch 中平衡(过采样)不平衡数据(使用 WeightedRandomSampler)?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54415345/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com