gpt4 book ai didi

python - 了解 TensorFlow 检查点加载?

转载 作者:行者123 更新时间:2023-11-28 18:08:56 25 4
gpt4 key购买 nike

TF 检查点中包含什么?例如,估算器存储一个包含 GraphDef 原型(prototype)的单独文件,您基本上可以执行 tf.import_graph_def(),然后创建一个 tf.train.Saver( ) 并将检查点恢复到图中。现在,如果您有另一个 GraphDef 描述一个完全不同的图,它恰好共享完全相同的变量名称以及匹配的变量维度,您能否将检查点加载到该图中?换句话说,它只是一个变量名到值的映射,还是它假设了在加载过程中要检查的图形的其他内容?如果您尝试将检查点加载到作为原始图的子集的图中(即张量维度和名称匹配,但缺少某些名称)怎么办?

最佳答案

人们什么时候开始阅读文档(?): https://www.tensorflow.org/mobile/prepare_models

这些是不同的概念。只要形状匹配,您就可以只加载重量。如果出现不匹配,您将得到:

Restoring from checkpoint failed. This is most likely due to a mismatch between the current graph and the graph from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint.

但是,您可以调整一个非平凡的情况,其中图形完全不同:

import tensorflow as tf
import numpy as np

test_data = np.arange(4).reshape(1, 2, 2, 1)

# a simple graph and everything is fine
input = tf.placeholder(dtype=tf.float32, shape=[1, 2, 2, 1])
output = tf.layers.conv2d(input, 3, kernel_size=1, name='test', use_bias=False)

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(output, {input: test_data}))
saver = tf.train.Saver()
save_path = saver.save(sess, "/tmp/model.ckpt")
print(tf.trainable_variables())

# reset previous elements
tf.reset_default_graph()

# a new graph
input = tf.placeholder(dtype=tf.float32, shape=[1, 2, 2, 1])
# and wait: this is complete different but same name and shape
W = tf.get_variable('test/kernel', shape=[1, 1, 1, 3])
# but the graph has different operations
output = input + W

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.restore(sess, "/tmp/model.ckpt")
print(sess.run(output, {input: test_data}))

在我的例子中,我得到了:

# 1st version (original graph)
[[[[-0. -0. -0. ]
[-0.08429337 -1.0156475 -0.42691123]]

[[-0.16858673 -2.031295 -0.85382247]
[-0.2528801 -3.0469427 -1.2807337 ]]]]
# 2nd version (altered graph)
[[[[-0.08429337 -1.0156475 -0.42691123]
[ 0.91570663 -0.01564753 0.57308877]]

[[ 1.9157066 0.98435247 1.5730888 ]
[ 2.9157066 1.9843525 2.5730886 ]]]]

关于python - 了解 TensorFlow 检查点加载?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51900049/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com