gpt4 book ai didi

python - 如何从 Tensorflow 检查点将权重加载到 Keras 模型

转载 作者:太空狗 更新时间:2023-10-30 01:36:11 24 4
gpt4 key购买 nike

我有一些 python 代码可以使用 Tensorflow 的 TFRecords 和数据集 API 来训练网络。我已经使用 tf.Keras.layers 构建了网络,这可以说是最简单和最快的方法。方便的函数 model_to_estimator()

modelTF = tf.keras.estimator.model_to_estimator(
keras_model=model,
custom_objects=None,
config=run_config,
model_dir=checkPointDirectory
)

将 Keras 模型转换为估算器,这使我们能够很好地利用数据集 API,并在训练期间和训练完成后自动将检查点保存到 checkPointDirectory。估算器 API 提供了一些非常宝贵的功能,例如在多个 GPU 上自动分配工作负载,例如

distribution = tf.contrib.distribute.MirroredStrategy()
run_config = tf.estimator.RunConfig(train_distribute=distribution)

现在对于大型模型和大量数据,在使用某种形式的已保存模型进行训练后执行预测通常很有用。似乎从 Tensorflow 1.10(参见 https://github.com/tensorflow/tensorflow/issues/19295)开始,tf.keras.model 对象支持来自 Tensorflow 检查点的 load_weights()。这在 Tensorflow 文档中有简要提及,但在 Keras 文档中没有提及,我找不到任何人展示这方面的示例。在一些新的 .py 中再次定义模型层后,我尝试了

checkPointPath = os.path.join('.', 'tfCheckPoints', 'keras_model.ckpt.index')
model.load_weights(filepath=checkPointPath, by_name=False)

但这给出了一个 NotImplementedError:

Restoring a name-based tf.train.Saver checkpoint using the object-based restore API. This mode uses global names to match variables, and so is somewhat fragile. It also adds new restore ops to the graph each time it is called when graph building. Prefer re-encoding training checkpoints in the object-based format: run save() on the object-based saver (the same one this message is coming from) and use that checkpoint in the future.

2018-10-01 14:24:49.912087:
Traceback (most recent call last):
File "C:/Users/User/PycharmProjects/python/mercury.classifier reductions/V3.2/wikiTestv3.2/modelEvaluation3.2.py", line 141, in <module>
model.load_weights(filepath=checkPointPath, by_name=False)
File "C:\Users\User\Anaconda3\lib\site-packages\tensorflow\python\keras\engine\network.py", line 1526, in load_weights
checkpointable_utils.streaming_restore(status=status, session=session)
File "C:\Users\User\Anaconda3\lib\site-packages\tensorflow\python\training\checkpointable\util.py", line 880, in streaming_restore
"Streaming restore not supported from name-based checkpoints. File a "
NotImplementedError: Streaming restore not supported from name-based checkpoints. File a feature request if this limitation bothers you.

我想按照警告的建议去做,并改为使用“基于对象的保护程序”,但我还没有找到通过传递给 estimator.train() 的 RunConfig 来执行此操作的方法。

那么有没有更好的方法将保存的权重返回到估计器中以用于预测? github 线程似乎表明这已经实现(虽然基于错误,可能与我在上面尝试的方式不同)。有没有人在 TF 检查点上成功使用过 load_weights()?我还没有找到任何关于如何完成此操作的教程/示例,因此不胜感激。

最佳答案

我不确定,但也许您可以将 keras_model.ckpt.index 更改为 keras_model.ckpt 进行测试。

关于python - 如何从 Tensorflow 检查点将权重加载到 Keras 模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52597523/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com