gpt4 book ai didi

python - 如何重新训练用于语言翻译的序列到序列神经网络模型?

转载 作者:行者123 更新时间:2023-12-05 07:25:28 25 4
gpt4 key购买 nike

我训练了一个 seq2seq tensorflow 模型,用于将句子从英语翻译成西类牙语。我训练了 615 700 步的模型,并成功保存了模型检查点。我的英语和西类牙语句子的训练数据大小都是 200 000。我想从 615 700 个步骤中为 10K 个新数据句子重新训练这个模型。为此,我正在使用序列对 tensoflow 模型进行排序。如何从最后一个检查点开始重新训练模型? Here是我用于翻译的链接。

我的 train 文件夹中有 3 种类型的文件:

.index
.meta
.data
and checkpoint file.

我的新训练数据集文件是 europarl_train.es-en.eneuroparl_train.es-en.es 分别用于英语和西类牙语句子。

我编写代码来加载我的模型 .meta 文件和权重

import data_utils
import seq2seq_model
import translate
import tensorflow as tf

with tf.Session() as sess:
saver = tf.train.import_meta_graph('/home/i9/L-T_Model_Training/16_NOV_MODEL/train/translate.ckpt-615700.meta')
saver.restore(sess,tf.train.latest_checkpoint('/home/i9/L-T_Model_Training/16_NOV_MODEL/train/.'))

如何开始保留此数据集?

最佳答案

保存

根据 TensorFlow version 2 doc您可以使用 tf.train.Checkpointtf.train.CheckpointManager 类来保存您的模型。考虑以下示例:

checkpoint_dir = './training_checkpoints'       # custom directory
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(model=model) # your model variable name
manager = tf.train.CheckpointManager(checkpoint=checkpoint, directory=checkpoint_dir, max_to_keep=3) # max_to_keep means how much of last checkpoints number you like to keep

现在,如果您想保存模型,请键入:manager.save()

加载

再次定义检查点和检查点管理器并运行这段代码:

if manager.latest_checkpoint:
checkpoint.restore((manager.latest_checkpoint)).assert_consumed()
print("Restored from {}".format(manager.latest_checkpoint))

如果您遇到类似 (AssertionError: Unresolved object in checkpoint (root)) 的错误,请将 assert_consumed 替换为 expect_partial。 (去这里找区别:link)

模型已从检查点加载。现在您可以加载数据并修复形状并继续训练您的模型。

关于python - 如何重新训练用于语言翻译的序列到序列神经网络模型?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54780352/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com