gpt4 book ai didi

Tensorflow Metagraph 基础知识

转载 作者:行者123 更新时间:2023-12-03 00:39:39 24 4
gpt4 key购买 nike

我想训练我的 Tensorflow 模型,卡住快照,然后使用新的输入数据以前馈模式运行(无需进一步训练)。问题:

  1. tf.train.export_meta_graphtf.train.import_meta_graph 是实现此目的的正确工具吗?
  2. 我是否需要在 collection_list 中包含我想要包含在快照中的所有变量的名称? (对我来说最简单的就是包含所有内容。)
  3. Tensorflow 文档说:“如果未指定 collection_list,则将导出模型中的所有集合。”这是否意味着,如果我在 collection_list 中未指定任何变量,那么模型中的所有变量都会被导出,因为它们位于默认集合中?
  4. Tensorflow 文档说:“为了将 Python 对象序列化到 MetaGraphDef 或从 MetaGraphDef 序列化,Python 类必须实现 to_proto() 和 from_proto() 方法,并使用 register_proto_function 将它们注册到系统。"这是否意味着 to_proto()from_proto() 必须仅添加到我已定义并想要导出的类中?如果我只使用标准 Python 数据类型(int、float、list、dict),那么这无关紧要吗?

提前致谢。

最佳答案

有点晚了,但我还是会尽力回答。

  1. Are tf.train.export_meta_graph and tf.train.import_meta_graph the right tools for this?

我会这么说。请注意tf.train.export_meta_graph当您通过tf.train.Saver保存模型时,会隐式调用您。要点是:

# create the model
...
saver = tf.train.Saver()
with tf.Session() as sess:
...
# save graph and variables
# if you are using global_step, the saver will automatically keep the n=5 latest checkpoints
saver.save(sess, save_path, global_step)

然后恢复:

save_path = ...
latest_checkpoint = tf.train.latest_checkpoint(save_path)
saver = tf.train.import_meta_graph(latest_checkpoint + '.meta')
with tf.Session() as sess:
saver.restore(sess, latest_checkpoint)

请注意,不要调用 tf.train.import_meta_graph您还可以首先调用用于创建模型的原始代码段。不过,我认为使用 import_meta_graph 更优雅。这样,即使您无权访问创建模型的代码,您也可以恢复模型。

<小时/>
  1. Do I need to include, in collection_list, the names of all variables that I want included in the snapshot? (Simplest for me would be to include everything.)

没有。然而问题有点令人困惑:collection_listexport_meta_graph并不意味着是变量列表,而是集合(即字符串键列表)。

集合非常方便,例如所有可训练变量都会自动包含在集合 tf.GraphKeys.TRAINABLE_VARIABLES 中您可以通过调用获取:

tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

tf.trainable_variables()  # defaults to the default graph

如果恢复后您需要访问可训练变量之外的其他中间结果,我发现将它们放入自定义集合中非常方便,如下所示:

...
input_ = tf.placeholder(tf.float32, shape=[64, 64])
....
tf.add_to_collection('my_custom_collection', input_)

此集合会自动存储(除非您通过在 collection_listexport_meta_graph 参数中省略此集合的名称来明确指定不存储)。所以你可以简单地检索 input_恢复后的占位符如下:

...
with tf.Session() as sess:
saver.restore(sess, latest_checkpoint)
input_ = tf.get_collection_ref('my_custom_collection')[0]
<小时/>
  1. The Tensorflow docs say: "If no collection_list is specified, all collections in the model will be exported." Does that mean that if I specify no variables in collection_list then all variables in the model are exported because they are in the default collection?

是的。再次注意 collection_list 的微妙细节是集合而不是变量的列表。事实上,如果您只想保存某些变量,您可以在构造 tf.train.Saver 时指定这些变量。目的。来自 tf.train.Saver.__init__ 的文档:

 """Creates a `Saver`.

The constructor adds ops to save and restore variables.

`var_list` specifies the variables that will be saved and restored. It can
be passed as a `dict` or a list:

* A `dict` of names to variables: The keys are the names that will be
used to save or restore the variables in the checkpoint files.
* A list of variables: The variables will be keyed with their op name in
the checkpoint files.
<小时/>
  1. The Tensorflow docs say: "In order for a Python object to be serialized to and from MetaGraphDef, the Python class must implement to_proto() and from_proto() methods, and register them with the system using register_proto_function." Does that mean that to_proto() and from_proto() must be added only to classes that I have defined and want exported? If I am using only standard Python data types (int, float, list, dict) then is this irrelevant?

我从未使用过此功能,但我想说你的解释是正确的。

关于Tensorflow Metagraph 基础知识,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/39801913/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com