gpt4 book ai didi

python-2.7 - 如何在 Tensorflow 中使用 CheckpointReader 恢复变量

转载 作者:行者123 更新时间:2023-12-03 17:59:20 25 4
gpt4 key购买 nike

如果当前模型中有相同的变量名,我正在尝试从检查点文件中恢复一些变量。
我发现有一些方法,如 Tensorfow Github

所以我想做的是使用 has_tensor("variable.name") 检查检查点文件中的变量名称,如下所示,

...    
reader = tf.train.NewCheckpointReader(ckpt_path)
for v in tf.trainable_variables():
print v.name
if reader.has_tensor(v.name):
print 'has tensor'
...

但我发现 v.name 返回变量 namecolon+number。例如,我有变量名 W_ob_o 然后 v.name 返回 W_o:0, b_o:0 .

但是 reader.has_tensor() 需要 name 没有 colonnumber 作为 W_o, b_o

我的问题是:如何去掉变量名末尾的冒号数字以便读取变量?
有没有更好的方法来恢复这些变量?

最佳答案

你可以使用 string.split()获取张量名称:

...    
reader = tf.train.NewCheckpointReader(ckpt_path)
for v in tf.trainable_variables():
tensor_name = v.name.split(':')[0]
print tensor_name
if reader.has_tensor(tensor_name):
print 'has tensor'
...

接下来,让我用一个示例来说明如何从 .cpkt 文件中恢复每个可能的变量。首先,我们将v2v3保存在tmp.ckpt中:

import tensorflow as tf

v1 = tf.Variable(tf.ones([1]), name='v1')
v2 = tf.Variable(2 * tf.ones([1]), name='v2')
v3 = tf.Variable(3 * tf.ones([1]), name='v3')

saver = tf.train.Saver({'v2': v2, 'v3': v3})

with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
saver.save(sess, 'tmp.ckpt')

这就是我如何恢复出现在 tmp.ckpt 中的每个变量(属于新图):

with tf.Graph().as_default():
assert len(tf.trainable_variables()) == 0
v1 = tf.Variable(tf.zeros([1]), name='v1')
v2 = tf.Variable(tf.zeros([1]), name='v2')

reader = tf.train.NewCheckpointReader('tmp.ckpt')
restore_dict = dict()
for v in tf.trainable_variables():
tensor_name = v.name.split(':')[0]
if reader.has_tensor(tensor_name):
print('has tensor ', tensor_name)
restore_dict[tensor_name] = v

saver = tf.train.Saver(restore_dict)
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
saver.restore(sess, 'tmp.ckpt')
print(sess.run([v1, v2])) # prints [array([ 0.], dtype=float32), array([ 2.], dtype=float32)]

此外,您可能希望确保形状和数据类型匹配。

关于python-2.7 - 如何在 Tensorflow 中使用 CheckpointReader 恢复变量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/39137597/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com