- html - 出于某种原因,IE8 对我的 Sass 文件中继承的 html5 CSS 不友好?
- JMeter 在响应断言中使用 span 标签的问题
- html - 在 :hover and :active? 上具有不同效果的 CSS 动画
- html - 相对于居中的 html 内容固定的 CSS 重复背景?
我有一个 Keras 模型,我想将其转换为 Tensorflow protobuf(例如 saved_model.pb
)。
该模型来自 vgg-19 网络上的迁移学习,其中头部被切断并使用全连接+softmax 层进行训练,而 vgg-19 网络的其余部分被卡住
我可以在 Keras 中加载模型,然后使用 keras.backend.get_session()
在 tensorflow 中运行模型,生成正确的预测:
frame = preprocess(cv2.imread("path/to/img.jpg")
keras_model = keras.models.load_model("path/to/keras/model.h5")
keras_prediction = keras_model.predict(frame)
print(keras_prediction)
with keras.backend.get_session() as sess:
tvars = tf.trainable_variables()
output = sess.graph.get_tensor_by_name('Softmax:0')
input_tensor = sess.graph.get_tensor_by_name('input_1:0')
tf_prediction = sess.run(output, {input_tensor: frame})
print(tf_prediction) # this matches keras_prediction exactly
如果我不包含行 tvars = tf.trainable_variables()
,然后 tf_prediction
变量完全错误,与 keras_prediction
的输出不匹配根本不。事实上,输出中的所有值(具有 4 个概率值的单个数组)完全相同(~0.25,全部加 1)。这让我怀疑头部的权重只是初始化为 0 if tf.trainable_variables()
没有首先调用,这是在检查模型变量后确认的。无论如何,调用tf.trainable_variables()
使 tensorflow 预测正确。
问题是,当我尝试保存此模型时,tf.trainable_variables()
中的变量实际上并没有保存到 .pb
文件:
with keras.backend.get_session() as sess:
tvars = tf.trainable_variables()
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), ['Softmax'])
graph_io.write_graph(constant_graph, './', 'saved_model.pb', as_text=False)
我要问的是,如何使用 tf.training_variables()
将 Keras 模型保存为 Tensorflow protobuf完好无损的?
非常感谢!
最佳答案
因此,您卡住图中变量(转换为常量)的方法应该可行,但不是必需的,并且比其他方法更棘手。 (更多内容见下文)。如果您出于某种原因想要卡住图形(例如导出到移动设备),我需要更多详细信息来帮助调试,因为我不确定 Keras 在幕后对您的图形做了什么隐式操作。但是,如果您想稍后保存并加载图表,我可以解释如何做到这一点,(尽管不能保证 Keras 所做的任何事情都不会搞砸......,很乐意帮助调试)。
所以这里实际上有两种格式在起作用。一种是 GraphDef,用于检查点,因为它不包含有关输入和输出的元数据。另一个是 MetaGraphDef
,其中包含元数据和图形定义,元数据对于预测和运行 ModelServer
(来自tensorflow/serving)很有用。
无论哪种情况,您都需要做的不仅仅是调用graph_io.write_graph
,因为变量通常存储在graphdef之外。
这两个用例都有包装器库。 tf.train.Saver
主要用于保存和恢复检查点。
但是,由于您想要预测,我建议使用 tf.saved_model.builder.SavedModelBuilder
构建 SavedModel 二进制文件。我为此提供了一些样板:
from tensorflow.python.saved_model.signature_constants import DEFAULT_SERVING_SIGNATURE_DEF_KEY as DEFAULT_SIG_DEF
builder = tf.saved_model.builder.SavedModelBuilder('./mymodel')
with keras.backend.get_session() as sess:
output = sess.graph.get_tensor_by_name('Softmax:0')
input_tensor = sess.graph.get_tensor_by_name('input_1:0')
sig_def = tf.saved_model.signature_def_utils.predict_signature_def(
{'input': input_tensor},
{'output': output}
)
builder.add_meta_graph_and_variables(
sess, tf.saved_model.tag_constants.SERVING,
signature_def_map={
DEFAULT_SIG_DEF: sig_def
}
)
builder.save()
运行此代码后,您应该有一个 mymodel/saved_model.pb
文件以及一个目录 mymodel/variables/
,其中包含与变量值对应的 protobuf。
然后要再次加载模型,只需使用tf.saved_model.loader
:
# Does Keras give you the ability to start with a fresh graph?
# If not you'll need to do this in a separate program to avoid
# conflicts with the old default graph
with tf.Session(graph=tf.Graph()):
meta_graph_def = tf.saved_model.loader.load(
sess,
tf.saved_model.tag_constants.SERVING,
'./mymodel'
)
# From this point variables and graph structure are restored
sig_def = meta_graph_def.signature_def[DEFAULT_SIG_DEF]
print(sess.run(sig_def.outputs['output'], feed_dict={sig_def.inputs['input']: frame}))
显然,通过tensorflow/serving或Cloud ML Engine可以使用此代码进行更有效的预测,但这应该可行。Keras 可能正在幕后做一些事情,这也会干扰这个过程,如果是这样,我们希望听到它(并且我想确保 Keras 用户也能够卡住图表,因此,如果您想向我发送包含完整代码或其他内容的要点,也许我可以找到熟悉 Keras 的人来帮助我调试。)
编辑:您可以在这里找到一个端到端的示例:https://github.com/GoogleCloudPlatform/cloudml-samples/blob/master/census/keras/trainer/model.py#L85
关于tensorflow - 使用convert_variables_to_constants保存tf.trainable_variables(),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45779268/
正如标题所说,为什么在 tensorflow 2 中不推荐使用 convert_variables_to_constants()?获取可保存模型以加载到下游独立应用程序以进行推理的简单替代方法是什么(
我想知道是否可以使用函数 tf.graph_util.convert_variables_to_constants (为了存储图表的卡住版本)在训练/评估循环中,而我正在使用自定义估算器。例如: be
我是一名优秀的程序员,十分优秀!