gpt4 book ai didi

pytorch - 如何使用 PyTorch 的 DataLoader 确保批处理包含来自所有工作人员的样本?

转载 作者:行者123 更新时间:2023-12-03 17:13:19 26 4
gpt4 key购买 nike

我想知道怎么用torch.utils.data.DataLoader在 PyTorch 中,尤其是在多 worker 情况下。

我发现来自 DataLoader 的一批输出总是来自一个 worker 。
我预计 DataLoader 中有一个队列,用于存储来自所有工作人员的数据,而 DataLoader 将它们在队列中打乱以输出随机批处理数据。我认为这是tf.data.Dataset的方式在 tensorflow 中。
我们可以在 PyTorch 中实现类似的功能吗?我想通过使用多个 worker 从大序列化文件(如 Tfrecord )加载数据集。在这种情况下,在一批中混合源文件,这意味着混合工作器的源,很重要。

请引用以下代码:

import random
import time

import torch


class MyDataset(torch.utils.data.Dataset):
def __len__(self):
return 50

def __getitem__(self, idx):
info = torch.utils.data.get_worker_info()

time.sleep(random.uniform(0, 1))
print("[{}]:{}".format(info.id, idx))
return idx, info.id


if __name__ == '__main__':
dataset = MyDataset()
dataloader = torch.utils.data.DataLoader(dataset, batch_size=5, shuffle=False, num_workers=2)
for batch in dataloader:
print(batch)

输出:
[0]:0
[1]:5
[0]:1
[1]:6
[0]:2
[0]:3
[1]:7
[0]:4
[tensor([0, 1, 2, 3, 4]), tensor([0, 0, 0, 0, 0])]
[1]:8
[1]:9
[tensor([5, 6, 7, 8, 9]), tensor([1, 1, 1, 1, 1])]
[0]:10
[0]:11
[1]:15
[1]:16
[0]:12
[1]:17
...

在这里, [0, 1, 2, 3, 4][0, 0, 0, 0, 0][tensor([0, 1, 2, 3, 4]), tensor([0, 0, 0, 0, 0])]表示该批次包括来自 worker ID 0 的索引 0 到 4 数据.
请注意 shuffle=True不能解决这个问题,它只会改变数据的索引。

在这种情况下,我想得到一个批次: [tensor([0, 5, 1, 6, 2]), tensor([0, 1, 0, 1, 0])] .

最佳答案

我已经实现了一些简单的方法来解决类似的问题,我将大型视频文件作为训练数据,每个工作人员负责加载和预处理单个文件,然后从中生成样本。问题在于,正如 OP 所描述的那样,使用 Pytorch 的默认数据加载机制,每个批次仅包含来自单个视频文件的样本。

首先,让我们回顾一下问题。在这个简化的代码示例中,每个工作人员产生一个包含其零索引工作人员 ID 的张量。批量大小为 32 和 4 个 worker 时,我们希望每个批次包含 8 个零、8 个 1、8 个二和 8 个三。

from collections import defaultdict

import torch as T
import torch.utils.data as tdata


class Dataset(tdata.IterableDataset):
def __init__(self, batch_size: int):
self._bs = batch_size

def __iter__(self):
worker_info = tdata.get_worker_info()
if not worker_info:
raise NotImplementedError('Not implemented for num_workers=0')
for _ in range(self._bs):
yield T.tensor([worker_info.id])


batch_size = 32
num_workers = 4
dataset = Dataset(batch_size)
loader = tdata.DataLoader(dataset,
batch_size=batch_size,
num_workers=num_workers)


for batch in loader:
counts = defaultdict(int)
for n in batch.numpy().flatten():
counts[n] += 1
print(dict(counts))

相反,代码打印:
{0: 32}
{1: 32}
{2: 32}
{3: 32}

这意味着第一批只包含 worker 0 的样本,第二批只包含 worker 1 的样本,等等。为了解决这个问题,我们将在 DataLoader 中设置批大小。至 batch_size // num_workers并在 DataLoader 上使用一个简单的包装器为我们的批次汇集每个 worker 的样本:

def pooled_batches(loader):
loader_it = iter(loader)
while True:
samples = []
for _ in range(loader.num_workers):
try:
samples.append(next(loader_it))
except StopIteration:
pass
if len(samples) == 0:
break
else:
yield T.cat(samples, dim=0)


batch_size = 32
num_workers = 4
dataset = Dataset(batch_size)
per_worker = batch_size // num_workers
loader = tdata.DataLoader(dataset,
batch_size=per_worker,
num_workers=num_workers)

for batch in pooled_batches(loader):
counts = defaultdict(int)
for n in batch.numpy().flatten():
counts[n] += 1
print(dict(counts))

代码现在打印
{0: 8, 1: 8, 2: 8, 3: 8}
{0: 8, 1: 8, 2: 8, 3: 8}
{0: 8, 1: 8, 2: 8, 3: 8}
{0: 8, 1: 8, 2: 8, 3: 8}

正如预期的那样。

关于pytorch - 如何使用 PyTorch 的 DataLoader 确保批处理包含来自所有工作人员的样本?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57729279/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com