gpt4 book ai didi

python - 从 Python 中的 ONNX 模型获取预测

转载 作者:行者123 更新时间:2023-12-05 03:28:13 29 4
gpt4 key购买 nike

我找不到任何人向外行人解释如何将 onnx 模型加载到 python 脚本中,然后在输入图像时使用该模型进行预测。我能找到的只有这些代码行:

sess = rt.InferenceSession("onnx_model.onnx")
input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name
pred = sess.run([label_name], {input_name: X.astype(np.float32)})[0]

但是我不知道那是什么意思。我到处看,每个人似乎都已经知道他们的意思,所以没有人解释。如果我可以运行此代码,那将是一回事,但我不能。它给了我这个错误:

onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Invalid rank for input: Input3 Got: 2 Expected: 4 Please fix either the inputs or the model.

所以我需要真正知道这些东西的含义,这样我才能弄清楚如何修复错误。请懂行的人解释一下?

最佳答案

让我们首先检查您提供的代码,让一切都清楚。

sess = ort.InferenceSession("onnx_model.onnx")

此行将模型加载到 session 对象中。这意味着模型中使用的层、函数和权重已准备好执行推理。

input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name

get_inputsget_outputs 这两个方法分别检索有关模型的一些元信息,即模型期望的输入以及它可以提供的输出。在这些行中的元信息之外,实际上只使用了第一个输入和输出,而在这些之外,只有名称被获取并保存到变量中。
对于最后一行,让我们逐个部分地处理。

pred = sess.run(...)[0]

这会对模型执行推理,之后我们将检查此方法的输入,但目前,输出是不同输出的列表。这些输出都是每个 numpy 数组。在这种情况下,仅使用此列表中的第一个输出,并将其保存到 pred 变量

([label_name], {input_name: X.astype(np.float32)})

这些是 sess.run 的输入。第一个是您希望 session 计算的输出名称列表。第二个参数是一个字典,其中每个输入的名称映射到 numpy 数组。这些数组应与模型创建期间提供的数组具有相同的维度。同样,这些数组的类型也应与创建模型期间使用的类型相匹配。

您遇到的错误似乎表明提供的数组没有预期的维度。这些预期的维度数量似乎是 4。
为了清楚地了解输入数组的确切形状和数据类型应该是什么,可以使用可视化工具,例如 Netron

关于python - 从 Python 中的 ONNX 模型获取预测,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/71279968/

29 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com