gpt4 book ai didi

python - 使用导出/重新加载模型 : "Input type and weight type should be the same" 进行 fastai 错误预测

转载 作者:太空宇宙 更新时间:2023-11-03 15:31:45 25 4
gpt4 key购买 nike

每当我导出一个 fastai 模型并重新加载它时,当我尝试使用重新加载的模型在新的测试集上生成预测时,我会收到这个错误(或一个非常相似的错误):

RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.cuda.HalfTensor) should be the same

下面是最小的可重现代码示例,您只需将 FILES_DIR 变量更新为 MNIST 数据存放在您系统上的位置:

from fastai import *
from fastai.vision import *

# download data for reproduceable example
untar_data(URLs.MNIST_SAMPLE)
FILES_DIR = '/home/mepstein/.fastai/data/mnist_sample' # this is where command above deposits the MNIST data for me


# Create FastAI databunch for model training
tfms = get_transforms()
tr_val_databunch = ImageDataBunch.from_folder(path=FILES_DIR, # location of downloaded data shown in log of prev command
train = 'train',
valid_pct = 0.2,
ds_tfms = tfms).normalize()

# Create Model
conv_learner = cnn_learner(tr_val_databunch,
models.resnet34,
metrics=[error_rate]).to_fp16()

# Train Model
conv_learner.fit_one_cycle(4)

# Export Model
conv_learner.export() # saves model as 'export.pkl' in path associated with the learner

# Reload Model and use it for inference on new hold-out set
reloaded_model = load_learner(path = FILES_DIR,
test = ImageList.from_folder(path = f'{FILES_DIR}/valid'))

preds = reloaded_model.get_preds(ds_type=DatasetType.Test)

输出:

"RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.cuda.HalfTensor) should be the same"

逐个语句地执行代码,一切正常,直到最后一行 pred = ... 出现上面的 torch 错误。

相关软件版本:

python 3.7.3fastai 1.0.57
手电筒 1.2.0
torch 视觉 0.4.0

最佳答案

所以这个问题的答案最终变得相对简单:

1) 如我的评论所述,在混合精度模式下训练(设置 conv_learner to_fp16())导致导出/重新加载模型出错

2) 要在混合精度模式下训练(比常规训练更快)并启用模型的导出/重新加载而不会出错,只需在导出前将模型设置回默认精度即可。

...在代码中,只需更改上面的示例:

# Export Model
conv_learner.export()

到:

# Export Model (after converting back to default precision for safe export/reload
conv_learner = conv_learner.to_fp32()
conv_learner.export()

...现在上面的完整(可重现)代码示例运行时没有错误,包括模型重新加载后的预测。

关于python - 使用导出/重新加载模型 : "Input type and weight type should be the same" 进行 fastai 错误预测,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57618507/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com