gpt4 book ai didi

tensorflow - 如何使用TF2将加载的h5模型正确保存到pb

转载 作者:行者123 更新时间:2023-12-03 20:15:26 30 4
gpt4 key购买 nike

我加载了一个已保存的 h5 模型并希望将该模型保存为 pb。
该模型在训练期间使用 tf.keras.callbacks.ModelCheckpoint 保存回调函数。

TF版本:2.0.0a
编辑 :2.0.0-beta1 也有同样的问题

我保存 pb 的步骤:

  • 我第一套K.set_learning_phase(0)
  • 然后我用 tf.keras.models.load_model 加载模型
  • 然后,我定义了 freeze_session()功能。
  • (可选我编译模型)
  • 然后使用 freeze_session()功能与 tf.keras.backend.get_session

  • 错误 我得到,编译和不编译:

    AttributeError: module 'tensorflow.python.keras.api._v2.keras.backend' has no attribute 'get_session'



    我的问题:
  • TF2 没有get_session了?
    (我知道 tf.contrib.saved_model.save_keras_model 不存在了,我也试过 tf.saved_model.save 没有真正奏效)
  • 还是get_session仅在我实际训练模型时才起作用,而仅加载 h5 不起作用
    编辑 :对于新训练的 session ,也没有可用的 get_session。
  • 如果是这样,我将如何在未经培训的情况下将 h5 转换为 pb?有好的教程吗?

  • 感谢您的帮助

    更新 :

    自 TF2.x 正式发布以来,图形/ session 概念发生了变化。 savedmodel应该使用api。
    您可以使用 tf.compat.v1.disable_eager_execution()使用 TF2.x,它将生成一个 pb 文件。但是,我不确定它是哪种 pb 文件类型,因为保存的模型组合从 TF1 更改为 TF2。我会继续挖掘。

    最佳答案

    我确实将模型保存到 pb来自 h5模型:

    import logging
    import tensorflow as tf
    from tensorflow.compat.v1 import graph_util
    from tensorflow.python.keras import backend as K
    from tensorflow import keras

    # necessary !!!
    tf.compat.v1.disable_eager_execution()

    h5_path = '/path/to/model.h5'
    model = keras.models.load_model(h5_path)
    model.summary()
    # save pb
    with K.get_session() as sess:
    output_names = [out.op.name for out in model.outputs]
    input_graph_def = sess.graph.as_graph_def()
    for node in input_graph_def.node:
    node.device = ""
    graph = graph_util.remove_training_nodes(input_graph_def)
    graph_frozen = graph_util.convert_variables_to_constants(sess, graph, output_names)
    tf.io.write_graph(graph_frozen, '/path/to/pb/model.pb', as_text=False)
    logging.info("save pb successfully!")

    我使用 TF2 来转换模型,如:
  • keras.callbacks.ModelCheckpoint(save_weights_only=True)model.fit并保存 checkpoint在训练时;
  • 培训结束后,self.model.load_weights(self.checkpoint_path)负载 checkpoint ;
  • self.model.save(h5_path, overwrite=True, include_optimizer=False)另存为 h5 ;
  • 转换 h5pb就像上面一样;
  • 关于tensorflow - 如何使用TF2将加载的h5模型正确保存到pb,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56646940/

    30 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com