gpt4 book ai didi

pytorch - 如何保存 PyTorch 的 DataLoader 实例?

转载 作者:行者123 更新时间:2023-12-04 02:37:30 33 4
gpt4 key购买 nike

我想保存 PyTorch 的 torch.utils.data.dataloader.DataLoader例如,这样我就可以从中断的地方继续训练(保持随机种子、状态和所有内容)。

最佳答案

您需要采样器的自定义实现。
可以从以下地址轻松使用:https://gist.github.com/usamec/1b3b4dcbafad2d58faa71a9633eea6a5
您可以像这样保存和恢复:

sampler = ResumableRandomSampler(dataset)
loader = torch.utils.data.DataLoader(dataset, batch_size=2, sampler=sampler, pin_memory=True)

for x in loader:
print(x)
break

sampler2 = ResumableRandomSampler(dataset)
torch.save(sampler.get_state(), "test_samp.pth")
sampler2.set_state(torch.load("test_samp.pth"))
loader2 = torch.utils.data.DataLoader(dataset, batch_size=2, sampler=sampler2, pin_memory=True)

for x in loader2:
print(x)

关于pytorch - 如何保存 PyTorch 的 DataLoader 实例?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60993677/

33 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com