gpt4 book ai didi

python - tf.train.init_from_checkpoint 不初始化使用 tf.Variable 创建的变量

转载 作者:太空狗 更新时间:2023-10-29 18:08:54 32 4
gpt4 key购买 nike

tf.train.init_from_checkpoint 似乎初始化了通过 tf.get_variable 创建的变量,但不是 通过 tf.Variable 创建的变量

例如,让我们创建两个变量并保存它们:

import tensorflow as tf

tf.Variable(1.0, name='foo')
tf.get_variable('bar',initializer=1.0)
saver = tf.train.Saver()
with tf.Session() as sess:
tf.global_variables_initializer().run()
saver.save(sess, './model', global_step=0)

如果我通过 tf.train.Saver 再次加载它们,一切正常:变量被加载回 1,即使它们在这里被初始化为零:

import tensorflow as tf

foo = tf.Variable(0.0, name='foo')
bar = tf.get_variable('bar', initializer=0.0)
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, './model-0')
print(f'foo: {foo.eval()} bar: {bar.eval()}')
# foo: 1.0 bar: 1.0

但是如果我使用 tf.train.init_from_checkpoint 我会得到

import tensorflow as tf

foo = tf.Variable(0.0, name='foo')
bar = tf.get_variable('bar', initializer=0.0)
tf.train.init_from_checkpoint('./model-0', {'/':'/'})
with tf.Session() as sess:
tf.global_variables_initializer().run()
print(f'foo: {foo.eval()} bar: {bar.eval()}')
# foo: 0.0 bar: 1.0

bar 按预期设置回 1,但 foo 仍为 0。

这是预期的行为吗?如果是,为什么?

最佳答案

是的,这是有意的。此行为在 _init_from_checkpoint 方法中进行了描述,该方法在加载要恢复的变量时遍历赋值映射。

 for tensor_name_in_ckpt, current_var_or_name in sorted(
six.iteritems(assignment_map)):
var = None

它首先将要恢复的变量设置为 None,如果满足几个条件之一,它将重置为当前变量名。在这种特殊情况下,循环包含语句

if "/"in current_var_or_name

因此,它将从先前创建的字典 store_vars 中加载变量。它是在 _init_from_checkpoint 检查赋值映射中的当前变量是否为 tf.Variable 之后创建的,此时为 False。

 if _is_variable(current_var_or_name) or (
isinstance(current_var_or_name, list)
and all(_is_variable(v) for v in current_var_or_name)):
var = current_var_or_name
else:
store_vars = vs._get_default_variable_store()._vars

store_vars 是由内部类 _VariableStore 创建的,更准确地说,是由它的 _get_default_variable_store() 方法创建的。此类使用 get_variable 作为变量构造函数。因为 tf.Variable 没有默认作用域,所以 tf.get_variable 首先调用 tf.get_variable_scope(),返回当前变量作用域。 'foo' 不在此范围内。此外,tf.Variable 每次调用都会创建一个新变量,并且不允许共享。

store_vars 由默认作用域成员构成,因此它仅包含“bar”变量,foo 稍后使用 tf.Variable 添加到变量集合 操作。

但是,如果 assignment_map 将包含 {'foo':foo, 'bar':bar},则上面提到的 _init_from_checkpoint 将找到这些变量并加载它们。所以在这种情况下,您的代码将输出 foo: 1.0 bar: 1.0

您可以在 https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/checkpoint_utils.py 中找到代码

https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/ops/variable_scope.py另请参阅此答案 What is the default variable_scope in Tensorflow?

关于python - tf.train.init_from_checkpoint 不初始化使用 tf.Variable 创建的变量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54905301/

32 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com