gpt4 book ai didi

python - 如何使用 WeightedRandomSampler 平衡 PyTorch 中的不平衡数据?

转载 作者:行者123 更新时间:2023-11-30 08:48:28 24 4
gpt4 key购买 nike

我有一个 2 类问题,并且我的数据不平衡。0 类有 232550 个样本,1 类有 13498 个样本。PyTorch 文档和互联网告诉我为我的 DataLoader 使用 WeightedRandomSampler 类。

我尝试使用 WeightedRandomSampler,但我不断收到错误。

    trainratio = np.bincount(trainset.labels) #trainset.labels is a list of 
float [0,1,0,0,0,...]
classcount = trainratio.tolist()
train_weights = 1./torch.tensor(classcount, dtype=torch.float)
train_sampleweights = train_weights[trainset.labels]
train_sampler = WeightedRandomSampler(weights=train_sampleweights,
num_samples=len(train_sampleweights))
trainloader = DataLoader(trainset, sampler=train_sampler,
shuffle=False)

我打印出来的一些尺寸:

train_weights = tensor([4.3002e-06, 4.3002e-06, 4.3002e-06,  ..., 
4.3002e-06, 4.3002e-06, 4.3002e-06])

train_weights shape= torch.Size([246048])

我不明白为什么会收到此错误:

UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
self.weights = torch.tensor(weights, dtype=torch.double)

我已经尝试过其他类似的解决方法,但到目前为止所有尝试都会产生一些错误。我应该如何实现这个来平衡我的训练、验证和测试数据?

最佳答案

显然这是一个内部警告,而不是错误。根据 PyTorch 人员的说法,我可以继续编码,而不必担心警告消息。

关于python - 如何使用 WeightedRandomSampler 平衡 PyTorch 中的不平衡数据?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54416773/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com