gpt4 book ai didi

c - 通过 C API 访问 tensorflow 2.0 SavedModel 的输入和输出张量

转载 作者:行者123 更新时间:2023-12-02 00:11:14 40 4
gpt4 key购买 nike

我无法从加载了 C_API 的 tensorflow 2.0 SavedModel 运行推理,因为我无法通过名称访问输入和输出操作。

我通过 TF_LoadSessionFromSavedModel(...) 成功加载 session :

#include <tensorflow/c/c_api>

...

TF_Status* status = TF_NewStatus();
TF_Graph* graph = TF_NewGraph();
TF_Buffer* r_opts = TF_NewBufferFromString("",0);
TF_Buffer* meta_g = TF_NewBuffer();

TF_SessionOptions* opts = TF_NewSessionOptions();
const char* tags[] = {"serve"};

TF_Session* session = TF_LoadSessionFromSavedModel(opts, r_opts, "saved_model/tf2_model", tags, 1, graph, meta_g, status);

if ( TF_GetCode(status) != TF_OK ) exit(-1); //does not happen

但是,当我尝试使用以下方法设置输入和输出张量时出现错误:

TF_Operation* inputOp  = TF_GraphOperationByName(graph, "input"); //works with "serving_default_input"
TF_Operation* outputOp = TF_GraphOperationByName(graph, "prediction"); //does not work

我作为参数传递的名称被分配给已保存模型的输入和输出 keras 层,但不在加载的 graph 中。运行 saved_model_cli(按照 tf SavedModel 教程 here )显示具有这些名称的男高音存在于 SignatureDef serving_default 下,所以我猜我需要将 serving_default 实例化为一个图形(换句话说,根据签名创建一个图形),但是我找不到使用 C API 执行此操作的方法。

注意tensorflows的C_API test使用 C++ tensorflow/core/功能从元图中加载签名定义映射并使用它来查找输入和输出操作名称,但我想避免对 C++ 的依赖。

另请注意,按名称访问操作适用于卡住的 .pb 图,但这种格式已被弃用。

提前感谢您的任何想法和提示!

最佳答案

目前(截至 2020 年 5 月)Tensorflow C API 并未正式支持 SavedModel (tensorflow 2.0) 格式,尽管他们可能会发布功能 soon .

无论如何,您可以使用导出模型时定义的默认 SignatureDef,并使用 saved_model_cli 查找输入和输出张量的名称工具。

假设您使用

保存了您的模型
model.save('/path/to/model/folder')

然后你打开一个 bash 并做

cd /python/folder/bin/
saved_model_cli show --dir /path/to/model/folder --tag_set serve --signature_def serving_default

(saved_model_cli的实际位置各不相同,但使用anaconda时默认安装在bin/文件夹下)

默认情况下它会产生如下内容:

serving_default
The given SavedModel SignatureDef contains the following input(s):
inputs['graph_input'] tensor_info:
dtype: DT_DOUBLE
shape: (-1, 28, 28)
name: serving_default_graph_input:0
The given SavedModel SignatureDef contains the following output(s):
outputs['graph_output'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 10)
name: StatefulPartitionedCall:0
Method name is: tensorflow/serving/predict

在这种情况下,serving_default_graph_input 是输入张量名称,StatefulPartitionedCall 是输出张量名称。然后,您可以使用 TF_GraphOperationByName() 加载它们。

有了 Tensorflow 2 的 C API 支持,您将能够 save the model with a set of defined SignatureDefs然后加载所需的 concrete_function(),而不必担心张量名称。然而,当前的方法应该仍然有效。

关于c - 通过 C API 访问 tensorflow 2.0 SavedModel 的输入和输出张量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58968918/

40 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com