gpt4 book ai didi

Tensorflow Estimator -warm_start_from 和 model_dir

转载 作者:行者123 更新时间:2023-12-04 04:28:37 26 4
gpt4 key购买 nike

使用时 tf.estimatorwarm_start_from model_dir ,以及两者 warm_start_from目录和 model_dir目录包含有效的检查点,哪个检查点将被实际恢复?

为了给出一些上下文,我的估算器代码看起来像

est = tf.estimator.Estimator(
model_fn=model_fn,
model_dir=model_dir,
warm_start_from=warm_start_dir)

for epoch in range(num_epochs):
est.train(input_fn=train_input_fn)
est.evaluate(input_fn=eval_input_fn)

(输入函数使用一次性迭代器。)

所以在第一次迭代中,当 model_dir为空,我希望加载热启动检查点,但在下一个时期,我希望从 model_dir 中的最后一次迭代中获得中间微调检查点要加载。但至少从日志来看,它看起来像 warm_start_dir仍在加载中。

我可能会在下一次迭代中覆盖我的估算器,但我想知道它是否不应该以某种方式在估算器中构建。

最佳答案

我遇到了类似的问题,我通过提供一个在 session 开始时运行的初始化 Hook 并使用 tf.estimator.train_and_evaluate 解决了这个问题。 (虽然我不能相信整个解决方案,因为我在其他地方看到了类似的其他目的):

class InitHook(tf.train.SessionRunHook):
"""initializes model from a checkpoint_path
args:
modelPath: full path to checkpoint
"""
def __init__(self, checkpoint_dir):
self.modelPath = checkpoint_dir
self.initialized = False

def begin(self):
"""
Restore encoder parameters if a pre-trained encoder model is available and we haven't trained previously
"""
if not self.initialized:
log = logging.getLogger('tensorflow')
checkpoint = tf.train.latest_checkpoint(self.modelPath)
if checkpoint is None:
log.info('No pre-trained model is available, training from scratch.')
else:
log.info('Pre-trained model {0} found in {1} - warmstarting.'.format(checkpoint, self.modelPath))
tf.train.warm_start(checkpoint)
self.initialized = True

然后,对于训练:
initHook = InitHook(checkpoint_dir = warm_start_dir)
trainSpec = tf.estimator.TrainSpec(
input_fn = train_input_fn,
max_steps = N_STEPS,
hooks = [initHook]
)
evalSpec = tf.estimator.EvalSpec(
input_fn = eval_input_fn,
steps = None,
name = 'eval',
throttle_secs = 3600
)
tf.estimator.train_and_evaluate(estimator, trainSpec, evalSpec)

这在开始时运行一次以初始化来自 warm_start_dir 的变量.后来,当估计器中有新的检查点时 model_dir ,它从那里继续热启动。

关于Tensorflow Estimator -warm_start_from 和 model_dir,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49846207/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com