gpt4 book ai didi

python - 基本的 StopAtStepHook 和 MonitoredTrainingSession 用法

转载 作者:行者123 更新时间:2023-11-28 17:18:12 24 4
gpt4 key购买 nike

我想设置分布式 tensorflow 模型,但无法理解 MonitoredTrainingSession 和 StopAtStepHook 的交互方式。在我进行此设置之前:

for epoch in range(training_epochs):
for i in range(total_batch-1):
c, p, s = sess.run([cost, prediction, summary_op], feed_dict={x: batch_x, y: batch_y})

现在我有了这个设置(简化):

def run_nn_model(learning_rate, log_param, optimizer, batch_size, layer_config):
with tf.device(tf.train.replica_device_setter(
worker_device="/job:worker/task:%d" % mytaskid,
cluster=cluster)):

# [variables...]

hooks=[tf.train.StopAtStepHook(last_step=100)]
if myjob == "ps":
server.join()
elif myjob == "worker":
with tf.train.MonitoredTrainingSession(master = server.target,
is_chief=(mytaskid==0),
checkpoint_dir='/tmp/train_logs',
hooks=hooks
) as sess:

while not sess.should_stop():
#for epoch in range...[see above]

这是错的吗?它抛出:

RuntimeError: Run called even after should_stop requested.
Command exited with non-zero status 1

有人可以向我解释一下 tensorflow 在这里是如何协调的吗?如何使用计步器跟踪训练? (在我有这个方便的时代变量之前)

最佳答案

每次执行 sess.run 时,计数器都会递增。这里的问题是您运行的步数 (total_batch-1 x training_epochs) 多于 Hook (200) 中指定的步数。

尽管我认为这不是一个干净的语法,但您可以做的是定义 last_step = total_batch-1 x training_epochs

关于python - 基本的 StopAtStepHook 和 MonitoredTrainingSession 用法,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42960304/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com