gpt4 book ai didi

python - 我可以使用 ktrain 库从检查点恢复训练吗?

转载 作者:行者123 更新时间:2023-12-02 19:23:42 25 4
gpt4 key购买 nike

ktrain 是深度学习库 TensorFlow Keras(和其他库)的轻量级包装器,可帮助构建、训练和部署神经网络和其他机器学习模型。我可以使用 ktrain 库从检查点恢复训练吗?

最佳答案

是的,可以。 ktrain 常见问题解答中对此进行了解答。我将答案复制在这里:

方法 1:使用 Predictor API(适用于任何模型)

# save model and Preprocessor instance after partially training
ktrain.get_predictor(model, preproc).save('/tmp/my_predictor')

# reload Predictor and extract model
model = ktrain.load_predictor('/tmp/my_predictor').model

# re-instantiate Learner and continue training
learner = ktrain.get_learner(model, train_data=trn, val_data=val)
learner.fit_onecycle(2e-5, 1)

请注意,这里的 preproc 是一个 Preprocessor 实例。如果使用诸如 texts_from_csvimages_from_folder 之类的数据加载函数,它将是该函数的第三个返回值。或者,如果使用 Transformer API对于文本分类,它将是调用 text.Transformer 的输出(即 preproc = text.Transformer('bert-base-uncased', ...)) .

方法 2:使用 transformers 库(如果训练 Hugging Face Transformers 模型)

如果模型是Hugging Face Transformers模型,则可以直接使用transformers:

# save model using transformers API after partially training
learner.model.save_pretrained('/tmp/my_model')

# reload the model using transformers directly
from transformers import *
model = TFAutoModelForSequenceClassification.from_pretrained('/tmp/my_model')
model.compile(loss='categorical_crossentropy',optimizer='adam', metrics=['accuracy'])

# re-instantiate Learner and continue training
learner = ktrain.get_learner(model, train_data=trn, val_data=val)
learner.fit_onecycle(2e-5, 1)

方法3:使用checkpoint_folder参数保存模型权重

checkpoint_folder 参数(例如,learner.autofit(1e-4, 4, checkpoint_folder='/tmp/saved_weights'))仅保存每个时期之后的模型。任何时期的权重都可以使用 model.load_weights 方法重新加载到模型中,就像在 tf.Keras 中通常所做的那样。您只需要先重新创建首先是模型。例如,如果训练 NER 模型,它将按如下方式工作:

# recreate model from scratch
import ktrain
from ktrain import text
model = text.sequence_tagger(...
# load checkpoint weights from 3rd epoch into model
model.load_weights('../models/checkpoints/weights-03.hdf5')
# recreate learner
learner = ktrain.get_learner(model, ...
# continue training here

最后,还有一个 learner.save_modellearner.load_model 方法,用于在单个 session 期间进行交互式训练时保存和重新加载模型。

关于python - 我可以使用 ktrain 库从检查点恢复训练吗?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62678035/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com