gpt4 book ai didi

python - `tf.train.Saver`收集变量保存在哪里?

转载 作者:行者123 更新时间:2023-12-01 02:01:53 25 4
gpt4 key购买 nike

我发现如果要保存某个图的所有变量,必须在图的最后定义tf.train.Saver,否则saver无法获取所有变量变量。

这是我的测试代码:

def how_saver_work():
g = tf.Graph()

with g.as_default():

a = tf.Variable(1, name='a')
b = tf.Variable(2, name='b')

print(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))
saver = tf.train.Saver()

c = tf.Variable(3, name='c')
print(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))

with tf.Session() as sess:
tf.global_variables_initializer().run()

print(a.eval())
print(b.eval())
print(c.eval())

save_path = saver.save(sess, "./tmp/model.ckpt")

with tf.Session(graph=g) as sess:
print(g.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))
saver.restore(sess, save_path)

print(a.eval())
print(b.eval())
print(c.eval()) # err: uninitialized

我首先认为保存程序可能会从tf.GraphKeys.GLOBAL_VARIABLEStf.GraphKeys.SAVEABLE_OBJECTS获取变量。但这似乎是错误的。

我还想知道如何向 Saver 添加新变量。例如 var c.

最佳答案

你是对的,saver确实从tf.GraphKeys.GLOBAL_VARIABLEStf.GraphKeys.SAVEABLE_OBJECTS的联合中获取变量在其构建时(参见 Saver._build 实现或下面的引用):

if self._var_list is None:
# pylint: disable=protected-access
self._var_list = variables._all_saveable_objects()

其中_all_saveable_objectspython/ops/variables.py文件中定义

def _all_saveable_objects(scope=None):
"""Returns all variables and `SaveableObject`s that must be checkpointed.

Args:
scope: (Optional.) A string. If supplied, the resulting list is filtered
to include only items whose `name` attribute matches `scope` using
`re.match`. Items without a `name` attribute are never returned if a
scope is supplied. The choice of `re.match` means that a `scope` without
special tokens filters by prefix.

Returns:
A list of `Variable` and `SaveableObject` to be checkpointed
"""
# TODO(andreasst): make this function public once things are settled.
return (ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope) +
ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS, scope))

除了创建新的保护程序之外,无法向保护程序添加变量(参见 https://github.com/tensorflow/tensorflow/issues/2489#issuecomment-221282483 ):

When you create a tf.train.Saver with no arguments, it will implicitly use the current set of variables at the time of Saver construction when it saves and restores. If you add a new variable [...], you have to create a new tf.train.Saver to save it.

关于python - `tf.train.Saver`收集变量保存在哪里?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49528515/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com