gpt4 book ai didi

python - 不使用 Supervisor 时 Tensorflow 会卡住

转载 作者:行者123 更新时间:2023-12-01 03:02:39 27 4
gpt4 key购买 nike

无 GPU、无队列、Tensorflow 1.1.0

这是示例 LSTM 代码:

https://github.com/tensorflow/models/blob/master/tutorials/rnn/ptb/ptb_word_lm.py

这段代码有效。它打印训练过程信息,很酷。现在,我尝试使用freeze_graph()将经过训练的模型图写入磁盘,最终我发现这个LSTM教程使用Supervisor来训练模型,并且Supervisor 卡住图,并且卡住的图不能在 freeze_graph() 过程中使用。

我尝试从 Supervisor 切换到使用普通 session 。 唯一更改发生在 main() 过程中(除了导入一些内容之外)。现在看起来像这样(更改的部分已突出显示,并且我删除了所有与图形保存相关的内容,这不是这里的问题):

with tf.Graph().as_default():
initializer = tf.random_uniform_initializer(
-config.init_scale, config.init_scale)
with tf.name_scope("Train"):
train_input = PTBInput(
config=config, data=train_data, name="TrainInput")
with tf.variable_scope("Model", reuse=None, initializer=initializer):
m = PTBModel(
is_training=True, config=config, input_=train_input)
tf.summary.scalar("Training Loss", m.cost)
tf.summary.scalar("Learning Rate", m.lr)
with session.Session() as sess: # CHANGED
sess.run(variables.global_variables_initializer()) # CHANGED
for i in range(config.max_max_epoch):
lr_decay = config.lr_decay ** max(i +
1 - config.max_epoch, 0.0)
m.assign_lr(sess, config.learning_rate * lr_decay)
print("Epoch: %d Learning rate: %.3f" %
(i + 1, sess.run(m.lr)))
train_perplexity = run_epoch(sess, m, eval_op=m.train_op,
verbose=True)
print("Epoch: %d Train Perplexity: %.3f" %
(i + 1, train_perplexity))

在这些更改之后,整个事情开始卡住在这一行:

https://github.com/tensorflow/models/blob/master/tutorials/rnn/ptb/ptb_word_lm.py#L300

这是模型内部的 session.run() 调用(不会对 Ctrl+C 使用react,可使用 kill -9 杀死):

vals = session.run(fetches, feed_dict)

之前的 session.run() 调用(有一些)工作得很好。

我做错了什么?看起来所有变量都初始化得很好(这是由原始代码中的 Supervisor 完成的)。有什么想法吗?

最佳答案

当您使用tf.train.Supervisor时,框架代码自动调用tf.train.start_queue_runners(sess) (以及初始化变量)在 session 开始时。如果您切换回使用原始 tf.Session ,您必须手动调用它来启动输入管道。像下面这样的更改应该有效:

# ...
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
tf.train.start_queue_runners(sess)
# ...

关于python - 不使用 Supervisor 时 Tensorflow 会卡住,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43669417/

27 4 0
文章推荐: python - 第二个循环中存在冲突的选项字符串
文章推荐: python - 无法在 Python 3.6 中导入 PyQ : flat namespace error
文章推荐: python - 使用 Python 子进程运行 SLURM 脚本将多个长作业提交到队列并等待作业完成,然后再继续 python 脚本
文章推荐: jquery - 将事件处理程序从 <input> 复制到
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com