gpt4 book ai didi

python - output_graph.pb 上的 tf.GraphKeys.TRAINABLE_VARIABLES 导致空列表

转载 作者:太空狗 更新时间:2023-10-30 01:05:14 29 4
gpt4 key购买 nike

我正在尝试从保存的模型 output_graph.pb 中提取所有权重/偏差。

我读了模型:

def create_graph(modelFullPath):
"""Creates a graph from saved GraphDef file and returns a saver."""
# Creates graph from saved graph_def.pb.
with tf.gfile.FastGFile(modelFullPath, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')

GRAPH_DIR = r'C:\tmp\output_graph.pb'
create_graph(GRAPH_DIR)

并尝试这样做,希望我能够提取所有权重/偏差在每一层内。

with tf.Session() as sess:
all_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
print (len(all_vars))

但是,我得到的 len 值为 0。

最终目标是提取权重和偏差并将其保存到文本文件/np.arrays。

最佳答案

tf.import_graph_def() 函数没有足够的信息来重建 tf.GraphKeys.TRAINABLE_VARIABLES 集合(为此,您需要 MetaGraphDef ).但是,如果 output.pb 包含“卡住的”GraphDef,则所有权重都将存储在 tf.constant() 中。图中的节点。要提取它们,您可以执行以下操作:

create_graph(GRAPH_DIR)

constant_values = {}

with tf.Session() as sess:
constant_ops = [op for op in sess.graph.get_operations() if op.type == "Const"]
for constant_op in constant_ops:
constant_values[constant_op.name] = sess.run(constant_op.outputs[0])

请注意,constant_values 可能包含比权重更多的值,因此您可能需要通过 op.name 或其他一些标准进一步过滤。

关于python - output_graph.pb 上的 tf.GraphKeys.TRAINABLE_VARIABLES 导致空列表,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46696859/

29 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com