gpt4 book ai didi

python - 训练后如何保存/恢复模型?

转载 作者:IT老高 更新时间:2023-10-28 12:02:10 25 4
gpt4 key购买 nike

在 Tensorflow 中训练模型后:

  1. 如何保存训练好的模型?
  2. 你以后如何恢复这个保存的模型?

最佳答案

我正在改进我的答案,以添加更多用于保存和恢复模型的详细信息。

在(及之后)Tensorflow 0.11 版:

保存模型:

import tensorflow as tf

#Prepare to feed input, i.e. feed_dict and placeholders
w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")
b1= tf.Variable(2.0,name="bias")
feed_dict ={w1:4,w2:8}

#Define a test operation that we will restore
w3 = tf.add(w1,w2)
w4 = tf.multiply(w3,b1,name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer())

#Create a saver object which will save all the variables
saver = tf.train.Saver()

#Run the operation by feeding input
print sess.run(w4,feed_dict)
#Prints 24 which is sum of (w1+w2)*b1

#Now, save the graph
saver.save(sess, 'my_test_model',global_step=1000)

恢复模型:

import tensorflow as tf

sess=tf.Session()
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))


# Access saved Variables directly
print(sess.run('bias:0'))
# This will print 2, which is the value of bias that we saved


# Now, let's access and create placeholders variables and
# create feed-dict to feed new data

graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}

#Now, access the op that you want to run.
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

print sess.run(op_to_restore,feed_dict)
#This will print 60 which is calculated

这里已经很好地解释了这个和一些更高级的用例。

A quick complete tutorial to save and restore Tensorflow models

关于python - 训练后如何保存/恢复模型?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/33759623/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com