gpt4 book ai didi

python - 使用 TensorFlow Estimators 迁移学习/再训练

转载 作者:太空宇宙 更新时间:2023-11-04 02:39:12 25 4
gpt4 key购买 nike

我一直无法弄清楚如何在新的 TF 中使用迁移学习/最后一层再训练 Estimator API .

Estimator 需要一个 model_fn,其中包含网络架构、训练和评估操作,如 documentation 中所定义。 .使用 CNN 架构的 model_fn 示例是 here .

如果我想重新训练最后一层,例如初始架构,我不确定是否需要在此 model_fn 中指定整个模型,然后加载预经过训练的权重,或者是否有一种方法可以像“传统”方法中那样使用保存的图形(例如 here )。

这已被提出为 issue ,但仍然开放,我不清楚答案。

最佳答案

可以在模型定义期间加载元图,并使用 SessionRunHook 从 ckpt 文件加载权重。

def model(features, labels, mode, params):
# Create the graph here

return tf.estimator.EstimatorSpec(mode,
predictions,
loss,
train_op,
training_hooks=[RestoreHook()])

SessionRunHook 可以是:

class RestoreHook(tf.train.SessionRunHook):

def after_create_session(self, session, coord=None):
if session.run(tf.train.get_or_create_global_step()) == 0:
# load weights here

这样,权重在第一步加载并在模型检查点的训练期间保存。

关于python - 使用 TensorFlow Estimators 迁移学习/再训练,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47041724/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com