gpt4 book ai didi

python-2.7 - Tensorflow saver.restore() 不恢复网络

转载 作者:行者123 更新时间:2023-12-03 09:37:59 24 4
gpt4 key购买 nike

我完全迷失了 tensorflow 保护程序方法。

我正在尝试遵循基本的 tensorflow 深度神经网络模型教程。我想弄清楚如何训练网络进行几次迭代,然后在另一个 session 中加载模型。

with tf.Session() as sess:
graph = tf.Graph()
x = tf.placeholder(tf.float32,shape=[None,784])
y_ = tf.placeholder(tf.float32, shape=[None,10])

sess.run(global_variables_initializer())

#Define the Network
#(This part is all copied from the tutorial - not copied for brevity)
#See here: https://www.tensorflow.org/versions/r0.12/tutorials/mnist/pros/

跳到训练。

    #Train the Network
train_step = tf.train.AdamOptimizer(1e-4).minimize(
cross_entropy,global_step=global_step)
correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

saver = tf.train.Saver()

for i in range(101):
batch = mnist.train.next_batch(50)
if i%100 == 0:
train_accuracy = accuracy.eval(feed_dict=
{x:batch[0],y_:batch[1]})
print 'Step %d, training accuracy %g'%(i,train_accuracy)
train_step.run(feed_dict={x:batch[0], y_: batch[1]})
if i%100 == 0:
print 'Test accuracy %g'%accuracy.eval(feed_dict={x:
mnist.test.images, y_: mnist.test.labels})

saver.save(sess,'./mnist_model')

控制台打印出:

Step 0, training accuracy 0.16

Test accuracy 0.0719

Step 100, training accuracy 0.88

Test accuracy 0.8734

接下来我要加载模型

with tf.Session() as sess:
saver = tf.train.import_meta_graph('mnist_model.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))
sess.run(tf.global_variables_initializer())

现在我想重新测试一下模型是否加载完毕

print 'Test accuracy %g'%accuracy.eval(feed_dict={x: 
mnist.test.images, y_: mnist.test.labels})

控制台打印出:

Test accuracy 0.1151

模型似乎没有保存任何数据?我做错了什么?

最佳答案

当您保存模型时,通常所有全局变量都保存在外部文件中,而局部变量则不然。你可以看看这个answer了解差异。

恢复代码中的错误是调用 tf.global_variable_initializer() after saver.restore()saver.restore文档提到,

The variables to restore do not have to have been initialized, as restoring is itself a way to initialize variables.

因此,尝试删除该行,

sess.run(tf.global_variables_initializer())

理想情况下,您应该将其替换为,

sess.run(tf.local_variables_initializer())

关于python-2.7 - Tensorflow saver.restore() 不恢复网络,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44378333/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com