gpt4 book ai didi

python - TensorFlow:将 py_func 保存到 .pb 文件

转载 作者:行者123 更新时间:2023-12-04 14:19:54 28 4
gpt4 key购买 nike

我尝试构建一个 tensorflow 模型 - 我在其中使用 tf.py_func 在普通 python 代码中创建部分代码。问题是当我将模型保存到 .pb 文件时,.pb 文件本身非常小并且不包含 py_func:0 张量。当我尝试从 .pb 文件加载和运行模型时,出现此错误:get ValueError: callback pyfunc_0 is not found。

当我不保存和加载为 .pb 文件时它工作

有人能帮忙吗?这对我来说非常重要,让我度过了几个不眠之夜。

model_version = "465555564"
tensorboard = TensorBoard(log_dir='./logs', histogram_freq = 0, write_graph = True, write_images = False)

sess = tf.Session()
K.set_session(sess)
K.set_learning_phase(0)

def my_func(x):
some_function

input = tf.placeholder(tf.float32)
y = tf.py_func(my_func, [input], tf.float32)

prediction_signature = tf.saved_model.signature_def_utils.predict_signature_def({"inputs": input}, {"prediction": y})
builder = saved_model_builder.SavedModelBuilder('./'+model_version)
legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
builder.add_meta_graph_and_variables(
sess, [tag_constants.SERVING],
signature_def_map={
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:prediction_signature,
},
legacy_init_op=legacy_init_op)

builder.save()

最佳答案

一种使用tf.py_func 保存 TF 模型的方法,但您必须使用SavedModel

TF 有 2 个级别的模型保存:检查点和 SavedModels。参见 this answer有关更多详细信息,但请在此处引用:

  • A checkpoint contains the value of (some of the) variables in a TensorFlow model. It is created by a Saver. To use a checkpoint, you need to have a compatible TensorFlow Graph, whose Variables have the same names as the Variables in the checkpoint.
  • SavedModel is much more comprehensive: It contains a set of Graphs (MetaGraphs, in fact, saving collections and such), as well as a checkpoint which is supposed to be compatible with these Graphs, and any asset files that are needed to run the model (e.g. Vocabulary files). For each MetaGraph it contains, it also stores a set of signatures. Signatures define (named) input and output tensors.

tf.py_func 操作 不能SavedModel 保存(在 this page in the docs 上注明),这是您尝试做的这里。这是有充分理由的。 SavedModel 应该完全独立于原始代码,能够以可以反序列化它的任何其他语言加载。这允许模型通过 ML Engine 之类的东西加载。 ,这可能是用 C++ 或类似语言编写的。问题是它不能序列化任意 Python 代码,所以 py_func 是行不通的。

可以通过使用检查点来解决这个问题,只要您愿意留在 Python 中即可。您将无法获得 SavedModel 提供的独立性。您可以在使用 tf.train.Saver 训练后保存检查点,然后在新的 Session 中重新构建整个图并使用该 加载它节省程序。甚至还有一种在 ML Engine 中使用该代码的方法,它过去专门用于 SavedModel。您可以使用 custom prediction routines回避对 SavedModel 的需求。

有关保存/恢复模型的更多信息,请参阅 the docs .

关于python - TensorFlow:将 py_func 保存到 .pb 文件,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56233463/

28 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com