gpt4 book ai didi

python - 保存重新训练的 tensorflow 模型时出现问题

转载 作者:行者123 更新时间:2023-11-30 08:48:52 26 4
gpt4 key购买 nike

我正在尝试加载模型(之前保存的),并在重新训练后保存它。加载效果很好,但我在保存时遇到问题,如下所示:

sess=tf.Session()
sess.run(init)
loader = tf.train.import_meta_graph(self.model_path+'.meta')
loader.restore(sess,self.model_path)#tf.train.latest_checkpoint('./'))
print('Model restored')
#retrain
saver=tf.train.Saver()
saver.save(sess, self.model_path)

我第一次保存时没有遇到任何类似的问题,如下所示:

saver=tf.train.Saver()
sess=tf.Session()
sess.run(init)
#train
saver.save(sess, self.model_path)

我遇到的错误是:

File "/share/apps/python2.7-tensorflow-1.2.1/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 1139, in __init__
self.build()
File "/share/apps/python2.7-tensorflow-1.2.1/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 1170, in build
restore_sequentially=self._restore_sequentially)
File "/share/apps/python2.7-tensorflow-1.2.1/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 673, in build
saveables = self._ValidateAndSliceInputs(names_to_saveables)
File "/share/apps/python2.7-tensorflow-1.2.1/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 557, in _ValidateAndSliceInputs
names_to_saveables = BaseSaverBuilder.OpListToDict(names_to_saveables)
File "/share/apps/python2.7-tensorflow-1.2.1/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 535, in OpListToDict
name)
ValueError: At least two variables have the same name: Variable_15/Adam

最佳答案

您看到此消息是因为作用域中有两个同名的变量。 tf.train.import_meta_graph 从文件中读取图形,并将所有操作和张量添加到当前现有图形中。令我惊讶的是 import_meta_graph 一开始甚至没有引发这样的异常。

查看完整示例以重现此行为:

import tensorflow as tf

# tiny graph
x = tf.placeholder(tf.float32, shape=[1, 2], name='input')
output = tf.identity(tf.layers.dense(x, 1), name='output')
cost = tf.reduce_sum(x * output)
# create first time u'beta1_power:0', u'beta2_power:0'
train_op = tf.train.AdamOptimizer().minimize(cost)

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(tf.global_variables())
saver.save(sess, './adam/my_model')

print([v.name for v in tf.global_variables()])

# create second time u'beta1_power:0', u'beta2_power:0'
meta_graph = tf.train.import_meta_graph('./adam/my_model.meta')
meta_graph.restore(sess, './adam/my_model')

print([v.name for v in tf.global_variables()])

saver = tf.train.Saver(tf.global_variables())
# exception as there are now two times: u'beta1_power:0', u'beta2_power:0'
saver.save(sess, './adam/my_model2')

解决方案是:

  • tf.trainimport_meta_graph 之前使用 tf.reset_default_graph() 清除图表
  • tf.train.import_meta_graph 使用新 session
  • 只需使用 tf.train.Saver().restore(sess, '/tmp/model/my_model') 加载权重

关于python - 保存重新训练的 tensorflow 模型时出现问题,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48857159/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com