gpt4 book ai didi

python - 如何计算从 pb 文件加载的 tensorflow 模型的触发器

转载 作者:太空宇宙 更新时间:2023-11-03 11:39:36 31 4
gpt4 key购买 nike

我有一个模型保存在 pb 文件中。我希望计算它的失败。我的示例代码如下:

import tensorflow as tf
import sys
from tensorflow.python.platform import gfile

from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.python.util import compat

pb_file = 'themodel.pb'

run_meta = tf.RunMetadata()
with tf.Session() as sess:
print("load graph")
with gfile.FastGFile(pb_path,'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
flops = tf.profiler.profile(tf.get_default_graph(), run_meta=run_meta,
options=tf.profiler.ProfileOptionBuilder.float_operation())
print("test flops:{:,}".format(flops.total_float_ops))

打印信息奇怪。我的模型有几十层,但它在打印信息中只报告了 18 个触发器。我很确定模型已正确加载,因为如果我尝试按如下方式打印每一层的名称:

print([n.name for n in tf.get_default_graph().as_graph_def().node])

打印信息显示正确的网络。

我的代码有什么问题?

谢谢!

最佳答案

我想我找到了我的问题的原因和解决方案。以下代码可以打印给定 pb 文件的触发器。

import os
import tensorflow as tf
from tensorflow.core.framework import graph_pb2
from tensorflow.python.framework import importer

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

pb_path = 'mymodel.pb'

run_meta = tf.RunMetadata()
with tf.Graph().as_default():
output_graph_def = graph_pb2.GraphDef()
with open(pb_path, "rb") as f:
output_graph_def.ParseFromString(f.read())
_ = importer.import_graph_def(output_graph_def, name="")
print('model loaded!')
all_keys = sorted([n.name for n in tf.get_default_graph().as_graph_def().node])
# for k in all_keys:
# print(k)

with tf.Session() as sess:
flops = tf.profiler.profile(tf.get_default_graph(), run_meta=run_meta,
options=tf.profiler.ProfileOptionBuilder.float_operation())
print("test flops:{:,}".format(flops.total_float_ops))

之所以题中打印的flops只有18个,是因为在生成pb文件时,我将输入图像的shape设置为[None, None, 3]。如果我将其更改为 [500, 500, 3],则打印的触发器将是正确的。

关于python - 如何计算从 pb 文件加载的 tensorflow 模型的触发器,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52258030/

31 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com