gpt4 book ai didi

python - KeyError : The tensor variable , 引用不存在的张量

转载 作者:太空宇宙 更新时间:2023-11-04 00:20:51 24 4
gpt4 key购买 nike

我使用 LSTMCell 训练了一个模型来生成文本。我启动了 tensorflow session 并使用 tf.global_variables_initializer() 保存所有 tensorflow 变量。

import tensorflow as tf
sess = tf.Session()
//code blocks
run_init_op = tf.global_variables_intializer()
sess.run(run_init_op)
saver = tf.train.Saver()
#varible that makes prediction
prediction = tf.nn.softmax(tf.matmul(last,weight)+bias)
#feed the inputdata into model and trained
#saved the model
#save the tensorflow model
save_path= saver.save(sess,'/tmp/text_generate_trained_model.ckpt')
print("Model saved in the path : {}".format(save_path))

模型接受训练并保存所有 session 。查看整个代码的链接 lstm_rnn.py

现在我加载了存储的模型并尝试为文档生成文本。所以,我用下面的代码恢复了模型

tf.reset_default_graph()
imported_data = tf.train.import_meta_graph('text_generate_trained_model.ckpt.meta')
with tf.Session() as sess:
imported_meta.restore(sess,tf.train.latest_checkpoint('./'))

#accessing the default graph which we restored
graph = tf.get_default_graph()

#op that we can be processed to get the output
#last is the tensor that is the prediction of the network
y_pred = graph.get_tensor_by_name("prediction:0")
#generate characters
for i in range(500):
x = np.reshape(pattern,(1,len(pattern),1))
x = x / float(n_vocab)
prediction = sess.run(y_pred,feed_dict=x)
index = np.argmax(prediction)
result = int_to_char[index]
seq_in = [int_to_char[value] for value in pattern]
sys.stdout.write(result)
patter.append(index)
pattern = pattern[1:len(pattern)]

print("\n Done...!")
sess.close()

我开始知道图中不存在预测变量。

KeyError: "The name 'prediction:0' refers to a Tensor which does not exist. The operation, 'prediction', does not exist in the graph."

完整代码可在此处获得 text_generation.py

虽然我保存了所有 tensorflow 变量,但预测张量并未保存在 tensorflow 计算图中。我的 lstm_rnn.py 文件有什么问题。

谢谢!

最佳答案

要使 graph.get_tensor_by_name("prediction:0") 正常工作,您应该在创建它时为其命名。你可以这样命名它

prediction = tf.nn.softmax(tf.matmul(last,weight)+bias, name="prediction")

如果您已经训练了模型并且无法重命名张量,您仍然可以通过默认名称获取该张量,如下所示,

y_pred = graph.get_tensor_by_name("Reshape_1:0")

如果 Reshape_1 不是张量的实际名称,您将必须查看图中的名称并弄清楚。您可以使用

检查
for op in graph.get_operations():
print(op.name)

关于python - KeyError : The tensor variable , 引用不存在的张量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49166819/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com