gpt4 book ai didi

python - TF 对象检测 Zoo 模型没有可训练变量?

转载 作者:太空宇宙 更新时间:2023-11-04 01:57:13 27 4
gpt4 key购买 nike

TF Objection Detection Zoo 中的模型有 meta+ckpt 文件、Frozen.pb 文件和 Saved_model 文件。

我尝试使用 meta+ckpt 文件进一步训练,并为特定张量提取一些权重以用于研究目的。我看到模型没有任何可训练的变量。

vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
print(vars)

上面的代码片段给出了一个 [] 列表。我也尝试使用以下内容。

vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
print(vars)

我再次得到一个 [] 列表。

这怎么可能?模型是否剥离了变量?还是 tf.Variable(trainable=False) ?我在哪里可以获得具有有效可训练变量的 meta+ckpt 文件。我专门看SSD+mobilnet机型

更新:

以下是我用于恢复的代码片段。它在一个类中,因为我正在为某些应用程序制作自定义工具。

def _importer(self):
sess = tf.InteractiveSession()
with sess.as_default():
reader = tf.train.import_meta_graph(self.metafile,
clear_devices=True)
reader.restore(sess, self.ckptfile)

def _read_graph(self):
sess = tf.get_default_session()
with sess.as_default():
vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
print(vars)

更新 2:

我还尝试了以下代码段。简约复古风格。

model_dir = 'ssd_mobilenet_v2/'

meta = glob.glob(model_dir+"*.meta")[0]
ckpt = meta.replace('.meta','').strip()

sess = tf.InteractiveSession()
graph = tf.Graph()
with graph.as_default():
with tf.Session() as sess:
reader = tf.train.import_meta_graph(meta,clear_devices=True)
reader.restore(sess,ckpt)

vari = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
for var in vari:
print(var.name,"\n")

上面的代码片段还给出了[]变量列表

最佳答案

经过一些研究,您问题的最终答案是不,他们没有。这很明显,直到您意识到 saved_model 中的 variables 目录是空的。

对象检测模型zoo提供的checkpoint文件包含以下文件:

.
|-- checkpoint
|-- frozen_inference_graph.pb
|-- model.ckpt.data-00000-of-00001
|-- model.ckpt.index
|-- model.ckpt.meta
|-- pipeline.config
`-- saved_model
|-- saved_model.pb
`-- variables

pipeline.config 是保存模型的配置文件,frozen_inference_graph.pb 是现成的推理。请注意 checkpointmodel.ckpt.data-00000-of-00001model.ckpt.metamodel.ckpt。 index 都对应checkpoint。 (Here 你可以找到一个很好的解释)

所以当你想得到可训练的变量时,唯一有用的就是saved_model目录。

Use SavedModel to save and load your model—variables, the graph, and the graph's metadata. This is a language-neutral, recoverable, hermetic serialization format that enables higher-level systems and tools to produce, consume, and transform TensorFlow models.

要恢复 SavedModel,您可以使用 API tf.saved_model.loader.load() , 这个 api 包含一个名为 tags 的参数,它指定了 MetaGraphDef 的类型。所以如果你想得到可训练的变量,你需要在调用api时指定tag_constants.TRAINING

我试图调用此 api 来恢复变量,但它给了我一个错误

MetaGraphDef associated with tags 'train' could not be found in SavedModel. To inspect available tag-sets in the SavedModel, please use the SavedModel CLI: saved_model_cli

所以我执行了这个 saved_model_cli 命令来检查 SavedModel 中可用的所有标签。

#from directory saved_model
saved_model_cli show --dir . --all

输出是

MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
...
signature_def['serving_default']:
...

所以在这个SavedModel中没有标签train,只有serve。因此,此处的 SavedModel 仅用于 tensorflow 服务。这意味着当这些文件在创建时未使用标记 training 指定时,无法从这些文件中恢复训练变量。

P.S.:以下代码是我用来恢复 SavedModel 的代码。设置tag_constants.TRAINING时加载无法完成,设置tag_constants.SERVING时加载成功但变量为空。

graph = tf.Graph()
with tf.Session(graph=graph) as sess:
tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.TRAINING], export_dir)
variables = graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
print(variables)

P.P.S:我找到了创建 SavedModel 的脚本 here .可见创建SavedModel时确实没有train标签。

关于python - TF 对象检测 Zoo 模型没有可训练变量?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56547313/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com