gpt4 book ai didi

python-3.x - 如何在当前纪元期间加载/获取下一个纪元的下一批数据?

转载 作者:行者123 更新时间:2023-12-04 13:27:55 24 4
gpt4 key购买 nike

我知道从 PyTorch 1.7.0 开始就可以 prefetch部分批次 之前 一个时代开始了。但是,这无法在一个 epoch 内的操作正在执行时和下一个 epoch 开始之前获取批次。基于 this thread ,看来应该可以用Sampler在一个纪元期间和下一个纪元开始之前加载批次。但是,我无法理解如何使用采样器来实现这一点。
任何人都可以为允许在一个时期内获取样本的 Sampler 提供代码示例吗?

最佳答案

您可以在后台线程中从迭代器预取下一批。

class _ThreadedIterator(threading.Thread):
"""
Prefetch the next queue_length items from iterator in a background thread.

Example:
>> for i in bg_iterator(range(10)):
>> print(i)
"""

class _End:
pass

def __init__(self, generator: Iterable, maxsize: int) -> None:
threading.Thread.__init__(self)
self.queue: Queue = Queue(maxsize)
self.generator = generator
self.daemon = True
self.start()

def run(self) -> None:
for item in self.generator:
self.queue.put(item)
self.queue.put(self._End)

def __iter__(self) -> Any:
return self

def __next__(self) -> Any:
next_item = self.queue.get()
if next_item == self._End:
raise StopIteration
return next_item

# Required for Python 2.7 compatibility
def next(self) -> Any:
return self.__next__()


def bg_iterator(iterable: Iterable, maxsize: int) -> Any:
return _ThreadedIterator(iterable, maxsize=maxsize)
UPD .
用法:
model = model.to(device, non_blocking=True)
for inputs, targets in bg_iterator(data_loader, maxsize=2):
inputs = inputs.to(device, non_blocking=True)
targets = targets.to(device, non_blocking=True)
example

关于python-3.x - 如何在当前纪元期间加载/获取下一个纪元的下一批数据?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/67085517/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com