gpt4 book ai didi

python - 在 Keras 中使用 Adam 优化器恢复训练

转载 作者:行者123 更新时间:2023-12-01 21:24:39 28 4
gpt4 key购买 nike

我的问题很简单,但我在网上找不到明确的答案(到目前为止)。

在定义的训练周期数之后,我保存了使用 adam 优化器训练的 keras 模型的权重:

callback = tf.keras.callbacks.ModelCheckpoint(filepath=path, save_weights_only=True)
model.fit(X,y,callbacks=[callback])

当我关闭 jupyter 后恢复训练时,我可以简单地使用:

model.load_weights(path)

继续训练。

由于 Adam 依赖于纪元数(例如学习率衰减的情况),我想知道在与以前相同的条件下恢复训练的最简单方法。

根据 ibarrond 的回答,我编写了一个小的自定义回调。

optim = tf.keras.optimizers.Adam()
model.compile(optimizer=optim, loss='categorical_crossentropy',metrics=['accuracy'])

weight_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, save_weights_only=True, verbose=1, save_best_only=False)

class optim_callback(tf.keras.callbacks.Callback):
'''Custom callback to save optimiser state'''

def on_epoch_end(self,epoch,logs=None):
optim_state = tf.keras.optimizers.Adam.get_config(optim)
with open(optim_state_pkl,'wb') as f_out:
pickle.dump(optim_state,f_out)

model.fit(X,y,callbacks=[weight_callback,optim_callback()])

当我恢复训练时:

model.load_weights(checkpoint_path)
with open(optim_state_pkl,'rb') as f_out:
optim_state = pickle.load(f_out)
tf.keras.optimizers.Adam.from_config(optim_state)

我只是想检查一下这是否正确。再次非常感谢!!

附录:进一步阅读默认值 Keras implementation亚当和 original Adam paper ,我相信默认的 Adam 不依赖于纪元数,而只依赖于迭代数。因此,这是不必要的。但是,对于任何希望跟踪其他优化器的人来说,该代码可能仍然有用。

最佳答案

为了完美捕获优化器的状态,您应该使用函数 get_config() 存储其配置。 。此函数返回一个字典(包含选项),可以使用pickle对其进行序列化并存储在文件中。 .

要重新启动该过程,只需 d = pickle.load('my_saved_tfconf.txt') 检索具有配置的字典,然后使用函数 from_config(d ) Keras Adam Optimizer.

关于python - 在 Keras 中使用 Adam 优化器恢复训练,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60076741/

28 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com