gpt4 book ai didi

tensorflow - 如何使用 Keras API 在 Tensorflow 2.0 中的多个 GPU 上加载模型后继续训练?

转载 作者:行者123 更新时间:2023-12-01 17:44:17 24 4
gpt4 key购买 nike

我使用 Keras API 在 Tensorflow 2.0 中训练了一个由 RNN 组成的文本分类模型。我使用 tf.distribute.MirroredStrategy() 在多个 GPU(2) 上训练了这个模型来自here 。我使用 tf.keras.callbacks.ModelCheckpoint('file_name.h5') 保存了模型的检查点每个纪元之后。现在,我想在上次保存的检查点使用相同数量的 GPU 继续训练。加载tf.distribute.MirroredStrategy()内的检查点后像这样-

mirrored_strategy = tf.distribute.MirroredStrategy()
with mirrored_strategy.scope():
model =tf.keras.models.load_model('file_name.h5')

,它抛出以下错误。

File "model_with_tfsplit.py", line 94, in <module>
model =tf.keras.models.load_model('TF_model_onfull_2_03.h5') # Loading for retraining
File "/home/rishabh/.local/lib/python2.7/site-packages/tensorflow_core/python/keras/saving/save.py", line 138, in load_model
return hdf5_format.load_model_from_hdf5(filepath, custom_objects, compile)
File "/home/rishabh/.local/lib/python2.7/site-packages/tensorflow_core/python/keras/saving/hdf5_format.py", line 187, in load_model_from_hdf5
model._make_train_function()
File "/home/rishabh/.local/lib/python2.7/site-packages/tensorflow_core/python/keras/engine/training.py", line 2015, in _make_train_function
params=self._collected_trainable_weights, loss=self.total_loss)
File "/home/rishabh/.local/lib/python2.7/site-packages/tensorflow_core/python/keras/optimizer_v2/optimizer_v2.py", line 500, in get_updates
grads = self.get_gradients(loss, params)
File "/home/rishabh/.local/lib/python2.7/site-packages/tensorflow_core/python/keras/optimizer_v2/optimizer_v2.py", line 391, in get_gradients
grads = gradients.gradients(loss, params)
File "/home/rishabh/.local/lib/python2.7/site-packages/tensorflow_core/python/ops/gradients_impl.py", line 158, in gradients
unconnected_gradients)
File "/home/rishabh/.local/lib/python2.7/site-packages/tensorflow_core/python/ops/gradients_util.py", line 541, in _GradientsHelper
for x in xs
File "/home/rishabh/.local/lib/python2.7/site-packages/tensorflow_core/python/distribute/values.py", line 716, in handle
raise ValueError("`handle` is not available outside the replica context"
ValueError: `handle` is not available outside the replica context or a `tf.distribute.Strategy.update()` call

现在我不知道问题出在哪里。另外,如果我不使用这种镜像策略来使用多个 GPU,那么训练会从头开始,但经过几个步骤后,它会达到与保存模型之前相同的精度和损失值。虽然不确定这种行为是否正常。

谢谢!瑞沙卜·萨哈瓦特

最佳答案

在分布式范围下创建模型,然后使用load_weights方法。在此示例中,get_model 返回 tf.keras.Model

的实例
def get_model():
...
return model

mirrored_strategy = tf.distribute.MirroredStrategy()
with mirrored_strategy.scope():
model = get_model()
model.load_weights('file_name.h5')
model.compile(...)
model.fit(...)

关于tensorflow - 如何使用 Keras API 在 Tensorflow 2.0 中的多个 GPU 上加载模型后继续训练?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57376403/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com