gpt4 book ai didi

c++ - 从 C++ 中的 Tensorflow 的 .meta 文件加载图形以进行推理

转载 作者:搜寻专家 更新时间:2023-10-31 00:27:30 25 4
gpt4 key购买 nike

我已经使用 tensorflow 1.5.1 训练了一些模型,并且我有这些模型的检查点(包括 .ckpt 和 .meta 文件)。现在我想使用这些文件在 C++ 中进行推理。

在 python 中,我会执行以下操作来保存和加载图形和检查点。保存:

    images = tf.placeholder(...) // the input layer
//the graph def
output = tf.nn.softmax(net) // the output layer
tf.add_to_collection('images', images)
tf.add_to_collection('output', output)

为了推理,我恢复图和检查点,然后从集合中恢复输入和输出层,如下所示:

    meta_file = './models/last-100.meta'
ckpt_file = './models/last-100'
with tf.Session() as sess:
saver = tf.train.import_meta_graph(meta_file)
saver.restore(sess, ckpt_file)
images = tf.get_collection('images')
output = tf.get_collection('output')
outputTensors = sess.run(output, feed_dict={images: np.array(an_image)})

现在假设我像往常一样在 python 中进行保存,我如何使用像在 python 中一样的简单代码在 c++ 中进行推理和恢复?

我找到了示例和教程,但对于 tensorflow 版本 0.7 0.12 和相同的代码不适用于版本 1.5。我在 tensorflow 网站上没有找到使用 c++ API 恢复模型的教程。

最佳答案

为了这个thread .我会将我的评论改写为答案。

发布完整示例需要 CMake 设置或将文件放入特定目录以运行 bazel。因为我确实赞成第一种方式,它会突破这篇文章的所有限制以涵盖我想重定向到 complete implementation in C99, C++, GO without Bazel 的所有部分。我针对 TF > v1.5 进行了测试。

在 C++ 中加载图形并不比在 Python 中困难多少,鉴于您已经从源代码编译了 TensorFlow。

从创建一个 MWE 开始,它创建一个非常转储的网络图总是一个弄清楚事情是如何工作的好主意:

import tensorflow as tf

x = tf.placeholder(tf.float32, shape=[1, 2], name='input')
output = tf.identity(tf.layers.dense(x, 1), name='output')

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(tf.global_variables())
saver.save(sess, './exported/my_model')

关于这部分,SO 上可能有很多答案。所以我就让它留在这里,没有进一步的解释。

在 Python 中加载

在用其他语言做事之前,我们可以尝试在 python 中正确地做——从某种意义上说:我们只需要用 C++ 重写它。即使在 python 中恢复也非常容易,例如:

import tensorflow as tf

with tf.Session() as sess:

# load the computation graph
loader = tf.train.import_meta_graph('./exported/my_model.meta')
sess.run(tf.global_variables_initializer())
loader = loader.restore(sess, './exported/my_model')

x = tf.get_default_graph().get_tensor_by_name('input:0')
output = tf.get_default_graph().get_tensor_by_name('output:0')

它没有帮助,因为大多数这些 API 端点在 C++ API 中不存在(还?)。另一个版本是

import tensorflow as tf

with tf.Session() as sess:

metaGraph = tf.train.import_meta_graph('./exported/my_model.meta')
restore_op_name = metaGraph.as_saver_def().restore_op_name
restore_op = tf.get_default_graph().get_operation_by_name(restore_op_name)
filename_tensor_name = metaGraph.as_saver_def().filename_tensor_name
sess.run(restore_op, {filename_tensor_name: './exported/my_model'})


x = tf.get_default_graph().get_tensor_by_name('input:0')
output = tf.get_default_graph().get_tensor_by_name('output:0')

等一下。您始终可以使用 print(dir(object)) 来获取 restore_op_name 等属性,...。恢复模型是 TensorFlow 中的一项操作,就像其他所有操作一样。我们只是调用此操作并提供路径(字符串张量)作为输入。我们甚至可以编写自己的恢复操作

def restore(sess, metaGraph, fn):
restore_op_name = metaGraph.as_saver_def().restore_op_name # u'save/restore_all'
restore_op = tf.get_default_graph().get_operation_by_name(restore_op_name)
filename_tensor_name = metaGraph.as_saver_def().filename_tensor_name # u'save/Const'
sess.run(restore_op, {filename_tensor_name: fn})

即使这看起来很奇怪,现在用 C++ 做同样的事情也有很大帮助。

在 C++ 中加载

从平常的事情开始

#include <tensorflow/core/public/session.h>
#include <tensorflow/core/public/session_options.h>
#include <tensorflow/core/protobuf/meta_graph.pb.h>
#include <string>
#include <iostream>

typedef std::vector<std::pair<std::string, tensorflow::Tensor>> tensor_dict;

int main(int argc, char const *argv[]) {

const std::string graph_fn = "./exported/my_model.meta";
const std::string checkpoint_fn = "./exported/my_model";

// prepare session
tensorflow::Session *sess;
tensorflow::SessionOptions options;
TF_CHECK_OK(tensorflow::NewSession(options, &sess));

// here we will put our loading of the graph and weights

return 0;
}

您应该能够通过将其放入 TensorFlow 存储库并使用 bazel 或简单地按照说明进行编译 here使用 CMake。

我们需要创建这样一个由 tf.train.import_meta_graph 创建的 meta_graph。这可以通过

tensorflow::MetaGraphDef graph_def;
TF_CHECK_OK(ReadBinaryProto(tensorflow::Env::Default(), graph_fn, &graph_def));

在 C++ 中,从文件中读取图形与在 Python 中导入图形不同。我们需要通过以下方式在 session 中创建此图

TF_CHECK_OK(sess->Create(graph_def.graph_def()));

通过查看上面奇怪的 python restore 函数:

restore_op_name = metaGraph.as_saver_def().restore_op_name
restore_op = tf.get_default_graph().get_operation_by_name(restore_op_name)
filename_tensor_name = metaGraph.as_saver_def().filename_tensor_name

我们可以用 C++ 编写等效的代码

const std::string restore_op_name = graph_def.saver_def().restore_op_name()
const std::string filename_tensor_name = graph_def.saver_def().filename_tensor_name()

准备就绪后,我们只需运行操作

sess->Run(feed_dict,     // inputs
{}, // output_tensor_names (we do not need them)
{restore_op}, // target_node_names
nullptr) // outputs (there are no outputs this time)

创建 feed_dict 可能是一个独立的帖子,这个答案已经足够长了。它只涵盖最重要的内容。我想重定向到 complete implementation in C99, C++, GO without Bazel我测试了 TF > v1.5。这并不难——在 plain C version 的情况下它可能会变得很长.

关于c++ - 从 C++ 中的 Tensorflow 的 .meta 文件加载图形以进行推理,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48841364/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com