gpt4 book ai didi

python - 用于导出 C++ 的 vgg16 keras 保存模型

转载 作者:行者123 更新时间:2023-11-30 03:23:23 25 4
gpt4 key购买 nike

我最近在研究 vgg16 以重用它的模型。

在 python 中:(Keras)

model = applications.VGG16(include_top=False, weights='imagenet')

一切都很好。

我需要通过编译导出 THiS 模型并适合导出 c++ json 文件。如何将 ThiS 模型正确导出为 h5 文件以用于导出模型?

最佳答案

一种方法是将您的 Keras 模型转换为 Python 中的 TensorFlow 模型,然后将卡住图导出到 .pb 文件。然后在 C++ 中加载它。我已经使用这段代码从 Keras 导出了一个卡住的 .pb 文件。

import tensorflow as tf

from keras import backend as K
from tensorflow.python.framework import graph_util

K.set_learning_phase(0)
model = function_that_returns_your_keras_model()
sess = K.get_session()

output_node_name = "my_output_node" # Name of your output node

with sess as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
graph_def = sess.graph.as_graph_def()
output_graph_def = graph_util.convert_variables_to_constants(
sess,
sess.graph.as_graph_def(),
output_node_name.split(","))
tf.train.write_graph(output_graph_def,
logdir="my_dir",
name="my_model.pb",
as_text=False)

然后您可以按照任何有关如何在 C++ 中加载 .pb 文件的教程进行操作。例如这个:https://medium.com/jim-fleming/loading-a-tensorflow-graph-with-the-c-api-4caaff88463f

Keras 在 TensorFlow 图中注入(inject) learning_phase 变量,可能还有其他仅 Keras 变量 - 如果我没记错的话,您应该确保从图中删除这些变量。

关于python - 用于导出 C++ 的 vgg16 keras 保存模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50396739/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com