gpt4 book ai didi

python - 打印特定层的保存权重 [Tensorflow]

转载 作者:太空宇宙 更新时间:2023-11-04 02:55:10 24 4
gpt4 key购买 nike

我在 PyCharm 中使用 Tensorflow 1.0 和 python 3.5
执行后 this代码,我在每 500 次迭代时保存了模型(索引、元和 ckpt 文件)。现在要加载模型,我们需要指向哪个文件?
我写了下面的代码来加载 ckpt(weights) 文件(没有对上面的 github 代码进行任何更改)

w1 = tf.Variable(tf.zeros([5, 5, 1, 32]), name="conv1/W")
b1 = tf.Variable(tf.zeros(shape=[32]), name="conv1/B")

w2 = tf.Variable(tf.zeros([5, 5, 32, 64]), name="conv2/W")
b2 = tf.Variable(tf.zeros(shape=[64]), name="conv2/B")

w3 = tf.Variable(tf.zeros([3136, 1024]), name="fc1/W")
b3 = tf.Variable(tf.zeros(shape=[1024]), name="fc1/B")

w4 = tf.Variable(tf.zeros([1024, 10]), name="fc2/W")
b4 = tf.Variable(tf.zeros(shape=[10]), name="fc2/B")
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, "D:/tmp/mnist_tutorial/model.ckpt")

出现以下错误

W c:\tf_jenkins\home\workspace\release-win\device\cpu\os\windows\tensorflow\core\framework\op_kernel.cc:975] Not found: Unsuccessful TensorSliceReader constructor: Failed to find any matching files for D:/tmp/mnist_tutorial/model.ckpt
Traceback (most recent call last):
File "C:\Users\Admin\AppData\Local\Programs\Python\Python35\lib\site-packages\tensorflow\python\client\session.py", line 1021, in _do_call return fn(*args)
File "C:\Users\Admin\AppData\Local\Programs\Python\Python35\lib\site-packages\tensorflow\python\client\session.py", line 1003, in _run_fn status, run_metadata)
File "C:\Users\Admin\AppData\Local\Programs\Python\Python35\lib\contextlib.py", line 66, in exit next(self.gen)
File "C:\Users\Admin\AppData\Local\Programs\Python\Python35\lib\site-packages\tensorflow\python\framework\errors_impl.py", line 469, in raise_exception_on_not_ok_status

pywrap_tensorflow.TF_GetCode(status)) tensorflow.python.framework.errors_impl.NotFoundError: Unsuccessful TensorSliceReader constructor: Failed to find any matching files for D:/tmp/mnist_tutorial/model.ckpt
[[Node: save/RestoreV2_1 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_save/Const_0, save/RestoreV2_1/tensor_names, save/RestoreV2_1/shape_and_slices)]]

在训练开始之前(在卷积层的函数定义中),我们可以通过以下方式打印各个层的权重:

 w = tf.Variable(tf.truncated_normal([5, 5, 1, 64], stddev=0.1), name="W")
b = tf.Variable(tf.constant(0.1, shape=[64]), name="B")
init = tf.global_variables_initializer()
with tf.Session()as sess:
sess.run(init)
print("weight type is ", w)
print('bias type is', b)
print("random generated weights are: ")
x = tf.Print('conv/W:0', [w],summarize=1600)
sess.run(x)
print("Generated Biases are: ")
y = tf.Print(b, [b],summarize=64)
sess.run(y)

如果有很多卷积层和全连接层,如何从 *.ckpt 文件中加载和打印任何特定层的权重和偏差,因为上述方法不起作用


更新:更改了代码并更新了错误消息

最佳答案

查看inspect_checkpoint tool在 tensorflow 存储库中读取检查点。

关于python - 打印特定层的保存权重 [Tensorflow],我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42623978/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com