gpt4 book ai didi

python - 具有不同批量大小的pytorch恢复模型

转载 作者:行者123 更新时间:2023-11-28 19:01:16 25 4
gpt4 key购买 nike

我有一个关于如何重新加载不同批量大小的 pytorch 模型的问题。在训练中,我的批量大小是 64,但在推理中,我希望批量大小是 1(一个接一个地输入数据)。这是我用来保存和恢复模型的代码:

torch.save(agent.qnetwork_local.state_dict(), './ckpt/checkpoint.pth')
saved_model = QNetwork(state_size=37, action_size=4, seed=0)
saved_model.load_state_dict(torch.load('./ckpt/checkpoint.pth'))

我在运行推理模型时遇到了这个错误:

RuntimeError: size mismatch, m1: [37 x 1], m2: [37 x 64] at /Users/soumith/code/builder/wheel/pytorch-src/aten/src/TH/generic/THTensorMath.cpp:2070

这个错误意味着模型的输入必须是 37x64,其中 37 是数据维度,64 是训练批量大小。但测试输入为 37x1,这意味着数据维度为 37,批量大小为 1。

在reload pytorch model中有不同batch size的解决方案吗?非常感谢。

最佳答案

我最终设法在 DataLoader 中使用 batch_size=1 做到了这一点

import torch
import pandas as pd
from torch.utils.data.dataloader import DataLoader

df = pd.read_csv('data.csv')
df = df.values

# Use CustomDataset class for your data
inference_dataset = CustomDataset(x=df[:1, 0:2])

inference_dataloader = DataLoader(inference_dataset, batch_size=1, shuffle=False, num_workers=4)

#
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load('./model/model'))
model.eval()

for i, x in enumerate(inference_dataloader):
x = x.float()
y_pred = model(x)
print(y_pred.value)

关于python - 具有不同批量大小的pytorch恢复模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52425824/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com