gpt4 book ai didi

python - Tensorflow:批量训练永远停留在 sess.run 中

转载 作者:行者123 更新时间:2023-11-28 21:41:47 25 4
gpt4 key购买 nike

我正在尝试逐批训练我的模型,因为我找不到任何示例来说明如何正确地训练它。这是我所能做的,我的任务是找到如何在 Tensorflow 中逐批训练模型。

queue=tf.FIFOQueue(capacity=50,dtypes=[tf.float32,tf.float32],shapes=[[10],[2]])
enqueue_op=queue.enqueue_many([X,Y])
dequeue_op=queue.dequeue()

qr=tf.train.QueueRunner(queue,[enqueue_op]*2)

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
X_train_batch,y_train_batch=tf.train.batch(dequeue_op,batch_size=2)
coord=tf.train.Coordinator()
enqueue_threads=qr.create_threads(sess,coord,start=True)
sess.run(tf.local_variables_initializer())
for epoch in range(100):
print("inside loop1")
for iter in range(5):
print("inside loop2")
if coord.should_stop():
break
batch_x,batch_y=sess.run([X_train_batch,y_train_batch])
print("after sess.run")
print(batch_x.shape)
_=sess.run(optimizer,feed_dict={x_place:batch_x,y_place:batch_y})
coord.request_stop()
coord.join(enqueue_threads)

哪些输出,

inside loop1
inside loop2

如你所见,它在运行 batch_x,batch_y=sess.run([X_train_batch,y_train_batch]) 行时永远卡住了。我不知道该如何解决这个问题,或者这是逐批训练模型的正确方法吗?

最佳答案

经过几个小时的搜索,我自己找到了解决方案。所以,我现在在下面回答我自己的问题。队列由调用 tf.train.start_queue_runners() 时创建的后台线程填充。如果不调用此方法,后台线程将不会启动,队列将保持为空,并且训练操作将无限期地阻塞以等待输入。

修复:在训练循环之前调用 tf.train.start_queue_runners(sess)。就像我在下面所做的那样:

queue=tf.FIFOQueue(capacity=50,dtypes=[tf.float32,tf.float32],shapes=[[10],[2]])
enqueue_op=queue.enqueue_many([X,Y])
dequeue_op=queue.dequeue()

qr=tf.train.QueueRunner(queue,[enqueue_op]*2)

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
X_train_batch,y_train_batch=tf.train.batch(dequeue_op,batch_size=2)
coord=tf.train.Coordinator()
enqueue_threads=qr.create_threads(sess,coord,start=True)
tf.train.start_queue_runners(sess)
for epoch in range(100):
print("inside loop1")
for iter in range(5):
print("inside loop2")
if coord.should_stop():
break
batch_x,batch_y=sess.run([X_train_batch,y_train_batch])
print("after sess.run")
print(batch_x.shape)
_=sess.run(optimizer,feed_dict={x_place:batch_x,y_place:batch_y})
coord.request_stop()
coord.join(enqueue_threads)

关于python - Tensorflow:批量训练永远停留在 sess.run 中,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44407873/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com