gpt4 book ai didi

python - 如何获取 tensorflow 图中每个节点的输入形状?

转载 作者:太空宇宙 更新时间:2023-11-03 14:00:46 24 4
gpt4 key购买 nike

您好:现在我正在努力将 tensorflow 检查点模型转换为 caffe 模型。我已成功读取图表并提取了每个节点中的 attr 值。我在“Conv2D”节点中获得了“扩张”、“步幅”和“填充”属性的值以及“权重”节点中的形状,但我无法获得“形状”属性的值,它在 Conv2D 的输入中为空节点。但是,这些形状显示在张量板的图表中。这是我的代码:

new_saver = tf.train.import_meta_graph(meta_path)          
new_saver.restore(sess, tf.train.latest_checkpoint(ckpt_path))
graph_def = sess.graph_def
node_list = graph_def.node

# conv_node, weight_node, from_node are all in node_list
# conv_node: the conv2d node in graph_def
# weight_node: the weights node of conv2d
# from_node: the input feature map node of conv2d

weight_shape_attr = weight_node.attr['shape']
weight_shapes = [dim.size for dim in weight_shape_attr.shape.dim]

strides = [ii for ii in conv_node.attr['strides'].list.i]
dilations = [ii for ii in conv_node.attr['dilations'].list.i]

shapes = from_node.attr['shape'] # this is empty

和张量板图: tensorboard_graph

请注意,Conv2D 节点的输入形状为 ?x79x79x32,它一定已存储在模型文件中的某个位置。有哪位大神帮忙点一下吗?点击一下会有帮助,谢谢。

最佳答案

Tensorflow 图具有 as_graph_def 方法,该方法具有可选参数 add_shapes(默认情况下为 False)。如果设置为 True,它会导致节点的附加属性:_output_shapes

因此您可以尝试以这种方式获取 GraphDef:

graph_def = sess.graph.as_graph_def(add_shapes=True)

关于python - 如何获取 tensorflow 图中每个节点的输入形状?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49419960/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com