gpt4 book ai didi

python - 如何从Pytorch中的高IO数据集读取,该数据集随着时间的推移而增长

转载 作者:行者123 更新时间:2023-12-03 16:15:32 25 4
gpt4 key购买 nike

我使用Tensorflow,但我是为用户编写文档,这些文档通常会在深度学习框架中有所不同。

当使用不适合本地文件系统(TB +)的数据集时,我从远程数据存储中采样数据,然后将采样示例本地写入Tensorflow标准tfrecords格式。

在训练的第一个时期,我将仅采样几个值,因此,一个局部数据的时期非常小,我对此进行了训练。在第2阶段,我重新检查采样子过程(现在有更多)产生了哪些数据文件,并在下一个阶段对扩展的本地数据文件集进行训练。每个时期重复该过程。这样,我可以建立样本的本地缓存,并在填满本地存储时可以驱逐较旧的样本。局部样本缓存大约在模型最需要方差时(朝训练的后半部分)增长。

在Python/Tensorflow中,至关重要的是,我不要在Python训练循环过程中反序列化数据,因为Python GIL无法支持数据传输速率(300-600 MB/秒,数据是原始的科学不可压缩的),因此不能保证GPU的性能当Python GIL无法快速服务于训练循环时,它会遭受苦难。

将样本从子进程(python多重处理)写入tfrecords文件中,从而允许tensorflow的 native TFRecordsDataset在Python之外进行反序列化,因此我们避开了Python GIL问题,并且我可以使具有高IO数据速率的GPU饱和。

I would like to know how I would address this issue in Pytorch. I'm writing about the sampling strategy that's being used, and want to provide specific recommendations to users of both Tensorflow and PyTorch, but I don't know the PyTorch preprocessing ecosystem well enough to write with sufficient detail.



旁注:支持这些数据传输速率的唯一纯基于Python的解决方案可能是带有System V共享内存和多处理功能的Python 3.8,但我还没有尝试过,因为对它的支持还不够(很快就可以了) )。现有的多处理解决方案是不够的,因为它们需要在训练循环过程中进行反序列化,从而在反序列化期间以高IO速率锁定GIL。

最佳答案

实际上,您可以使用torch.utils.data.DataLoader在子流程中轻松反序列化数据。通过将num_workers参数设置为1或更大的值,可以使用它们自己的python解释器和GIL生成子进程。

loader = torch.utils.data.DataLoader(your_dataset, num_workers=n, **kwargs)
for epoch in range(epochs):
for batch_idx, data in enumerate(loader):
# loader in the main process does not claim GIL at this point
Dataloader需要 torch.utils.data.Dataset才能从中获取数据。在您的情况下,实现适当的子类可能不是一件容易的事。如果您需要为每个纪元重新创建 Dataset实例,则可以执行以下操作。
for epcoh in range(epochs):
dset = get_new_dataset()
loader = torch.utils.data.DataLoader(dset, num_workers=n, **kwargs)
for batch_idx, data in enumerate(loader):
# Do training

甚至更好
dset = get_new_dataset()
loader = torch.utils.data.DataLoader(dset, num_workers=n, **kwargs)

for epcoh in range(epochs):
last_batch_idx = (len(dset)-1) // loader.batch_size
for batch_idx, data in enumerate(loader):
# Prepare next loader in advance to avoid blocking
if batch_idx == last_batch_idx:
dset = get_new_dataset()
loader = torch.utils.data.DataLoader(dset, num_workers=n, **kwargs)
# Do training

附带说明一下,请注意,在大多数情况下,受GIL影响的是受CPU约束的操作,而不是受I/O约束的操作,即 threading可用于任何纯粹的I/O繁重的操作,甚至不需要 subprocess。有关更多信息,请引用此 question和此Wikipedia article

关于python - 如何从Pytorch中的高IO数据集读取,该数据集随着时间的推移而增长,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60119934/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com