gpt4 book ai didi

tensorflow - 恢复保存的 TensorFlow 模型以在测试集上进行评估

转载 作者:行者123 更新时间:2023-12-01 09:17:10 24 4
gpt4 key购买 nike

我见过几个posts关于恢复 TF 模型和 exporting graphs 上的 Google 文档页面但我想我错过了一些东西。

我使用 Gist 中的代码将模型与此 utils 文件一起保存到其中 defines型号

现在我想恢复它并在以前看不见的测试数据中运行如下:

def evaluate(X_data, y_data):
num_examples = len(X_data)
total_accuracy = 0
total_loss = 0
sess = tf.get_default_session()
acc_steps = len(X_data) // BATCH_SIZE
for i in range(acc_steps):
batch_x, batch_y = next_batch(X_val, Y_val, BATCH_SIZE)

loss, accuracy = sess.run([loss_value, acc], feed_dict={
images_placeholder: batch_x,
labels_placeholder: batch_y,
keep_prob: 0.5
})
total_accuracy += (accuracy * len(batch_x))
total_loss += (loss * len(batch_x))
return (total_accuracy / num_examples, total_loss / num_examples)

## re-execute the code that defines the model

# Image Tensor
images_placeholder = tf.placeholder(tf.float32, shape=[None, 32, 32, 3], name='x')

gray = tf.image.rgb_to_grayscale(images_placeholder, name='gray')

gray /= 255.

# Label Tensor
labels_placeholder = tf.placeholder(tf.float32, shape=(None, 43), name='y')

# dropout Tensor
keep_prob = tf.placeholder(tf.float32, name='drop')

# construct model
logits = inference(gray, keep_prob)

# calculate loss
loss_value = loss(logits, labels_placeholder)

# training
train_op = training(loss_value, 0.001)

# accuracy
acc = accuracy(logits, labels_placeholder)

with tf.Session() as sess:
loader = tf.train.import_meta_graph('gtsd.meta')
loader.restore(sess, tf.train.latest_checkpoint('./'))
sess.run(tf.initialize_all_variables())
test_accuracy = evaluate(X_test, y_test)
print("Test Accuracy = {:.3f}".format(test_accuracy[0]))

我得到的测试准确率只有 3%。但是,如果我在训练模型后不关闭笔记本并立即运行测试代码,我将获得 95% 的准确度。

这让我相信我没有正确加载模型?

最佳答案

问题出在这两行:

loader.restore(sess, tf.train.latest_checkpoint('./'))
sess.run(tf.initialize_all_variables())

第一行从检查点加载保存的模型。第二行重新初始化模型中的所有变量(例如权重矩阵、卷积滤波器和偏置向量),通常为随机数,并覆盖加载值。

解决方案很简单:删除第二行 (sess.run(tf.initialize_all_variables())),然后评估将从检查点加载的训练值继续进行。


PS。此更改很有可能会给您一个关于“未初始化变量”的错误。在这种情况下,您应该执行 sess.run(tf.initialize_all_variables()) 以初始化未保存在检查点 before 执行 loader.restore(sess , tf.train.latest_checkpoint('./')).

关于tensorflow - 恢复保存的 TensorFlow 模型以在测试集上进行评估,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41287289/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com