gpt4 book ai didi

machine-learning - 当 TensorFlow 中有图形对象时,为什么脚本中 saver 的位置很重要?

转载 作者:行者123 更新时间:2023-11-30 09:52:46 26 4
gpt4 key购买 nike

我正在训练一些模型,我注意到当我显式定义图形变量时,我的保护程序对象的创建位置很重要。首先我的代码如下所示:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("tmp_MNIST_data/", one_hot=True)

x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.truncated_normal([784, 10], mean=0.0, stddev=0.1),name='w')
b = tf.Variable(tf.constant(0.1, shape=[10]),name='b')
y = tf.nn.softmax(tf.matmul(x, W) + b)
y_ = tf.placeholder(tf.float32, [None, 10])
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1)) # list of booleans indicating correct predictions
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(cross_entropy)
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(1001):
batch_xs, batch_ys = mnist.train.next_batch(100)
sess.run(fetches=train_step, feed_dict={x: batch_xs, y_: batch_ys})
if i % 100 == 0:
saver.save(sess=sess,save_path='./tmp/mdl_ckpt')
print(sess.run(fetches=accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))

然后我决定将其更改为类似的内容,在我定义变量和定义保护程序的位置似乎非常敏感。例如,如果在创建图形变量之后没有准确定义它们,则会出现错误。同样,我注意到必须在一个变量之后准确定义保护程序(注意在图形定义之后不够),以便保护程序一起捕获所有变量(这并没有'对我来说这没有意义,要求它在所有变量的定义后面而不是单个变量的定义后面才有意义)。

这就是代码现在的样子(注释显示了我定义保护程序的位置):

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("tmp_MNIST_data/", one_hot=True)

graph = tf.Graph()
with tf.Session(graph=graph) as sess:
#saver = tf.train.Saver()
x = tf.placeholder(tf.float32, [None, 784])
saver = tf.train.Saver()
y_ = tf.placeholder(tf.float32, [None, 10])
#saver = tf.train.Saver()
W = tf.Variable(tf.truncated_normal([784, 10], mean=0.0, stddev=0.1),name='w')
#saver = tf.train.Saver()
b = tf.Variable(tf.constant(0.1, shape=[10]),name='b')
y = tf.nn.softmax(tf.matmul(x, W) + b)
#saver = tf.train.Saver()
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1)) # list of booleans indicating correct predictions
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
#saver = tf.train.Saver()
step = tf.Variable(0, name='step')
#saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
#saver = tf.train.Saver()
for i in range(1001):
batch_xs, batch_ys = mnist.train.next_batch(100)
sess.run(fetches=train_step, feed_dict={x: batch_xs, y_: batch_ys})
if i % 100 == 0:
step_assign = step.assign(i)
sess.run(step_assign)
saver.save(sess=sess,save_path='./tmp/mdl_ckpt')
print(step.eval())
print( [ op.name for op in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)] )
print(sess.run(fetches=accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))

上面的代码应该可以工作,但是我很难理解为什么它会这样,或者为什么会发生这种情况。有人知道正确的做法是什么吗?

最佳答案

我不完全确定这里发生了什么,但我怀疑问题与变量没有进入错误的图表有关,或者 session 具有过时的图表版本。您创建一个图表,但不将其设置为默认值,然后使用该图表创建一个 session ...但是当您创建变量时,您没有指定它们应该进入哪个图表。也许 session 的创建将指定的图形设置为默认值,但这不是 tensorflow 的设计使用方式,因此如果它没有在这种情况下经过彻底测试,我不会感到惊讶。

虽然我没有解释或发生了什么,但我可以建议一个简单的解决方案:将图形构建与 session 运行分开。

graph = tf.Graph()
with graph.as_default():
build_graph()
saver = tf.train.Saver()

with tf.Session(graph=graph) as sess:
do_stuff_with(sess)
saver.save(sess, path)

关于machine-learning - 当 TensorFlow 中有图形对象时,为什么脚本中 saver 的位置很重要?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41865482/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com