gpt4 book ai didi

graph - tensorflow : how to insert custom input to existing graph?

转载 作者:行者123 更新时间:2023-12-04 19:00:25 46 4
gpt4 key购买 nike

我已经下载了一个实现 VGG16 ConvNet 的 tensorflow GraphDef,我用它来做这个:

Pl['images'] = tf.placeholder(tf.float32, 
[None, 448, 448, 3],
name="images") #batch x width x height x channels
with open("tensorflow-vgg16/vgg16.tfmodel", mode='rb') as f:
fileContent = f.read()

graph_def = tf.GraphDef()
graph_def.ParseFromString(fileContent)
tf.import_graph_def(graph_def, input_map={"images": Pl['images']})

此外,我的图像特征与 "import/pool5/" 的输出是同质的。 .

如何告诉我的图表不想使用他的输入 "images" , 但张量 "import/pool5/"作为输入?

谢谢 !

编辑

好的,我意识到我还不是很清楚。情况如下:

我正在尝试使用 this implementation ROI 池化,使用预先训练的 VGG16,我有 GraphDef 格式。所以这就是我所做的:

首先,我加载模型:
tf.reset_default_graph()
with open("tensorflow-vgg16/vgg16.tfmodel",
mode='rb') as f:
fileContent = f.read()
graph_def = tf.GraphDef()
graph_def.ParseFromString(fileContent)
graph = tf.get_default_graph()

然后,我创建我的占位符
images = tf.placeholder(tf.float32, 
[None, 448, 448, 3],
name="images") #batch x width x height x channels
boxes = tf.placeholder(tf.float32,
[None,5], # 5 = [batch_id,x1,y1,x2,y2]
name = "boxes")

我将图形第一部分的输出定义为 conv5_3/Relu
tf.import_graph_def(graph_def, 
input_map={'images':images})
out_tensor = graph.get_tensor_by_name("import/conv5_3/Relu:0")

所以, out_tensor形状为 [None,14,14,512]
然后,我进行 ROI 池化:
[out_pool,argmax] = module.roi_pool(out_tensor,
boxes,
7,7,1.0/1)

out_pool.shape = N_Boxes_in_batch x 7 x 7 x 512 , 与 pool5 齐次.然后我想喂 out_pool作为 pool5 之后的操作的输入,所以看起来像
tf.import_graph_def(graph.as_graph_def(),
input_map={'import/pool5':out_pool})

但它不起作用,我有这个错误:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-89-527398d7344b> in <module>()
5
6 tf.import_graph_def(graph.as_graph_def(),
----> 7 input_map={'import/pool5':out_pool})
8
9 final_out = graph.get_tensor_by_name("import/Relu_1:0")

/usr/local/lib/python3.4/dist-packages/tensorflow/python/framework/importer.py in import_graph_def(graph_def, input_map, return_elements, name, op_dict)
333 # NOTE(mrry): If the graph contains a cycle, the full shape information
334 # may not be available for this op's inputs.
--> 335 ops.set_shapes_for_outputs(op)
336
337 # Apply device functions for this op.

/usr/local/lib/python3.4/dist-packages/tensorflow/python/framework/ops.py in set_shapes_for_outputs(op)
1610 raise RuntimeError("No shape function registered for standard op: %s"
1611 % op.type)
-> 1612 shapes = shape_func(op)
1613 if len(op.outputs) != len(shapes):
1614 raise RuntimeError(

/home/hbenyounes/vqa/roi_pooling_op_grad.py in _roi_pool_shape(op)
13 channels = dims_data[3]
14 print(op.inputs[1].name, op.inputs[1].get_shape())
---> 15 dims_rois = op.inputs[1].get_shape().as_list()
16 num_rois = dims_rois[0]
17

/usr/local/lib/python3.4/dist-packages/tensorflow/python/framework/tensor_shape.py in as_list(self)
745 A list of integers or None for each dimension.
746 """
--> 747 return [dim.value for dim in self._dims]
748
749 def as_proto(self):

TypeError: 'NoneType' object is not iterable

有什么线索吗?

最佳答案

通常使用 tf.train.export_meta_graph 非常方便存储整个 MetaGraph。然后,在恢复时,您可以使用 tf.train.import_meta_graph , 因为事实证明,它将所有附加参数传递给底层 import_scoped_meta_graph其中有 input_map参数并在它自己调用 import_graph_def 时使用它.

它没有记录在案,我花了太多时间才找到它,但它确实有效!

关于graph - tensorflow : how to insert custom input to existing graph?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/38618960/

46 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com