gpt4 book ai didi

tensorflow - tf.data.experimental.sample_from_datasets 的 PyTorch 替代方案

转载 作者:行者123 更新时间:2023-12-04 08:27:11 26 4
gpt4 key购买 nike

假设我有两个数据集,数据集一有 100 个项目,数据集二有 5000 个项目。
现在我希望在训练期间,我的模型从数据集 1 中看到的项目与数据集 2 中的项目一样多。
在 Tensorflow 中,我可以做到:

dataset = tf.data.experimental.sample_from_datasets(
[dataset_one, dataset_two], weights=[50,1], seed=None
)
PyTorch 中是否有替代方案可以做到这一点?
我认为通过创建自定义数据集(非工作示例)来实现这并不太难
from torch.utils.data import Dataset

class SampleDataset(Dataset):
def __init__(self, datasets, weights):
self.datasets = datasets
self.weights = weights

def __len__(self):
return sum([len(dataset) for dataset in self.datasets])

def __getitem__(self, idx):
# sample a random number and based on that sample an item

return self.datasets[dataset_idx][sample_idx]
然而,这似乎很常见。已经有这样的东西了吗?

最佳答案

我认为 PyTorch 中没有直接的等价物。
然而,有一个函数叫做 torch.utils.data.WeightedRandomSampler 它根据概率列表对索引进行采样。您可以将此与 torch.data.utils.ConcatDataset 结合使用和 torch.utils.data.DataLoader sampler选项。
我将举一个有两个数据集的例子:SetA有 500 个元素和 SetB只有 10 个。
首先,您可以使用 ConcaDataset 创建所有数据集的串联。 :

ds = ConcatDataset([SetA(), SetB()])
然后,我们需要对其进行采样。问题是,你不能只给 WeightedRandomSampler [50, 1] ,就像你在 Tensorflow 中所做的那样。作为一种解决方法,您可以创建一个与总数据集大小相同长度的概率列表。
这个例子对应的概率列表是:
dist = np.array([1/51]*500 + [50/51]*10)
本质上,前 500 个指数(即“指向”到 SetA 的指数)将有 1/51 的概率被选中,而接下来的 10 个指数(即 SetB 中的指数)将有 50/51 的概率(即更有可能被采样,因为 SetB 中的元素较少,这是想要的结果!)
我们可以从该分布创建一个采样器:
WeightedRandomSampler(dist, 10)
其中 10 是采样元素的数量。我会放置最小数据集的大小,否则您可能会在同一时期多次查看相同的数据点......
最后,我们只需要使用我们的数据集和采样器实例化数据加载器:
dl = DataLoader(ds, sampler=sampler)
总结一下:
ds = ConcatDataset([SetA(), SetB()])
dist = np.array([1/51]*500 + [50/51]*10)
sampler = WeightedRandomSampler(dist, 10)
dl = DataLoader(ds, sampler=sampler)

编辑 ,对于任意数量的数据集:
sets = [SetA(), SetB(), SetC()]
ds = ConcatDataset(sets)

dist = np.concatenate([[(len(ds) - len(s))/len(ds)]*len(s) for s in sets])
sampler = WeightedRandomSampler(weights=dist, num_samplesmin([len(s) for s in sets])
dl = DataLoader(ds, sampler=sampler)

关于tensorflow - tf.data.experimental.sample_from_datasets 的 PyTorch 替代方案,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/65205801/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com