gpt4 book ai didi

python - sess.run() 导致训练缓慢

转载 作者:行者123 更新时间:2023-11-30 08:53:54 27 4
gpt4 key购买 nike

我正在训练 CNN,我相信我使用了 sess.run()导致我的训练速度非常慢。

本质上,我使用 mnist数据集...

from tensorflow.examples.tutorials.mnist import input_data
...
...
features = input_data.read_data_sets("/tmp/data/", one_hot=True)

问题是,CNN的第一层必须接受[batch_size, 28, 28, 1]形式的图像。 ,这意味着我必须先转换每个图像,然后再将其输入 CNN。

我用我的脚本执行以下操作...

x = tf.placeholder(tf.float32, [None, 28, 28, 1])
y = tf.placeholder(tf.float32, [None, 10])
...
...
with tf.Session() as sess:

for epoch in range(25):

total_batch = int(features.train.num_examples/500)

avg_cost = 0

for i in range(total_batch):

batch_xs, batch_ys = features.train.next_batch(10)

# Notice this line.
_, c = sess.run([train_op, loss], feed_dict={x:sess.run(tf.reshape(batch_xs, [10, 28, 28, 1])), y:batch_ys})

avg_cost += c / total_batch

if (epoch + 1) % 1 == 0:
print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f}".format(avg_cost))

注意注释行。我正在从训练集中获取第一批,并将其 reshape 为正确的格式 [batch_size, 28, 28, 1] 。我得打电话sess.run()每一次,我相信这就是训练如此缓慢的原因。

如何防止这种情况发生。我尝试使用 numpy 在另一个脚本中重新格式化数据,但它仍然给我带来了问题,因为我无法喂养 numpy不运行的数组 sess.run() 。有人可以告诉我如何在培训类(class)之外格式化数据吗?也许我可以在另一个脚本中格式化数据并将其加载到包含我的 CNN 的脚本中?

最佳答案

您绝对不应该在每次迭代时对新操作使用内部 sess.run() (尽管我不确定它确实会减慢您的速度)。您应该执行以下操作之一:

  • 有一个与您的输入形状相同的占位符,例如[None, 28*28*1],后跟 tf.reshape([None, 28, 28, 1]),位于网络的开头(而不是您的 tf.placeholder([None, 28, 28, 1]))

或者

  • 保留您的神经网络,并使用 numpy reshape 而不是 tensorflow 重新格式化:_, c = sess.run([train_op, loss], feed_dict={x:batch_xs.reshape( [-1, 28, 28 , 1]), y:batch_ys})

如果你只写 _, c = sess.run([train_op, loss], feed_dict={x:tf.reshape(batch_xs, [10, 28, 28, 1]), 它也可能有效y:batch_ys}) 但您不应该这样做,因为它会在每次迭代时在您的图中创建一个新操作。

关于python - sess.run() 导致训练缓慢,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46970039/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com