gpt4 book ai didi

c++ - iOS : "Invalid argument: Session was not created with a graph before Run()!" 上的 TensorFlow C++ 推理错误

转载 作者:可可西里 更新时间:2023-11-01 05:42:57 25 4
gpt4 key购买 nike

我正在尝试使用 TensorFlow 的 C++ API 在 iOS 上运行我的模型。型号是SavedModel保存为 .pb文件。但是,请调用 Session::Run()导致错误:

"Invalid argument: Session was not created with a graph before Run()!"



在 Python 中,我可以使用以下代码在模型上成功运行推理:
with tf.Session() as sess:
tf.saved_model.loader.load(sess, ['serve'], '/path/to/model/export')
result = sess.run(['OutputTensorA:0', 'OutputTensorB:0'], feed_dict={
'InputTensorA:0': np.array([5000.00] * 1000).reshape(1, 1000),
'InputTensorB:0': np.array([300.00] * 1000).reshape(1, 1000)
})
print(result[0])
print(result[1])

在 iOS 上的 C++ 中,我尝试模仿这个工作片段如下:
tensorflow::Input::Initializer input_a(5000.00, tensorflow::TensorShape({1, 1000}));
tensorflow::Input::Initializer input_b(300.00, tensorflow::TensorShape({1, 1000}));

tensorflow::Session* session_pointer = nullptr;

tensorflow::SessionOptions options;
tensorflow::Status session_status = tensorflow::NewSession(options, &session_pointer);

std::cout << session_status.ToString() << std::endl; // prints OK

std::unique_ptr<tensorflow::Session> session(session_pointer);

tensorflow::GraphDef model_graph;

NSString* model_path = FilePathForResourceName(@"saved_model", @"pb");
PortableReadFileToProto([model_path UTF8String], &model_graph);

tensorflow::Status session_init = session->Create(model_graph);

std::cout << session_init.ToString() << std::endl; // prints OK

std::vector<tensorflow::Tensor> outputs;
tensorflow::Status session_run = session->Run({{"InputTensorA:0", input_a.tensor}, {"InputTensorB:0", input_b.tensor}}, {"OutputTensorA:0", "OutputTensorB:0"}, {}, &outputs);

std::cout << session_run.ToString() << std::endl; // Invalid argument: Session was not created with a graph before Run()!

方法 FilePathForResourceNamePortableReadFileToProto取自 TensorFlow iOS 示例 here .

问题是什么?我注意到无论模型有多简单( see my issue report on GitHub )都会发生这种情况,这意味着问题不在于模型的细节。

最佳答案

这里的主要问题是您将图形导出到 SavedModel 在 Python 中,然后将其作为 GraphDef 读入在 C++ 中。虽然两者都有 .pb扩展和相似,它们不等价。

正在发生的事情是您正在阅读 SavedModelPortableReadFileToProto()它失败了,留下一个空指针( model_graph )到 GraphDef目的。所以在执行 PortableReadFileToProto() 之后, model_graph仍然是空的,但有效,GraphDef ,这就是为什么错误说 Session 不是在 Run() 之前用图形创建的。 session->Create()成功是因为您成功地创建了一个带有空图的 session 。

检查是否PortableReadFileToProto()的方法失败是检查其返回值。它返回一个 bool 值,如果读取图形失败,它将为 0。如果您想在此处获得描述性错误,请使用 ReadBinaryProto() .判断读取图形是否失败的另一种方法是检查 model_graph.node_size() 的值。 .如果这是 0,那么您有一个空图并且读取它失败。

虽然您可以使用 TensorFlow 的 C API 在 SavedModel 上执行推理通过使用 TF_LoadSessionFromSavedModel() TF_SessionRun() ,推荐的方法是使用 freeze_graph.py 将图形导出到卡住模型。或写信给 GraphDef使用 tf.train.write_graph() .我将演示使用 tf.train.write_graph() 导出的模型的成功推理:

在 Python 中:

# Build graph, call it g
g = tf.Graph()

with g.as_default():
input_tensor_a = tf.placeholder(dtype=tf.int32, name="InputTensorA")
input_tensor_b = tf.placeholder(dtype=tf.int32, name="InputTensorB")
output_tensor_a = tf.stack([input_tensor_a], name="OutputTensorA")
output_tensor_b = tf.stack([input_tensor_b], name="OutputTensorB")

# Save graph g
with tf.Session(graph=g) as sess:
sess.run(tf.global_variables_initializer())
tf.train.write_graph(
graph_or_graph_def=sess.graph_def,
logdir='/path/to/export',
name='saved_model.pb',
as_text=False
)

在 C++ (Xcode) 中:
using namespace tensorflow;
using namespace std;

NSMutableArray* predictions = [NSMutableArray array];

Input::Initializer input_tensor_a(1, TensorShape({1}));
Input::Initializer input_tensor_b(2, TensorShape({1}));

SessionOptions options;
Session* session_pointer = nullptr;
Status session_status = NewSession(options, &session_pointer);
unique_ptr<Session> session(session_pointer);

GraphDef model_graph;
string model_path = string([FilePathForResourceName(@"saved_model", @"pb") UTF8String]);

Status load_graph = ReadBinaryProto(Env::Default(), model_path, &model_graph);

Status session_init = session->Create(model_graph);

cout << "Session creation Status: " << session_init.ToString() << endl;
cout << "Number of nodes in model_graph: " << model_graph.node_size() << endl;
cout << "Load graph Status: " << load_graph.ToString() << endl;

vector<pair<string, Tensor>> feed_dict = {
{"InputTensorA:0", input_tensor_a.tensor},
{"InputTensorB:0", input_tensor_b.tensor}
};

vector<Tensor> outputs;
Status session_run = session->Run(feed_dict, {"OutputTensorA:0", "OutputTensorB:0"}, {}, &outputs);

[predictions addObject:outputs[0].scalar<int>()];
[predictions addObject:outputs[1].scalar<int>()];

Status session_close = session->Close();

这种通用方法可行,但您可能会遇到构建的 TensorFlow 库中缺少所需操作的问题,因此推理仍然会失败。为了解决这个问题,首先确保你已经构建了最新的 TensorFlow 1.3 通过在您的机器上克隆 repo 并运行 tensorflow/contrib/makefile/build_all_ios.sh 从根 tensorflow-1.3.0目录。如果您使用 TensorFlow-experimental ,推理不太可能适用于自定义的非固定模型。 Pod 喜欢这些例子。一旦你使用 build_all_ios.sh 构建了一个静态库,您需要在您的 .xcconfig 中链接它按照说明操作 here .

一旦您成功地将使用 makefile 构建的静态库与 Xcode 链接起来,您可能仍然会遇到阻止推理的错误。虽然您将获得的实际错误取决于您的实现,但您将收到分为两种不同形式的错误:
  • OpKernel ('op: "[operation]" device_type: "CPU"') for unknown op: [operation]

  • No OpKernel was registered to support Op '[operation]' with these attrs. Registered devices: [CPU], Registered kernels: [...]


  • 错误 #1 表示 .cc文件来自 tensorflow/core/opstensorflow/core/kernels对应的操作(或密切相关的操作)不在 tf_op_files.txt 中文件在 tensorflow/contrib/makefile .您必须找到 .cc包含 REGISTER_OP("YourOperation")并将其添加到 tf_op_files.txt .您必须通过运行 tensorflow/contrib/makefile/build_all_ios.sh 来重建再次。

    错误 #2 意味着 .cc相应操作的文件在您的 tf_op_files.txt 中文件,但是您为操作提供了它 (a) 不支持或 (b) 被剥离以减小构建大小的数据类型。

    一个“问题”是,如果您使用 tf.float64在你的模型的实现中,这被导出为 TF_DOUBLE在您的 .pb文件,大多数操作不支持此功能。使用 tf.float32代替 tf.float64然后使用 tf.train.write_graph() 重新保存您的模型.

    如果在检查为操作提供了正确的数据类型后仍然收到错误 #2,则需要删除 __ANDROID_TYPES_SLIM__在位于 tensorflow/contrib/makefile 的 makefile 中或替换为 __ANDROID_TYPES_FULL__然后重建。

    在通过错误 #1 和 #2 之后,您可能会成功推理。

    关于c++ - iOS : "Invalid argument: Session was not created with a graph before Run()!" 上的 TensorFlow C++ 推理错误,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46201109/

    25 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com