gpt4 book ai didi

python - 如何将 Keras .h5 导出到 tensorflow .pb?

转载 作者:IT老高 更新时间:2023-10-28 21:48:24 24 4
gpt4 key购买 nike

我已经使用新数据集对初始模型进行了微调,并将其保存为 Keras 中的“.h5”模型。现在我的目标是在仅接受“.pb”扩展名的 android Tensorflow 上运行我的模型。问题是 Keras 或 tensorflow 中是否有任何库可以进行这种转换?到目前为止,我已经看到了这篇文章:https://blog.keras.io/keras-as-a-simplified-interface-to-tensorflow-tutorial.html但还不能弄清楚。

最佳答案

Keras 本身不包含任何将 TensorFlow 图导出为 Protocol Buffer 文件的方法,但您可以使用常规 TensorFlow 实用程序来完成。 Here是一篇博客文章,解释了如何使用实用脚本 freeze_graph.py包含在 TensorFlow 中,这是它的“典型”完成方式。

但是,我个人觉得必须创建一个检查点然后运行外部脚本来获取模型很麻烦,我更喜欢使用我自己的 Python 代码来完成,所以我使用这样的函数:

def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
"""
Freezes the state of a session into a pruned computation graph.

Creates a new computation graph where variable nodes are replaced by
constants taking their current value in the session. The new graph will be
pruned so subgraphs that are not necessary to compute the requested
outputs are removed.
@param session The TensorFlow session to be frozen.
@param keep_var_names A list of variable names that should not be frozen,
or None to freeze all the variables in the graph.
@param output_names Names of the relevant graph outputs.
@param clear_devices Remove the device directives from the graph for better portability.
@return The frozen graph definition.
"""
graph = session.graph
with graph.as_default():
freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
output_names = output_names or []
output_names += [v.op.name for v in tf.global_variables()]
input_graph_def = graph.as_graph_def()
if clear_devices:
for node in input_graph_def.node:
node.device = ""
frozen_graph = tf.graph_util.convert_variables_to_constants(
session, input_graph_def, output_names, freeze_var_names)
return frozen_graph

这在 freeze_graph.py 的实现中受到启发。参数也与脚本类似。 session 是 TensorFlow session 对象。 keep_var_names 仅在您想保持某些变量不卡住时才需要(例如,对于有状态模型),因此通常不需要。 output_names 是一个列表,其中包含产生所需输出的操作的名称。 clear_devices 只是删除任何设备指令以使图形更便携。因此,对于具有一个输出的典型 Keras model,您可以执行以下操作:

from keras import backend as K

# Create, compile and train model...

frozen_graph = freeze_session(K.get_session(),
output_names=[out.op.name for out in model.outputs])

然后您可以像往常一样使用 tf.train.write_graph 将图形写入文件。 :

tf.train.write_graph(frozen_graph, "some_directory", "my_model.pb", as_text=False)

关于python - 如何将 Keras .h5 导出到 tensorflow .pb?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45466020/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com