gpt4 book ai didi

python - Tensorflow:加载一个 .pb 文件,然后将其保存为卡住图问题

转载 作者:太空宇宙 更新时间:2023-11-04 09:40:43 25 4
gpt4 key购买 nike

这个问题与这个问题非常相似:How do you use freeze_graph.py in Tensorflow?但是那个问题没有得到回答,我对这个问题有不同的方法。因此,我想要一些意见。

我也在尝试加载一个 .pb 二进制文件,然后将其卡住。这是我试过的代码。

如果这对您有任何想法,请告诉我。这不会返回错误。它只是让我的 jupyter notebook 崩溃了。

import tensorflow as tf
import sys
from tensorflow.python.platform import gfile

from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.python.util import compat

with tf.Session() as sess:
model_filename ='saved_model.pb' # binary .pb file
with gfile.FastGFile(model_filename, 'rb') as f:

data = compat.as_bytes(f.read()) # reads binary
sm = saved_model_pb2.SavedModel()
print(sm)
sm.ParseFromString(data) # parses through the file
print(sm)
if 1 != len(sm.meta_graphs):
print('More than one graph found. Not sure which to write')
sys.exit(1)

g_in = tf.import_graph_def(sm.meta_graphs[0].graph_def)
output_graph = "frozen_graph.pb"

# Getting all output nodes for the frozen graph
output_nodes = [n.name for n in tf.get_default_graph().as_graph_def().node]
# This not working fully
output_graph_def = tf.graph_util.convert_variables_to_constants(
sess, # The session is used to retrieve the weights
tf.get_default_graph().as_graph_def(), # The graph_def is used to retrieve the nodes
output_nodes# The output node names are used to select the usefull nodes
)

# Finally we serialize and dump the output graph to the filesystem
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
print("%d ops in the final graph." % len(output_graph_def.node))
print(g_in)
LOGDIR='.'
train_writer = tf.summary.FileWriter(LOGDIR)
train_writer.add_graph(sess.graph)

这段代码应该会生成一个卡住文件,但是我对tensorflow的保存机制不是很了解。如果我从这段代码中取出卡住图形部分,我会得到 events.out。 tensorboard 可以读取的文件。

最佳答案

所以在经历了很多挫折之后,我意识到我只是在加载元图。不是带有变量的整个图。这是这样做的代码:

def frozen_graph_maker(export_dir,output_graph):
with tf.Session(graph=tf.Graph()) as sess:
tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], export_dir)
output_nodes = [n.name for n in tf.get_default_graph().as_graph_def().node]
output_graph_def = tf.graph_util.convert_variables_to_constants(
sess, # The session is used to retrieve the weights
sess.graph_def,
output_nodes# The output node names are used to select the usefull nodes
)
# Finally we serialize and dump the output graph to the filesystem
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
def main():
export_dir='/dir/of/pb/and/variables'
output_graph = "frozen_graph.pb"
frozen_graph_maker(export_dir,output_graph)

我意识到我只是在加载元图。如果有人能证实我对失败原因的理解,我会很高兴。使用 compat.as_bytes 我只是将它加载为元图。有没有办法在完成这种加载后整合变量,还是我应该坚持使用 tf.saved_model.loader.load() ?我的加载尝试是完全错误的,因为它甚至没有调用变量文件夹。

另一个问题:使用 [n.name for n in tf.get_default_graph().as_graph_def().node] 我将所有节点放入 output_nodes,我应该只放入最后一个节点吗?它仅适用于最后一个节点。有什么不同?

关于python - Tensorflow:加载一个 .pb 文件,然后将其保存为卡住图问题,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51826706/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com