gpt4 book ai didi

save - 如何恢复保存的 tensorflow 模型?

转载 作者:行者123 更新时间:2023-12-04 02:05:27 25 4
gpt4 key购买 nike

有两个python文件,第一个是保存tensorflow的模型。第二个用于恢复保存的模型。

问题:

  1. 当我依次运行这两个文件时,就可以了。

  2. 当我运行第一个时,重新启动编辑并运行第二个,它告诉我 w1 没有定义?

我想做的是:

  1. 保存 tensorflow 模型

  2. 恢复保存的模型

这是怎么回事?感谢您的热心帮助?

model_save.py

import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver()

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.save(sess, 'SR\\my-model')

model_restore.py

import tensorflow as tf

with tf.Session() as sess:
saver = tf.train.import_meta_graph('SR\\my-model.meta')
saver.restore(sess,'SR\\my-model')
print (sess.run(w1))

enter image description here

最佳答案

简而言之,你应该使用

print (sess.run(tf.get_default_graph().get_tensor_by_name('w1:0')))

而不是 print (sess.run(w1)) 在您的 model_restore.py 文件中。

model_save.py

import tensorflow as tf
w1_node = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2_node = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver()

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(w1_node.eval()) # [ 0.43350926 1.02784836]
#print(w1.eval()) # NameError: name 'w1' is not defined
saver.save(sess, 'my-model')

w1_node 只在model_save.py 中定义,model_restore.py 文件无法识别。当我们通过name 调用Tensor 变量时,我们应该使用get_tensor_by_name,因为这篇文章Tensorflow: How to get a tensor by name?建议。

model_restore.py

import tensorflow as tf

with tf.Session() as sess:
saver = tf.train.import_meta_graph('my-model.meta')
saver.restore(sess,'my-model')
print (sess.run(tf.get_default_graph().get_tensor_by_name('w1:0')))
# [ 0.43350926 1.02784836]
print(tf.global_variables()) # print tensor variables
# [<tf.Variable 'w1:0' shape=(2,) dtype=float32_ref>,
# <tf.Variable 'w2:0' shape=(5,) dtype=float32_ref>]
for op in tf.get_default_graph().get_operations():
print str(op.name) # print all the operation nodes' name

关于save - 如何恢复保存的 tensorflow 模型?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43830036/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com