gpt4 book ai didi

python - 使用 python onnxruntime 进行预测时出错

转载 作者:行者123 更新时间:2023-12-04 10:58:00 24 4
gpt4 key购买 nike

我使用 sklearn 创建了一个非常基本的决策树。图书馆。这棵树基于 4 个特征进行训练:

feat1 INT
feat2 INT
feat3 FLOAT
feat4 FLOAT

标签/目标特征是一个 bool 值(0 或 1)。

我将树转换成 ONNX格式,现在我想使用 onnxruntime python库进行预测。我在互联网上找到了示例代码来做到这一点。问题是我不明白这段代码、函数和参数的所有部分究竟发生了什么。这导致我收到错误。我确实搜索了一些文档,但我找不到这个。

在下面的代码中,我将树模型转换为 ONNX格式。这是成功的,但部分代码我不明白。在 initial_type变量,根据我之前提到的 4 个特征列和标签/目标特征,我必须在此处输入什么?现在我输入了 FloatTensorType([None, 4]因为我有 4 个特征列和什么是 None我不知道吗。
##Convert to ONNX format

initial_type = [('float_input', FloatTensorType([None, 4]))]
onx = convert_sklearn(treeModel, initial_types=initial_type)
with open("path", "wb") as f:
f.write(onx.SerializeToString())

在下面的代码中,我想使用 onnxruntime 进行预测库,但我收到此错误:
RuntimeError: Either type_proto was null or it was not of sequence type

这是因为我不明白下面的最后一行代码。我输入了这个 {input_name: [4, 8, 77.8, 143.45]因为这是特征列的四个值。我在这里做错了什么?
sess = rt.InferenceSession("pathToONNXModel")
input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name
pred_onx = sess.run([label_name], {input_name: [4, 8, 77.8, 143.45]})[0]

最佳答案

你试了吗{input_name: numpy.array([4, 8, 77.8, 143.45], dtype=numpy.float32)} ? onnxruntime 需要 numpy 数组作为输入。

关于python - 使用 python onnxruntime 进行预测时出错,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59047288/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com