gpt4 book ai didi

tensorflow - tf.estimator.Estimator.train() 是否保持 input_fn 状态

转载 作者:行者123 更新时间:2023-12-03 09:56:58 26 4
gpt4 key购买 nike

一年多来我一直在使用自己的 Estimator/Experiment 之类的代码,但我最终想加入 Dataset+Estimator 的行列。

我想做如下的事情:

for _ in range(N):
estimator.train(train_input_fn, steps=1000)
estimator.evaluate(validation_input_fn)

其中 train_input_fn 创建一个 tf.data.Dataset 永远循环训练集,而 validation_input_fn 创建一个 tf. data.Dataset 执行一次验证集。

train() 是否在调用期间保持 train_input_fn 的状态(即如果引用匹配则只调用一次)?这是人们使用 Estimator 进行训练循环的方式吗?

最佳答案

正如我在上面的评论中提到的,看起来它不会在调用 estimator.train() 时保存状态。

我正在使用的一个解决方案(可能也是预期的方法)是将评估监听器传递给 estimator.train()。例如,

class EvalCheckpointSaverListener(tf.train.CheckpointSaverListener):
def __init__(self, estimator, input_fn):
self.estimator = estimator
self.input_fn = input_fn

def after_save(self, session, global_step):
self.estimator.evaluate(self.input_fn)

estimator.train(
input_fn=lambda:_train_input_fn(...),
max_steps=N,
saving_listeners=[
EvalCheckpointSaverListener(
estimator,
lambda:_eval_input_fn(...),
),
],
)

关于tensorflow - tf.estimator.Estimator.train() 是否保持 input_fn 状态,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46925196/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com