gpt4 book ai didi

python - TensorFlow:如何将图像解码器节点添加到我的图中?

转载 作者:太空宇宙 更新时间:2023-11-03 21:06:22 27 4
gpt4 key购买 nike

我有一个 tensorflow 模型作为卡住图,它接受图像张量作为输入。但是,我想向该图中添加一个新的输入图像解码器节点,以便该模型也接受 jpg 图像的编码字节字符串,并最终自行解码图像。到目前为止我已经尝试过这种方法:

model = './frozen_graph.pb'

with tf.gfile.FastGFile(model, 'rb') as f:

# read graph
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name="")
g = tf.get_default_graph()

# fetch old input
old_input = g.get_tensor_by_name('image_tensor:0')

# define new input
new_input = graph_def.node.add()
new_input.name = 'encoded_image_string_tensor'
new_input.op = 'Substr'
# add new input attr
image = tf.image.decode_image(new_input, channels=3)

# link new input to old input
old_input.input = 'encoded_image_string_tensor' # must match with the name above

上面的代码返回此异常:

Expected string passed to parameter 'input' of op 'Substr', got name: "encoded_image_string_tensor" op: "Substr"  of type 'NodeDef' instead.

我不太确定是否可以在图表中使用tf.image.decode_image,所以也许还有另一种方法来解决这个问题。有人得到提示吗?

最佳答案

感谢jdehesa那位给了我很好的提示,我能够解决这个问题。使用 input_map 参数,我成功地将一个新图表(仅解码 jpg 图像)映射到原始图表的输入(此处:node.name='image_tensor:0' )。只需确保重命名解码器图的 name_scope(此处:decoder)。之后,您可以使用tensorflow SavedModelBuilder保存新的串联图。

这是一个对象检测网络的示例:

import tensorflow as tf
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants


# The export path contains the name and the version of the model
model = 'path/to/model.pb'
export_path = './output/dir/'

sigs = {}

with tf.gfile.FastGFile(model, 'rb') as f:
with tf.name_scope('decoder'):
image_str_tensor = tf.placeholder(tf.string, shape=[None], name= 'encoded_image_string_tensor')
# The CloudML Prediction API always "feeds" the Tensorflow graph with
# dynamic batch sizes e.g. (?,). decode_jpeg only processes scalar
# strings because it cannot guarantee a batch of images would have
# the same output size. We use tf.map_fn to give decode_jpeg a scalar
# string from dynamic batches.
def decode_and_resize(image_str_tensor):
"""Decodes jpeg string, resizes it and returns a uint8 tensor."""
image = tf.image.decode_jpeg(image_str_tensor, channels=3)

# do additional image manipulation here (like resize etc...)

image = tf.cast(image, dtype=tf.uint8)
return image

image = tf.map_fn(decode_and_resize, image_str_tensor, back_prop=False, dtype=tf.uint8)

with tf.name_scope('net'):
# load .pb file
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())

# concatenate decoder graph and original graph
tf.import_graph_def(graph_def, name="", input_map={'image_tensor:0':image})
g = tf.get_default_graph()

with tf.Session() as sess:
# load graph into session and save to new .pb file

# define model input
inp = g.get_tensor_by_name('decoder/encoded_image_string_tensor:0')

# define model outputs
num_detections = g.get_tensor_by_name('num_detections:0')
detection_scores = g.get_tensor_by_name('detection_scores:0')
detection_boxes = g.get_tensor_by_name('detection_boxes:0')
out = {'num_detections': num_detections, 'detection_scores': detection_scores, 'detection_boxes': detection_boxes}


builder = tf.saved_model.builder.SavedModelBuilder(export_path)

tensor_info_inputs = {
'inputs': tf.saved_model.utils.build_tensor_info(inp)}
tensor_info_outputs = {}
for k, v in out.items():
tensor_info_outputs[k] = tf.saved_model.utils.build_tensor_info(v)

# assign detection signature for tensorflow serving
detection_signature = (
tf.saved_model.signature_def_utils.build_signature_def(
inputs=tensor_info_inputs,
outputs=tensor_info_outputs,
method_name=signature_constants.PREDICT_METHOD_NAME))

# "build" graph
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map={
'detection_signature':
detection_signature,
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
detection_signature,
},
main_op=tf.tables_initializer()
)
# save graph
builder.save()

我用来寻找正确解决方案的来源:

Coding example

Scope names explanation

Tensorflow Github Issue #22162

另外:如果您很难找到正确的输入和输出节点,您可以运行它来显示图形:

graph_op = g.get_operations()
for i in graph_op:
print(i.node_def)

关于python - TensorFlow:如何将图像解码器节点添加到我的图中?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55373048/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com