gpt4 book ai didi

python - Tensorflow:从图中删除节点

转载 作者:行者123 更新时间:2023-12-05 00:47:35 25 4
gpt4 key购买 nike

我正在尝试从图中删除一些节点并将其保存在 .pb 中

只有需要的节点可以添加到新的 mod_graph_def图,但图在其他节点输入中仍有一些对已删除节点的引用,但我无法修改节点的输入:

def delete_ops_from_graph():
with open(input_model_filepath, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())

nodes = []
for node in graph_def.node:
if 'Neg' in node.name:
print('Drop', node.name)
else:
nodes.append(node)

mod_graph_def = tf.GraphDef()
mod_graph_def.node.extend(nodes)

# The problem that graph still have some references to deleted node in other nodes inputs
for node in mod_graph_def.node:
inp_names = []
for inp in node.input:
if 'Neg' in inp:
pass
else:
inp_names.append(inp)

node.input = inp_names # TypeError: Can't set composite field

with open(output_model_filepath, 'wb') as f:
f.write(mod_graph_def.SerializeToString())

最佳答案

def delete_ops_from_graph():
with open(input_model_filepath, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())

# Delete nodes
nodes = []
for node in graph_def.node:
if 'Neg' in node.name:
print('Drop', node.name)
else:
nodes.append(node)

mod_graph_def = tf.GraphDef()
mod_graph_def.node.extend(nodes)

# Delete references to deleted nodes
for node in mod_graph_def.node:
inp_names = []
for inp in node.input:
if 'Neg' in inp:
pass
else:
inp_names.append(inp)

del node.input[:]
node.input.extend(inp_names)

with open(output_model_filepath, 'wb') as f:
f.write(mod_graph_def.SerializeToString())

关于python - Tensorflow:从图中删除节点,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56324534/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com