gpt4 book ai didi

python - 将 .pb 文件转换为 .ckpt (tensorflow)

转载 作者:太空狗 更新时间:2023-10-29 20:31:49 25 4
gpt4 key购买 nike

我已经成功地使用这个脚本将预训练的 .ckpt 模型转换为 .pb (protobuf) 格式:

import os
import tensorflow as tf

# Get the current directory
dir_path = os.path.dirname(os.path.realpath(__file__))
print "Current directory : ", dir_path
save_dir = dir_path + '/Protobufs'

graph = tf.get_default_graph()

# Create a session for running Ops on the Graph.
sess = tf.Session()

print("Restoring the model to the default graph ...")
saver = tf.train.import_meta_graph(dir_path + '/model.ckpt.meta')
saver.restore(sess,tf.train.latest_checkpoint(dir_path))
print("Restoring Done .. ")

print "Saving the model to Protobuf format: ", save_dir

#Save the model to protobuf (pb and pbtxt) file.
tf.train.write_graph(sess.graph_def, save_dir, "Binary_Protobuf.pb", False)
tf.train.write_graph(sess.graph_def, save_dir, "Text_Protobuf.pbtxt", True)
print("Saving Done .. ")

现在,我想要的是 vice-verca 程序。如何加载 protobuf 文件并将其转换为 .ckpt(检查点)格式?

我正在尝试使用以下脚本来做到这一点,但它总是失败:

import tensorflow as tf
import argparse

# Pass the filename as an argument
parser = argparse.ArgumentParser()
parser.add_argument("--frozen_model_filename", default="/path-to-pb-file/Binary_Protobuf.pb", type=str, help="Pb model file to import")
args = parser.parse_args()

# We load the protobuf file from the disk and parse it to retrieve the
# unserialized graph_def
with tf.gfile.GFile(args.frozen_model_filename, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())

#saver=tf.train.Saver()
with tf.Graph().as_default() as graph:

tf.import_graph_def(
graph_def,
input_map=None,
return_elements=None,
name="prefix",
op_dict=None,
producer_op_list=None
)
sess = tf.Session(graph=graph)
saver=tf.train.Saver()
save_path = saver.save(sess, "path-to-ckpt/model.ckpt")
print("Model saved to chkp format")

我相信拥有这些转换脚本会非常有帮助。

P.S:权重已经嵌入到 .pb 文件中。

谢谢。

最佳答案

看来您在两个文件中都只得到了图形定义,而不是卡住模型。

# This two lines only save the graph as proto file; it doesn't save the variables and their values. 
tf.train.write_graph(sess.graph_def, save_dir, "Binary_Protobuf.pb", False)
tf.train.write_graph(sess.graph_def, save_dir, "Text_Protobuf.pbtxt", True)

卡住图是使用freeze_graph file获得的

关于python - 将 .pb 文件转换为 .ckpt (tensorflow),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45491654/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com