gpt4 book ai didi

tensorflow - 使用 TensorFlow Benchmark 对 Keras 模型进行基准测试

转载 作者:行者123 更新时间:2023-12-01 12:21:54 27 4
gpt4 key购买 nike

我正在尝试使用 TensorFlow 后端对我的 Keras 模型构建的推理阶段的性能进行基准测试。我在想 Tensorflow Benchmark工具是正确的方法。

我已经设法使用 tensorflow_inception_graph.pb 在桌面上构建和运行示例一切似乎都很好。

我似乎无法弄清楚如何将 Keras 模型保存为正确的 .pb模型。我能够从 Keras 模型中获取 TensorFlow Graph,如下所示:

import keras.backend as K
K.set_learning_phase(0)

trained_model = function_that_returns_compiled_model()
sess = K.get_session()
sess.graph # This works

# Get the input tensor name for TF Benchmark
trained_model.input
> <tf.Tensor 'input_1:0' shape=(?, 360, 480, 3) dtype=float32>

# Get the output tensor name for TF Benchmark
trained_model.output
> <tf.Tensor 'reshape_2/Reshape:0' shape=(?, 360, 480, 12) dtype=float32>

我现在一直在尝试以几种不同的方式保存模型。
import tensorflow as tf
from tensorflow.contrib.session_bundle import exporter

model = trained_model
export_path = "path/to/folder" # where to save the exported graph
export_version = 1 # version number (integer)

saver = tf.train.Saver(sharded=True)
model_exporter = exporter.Exporter(saver)
signature = exporter.classification_signature(input_tensor=model.input, scores_tensor=model.output)
model_exporter.init(sess.graph.as_graph_def(), default_graph_signature=signature)
model_exporter.export(export_path, tf.constant(export_version), sess)

这会产生一个文件夹,其中包含一些我不知道如何处理的文件。

我现在会用这样的东西运行基准工具
bazel-bin/tensorflow/tools/benchmark/benchmark_model \
--graph=tensorflow/tools/benchmark/what_file.pb \
--input_layer="input_1:0" \
--input_layer_shape="1,360,480,3" \
--input_layer_type="float" \
--output_layer="reshape_2/Reshape:0"

但无论我尝试使用哪个文件作为 what_file.pb我收到了 Error during inference: Invalid argument: Session was not created with a graph before Run()!

最佳答案

所以我得到了这个工作。只需要将 tensorflow 图中的所有变量转换为常量,然后保存图定义。

这是一个小例子:

import tensorflow as tf

from keras import backend as K
from tensorflow.python.framework import graph_util

K.set_learning_phase(0)
model = function_that_returns_your_keras_model()
sess = K.get_session()

output_node_name = "my_output_node" # Name of your output node

with sess as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
graph_def = sess.graph.as_graph_def()
output_graph_def = graph_util.convert_variables_to_constants(
sess,
sess.graph.as_graph_def(),
output_node_name.split(","))
tf.train.write_graph(output_graph_def,
logdir="my_dir",
name="my_model.pb",
as_text=False)

现在只需使用 my_model.pb 调用 TensorFlow Benchmark 工具如图。

关于tensorflow - 使用 TensorFlow Benchmark 对 Keras 模型进行基准测试,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43434292/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com