gpt4 book ai didi

tensorflow - 使用经过训练的对象检测 API 模型和 TF 2 进行批量预测

转载 作者:行者123 更新时间:2023-12-04 11:27:54 24 4
gpt4 key购买 nike

我在 TPU 上使用 TF 2 的对象检测 API 成功训练了一个模型,该模型保存为 .pb(SavedModel 格式)。然后我使用 tf.saved_model.load 重新加载它并且在使用转换为形状为 (1, w, h, 3) 的张量的单个图像预测框时效果很好.

import tensorflow as tf
import numpy as np

# Load Object Detection APIs model
detect_fn = tf.saved_model.load('/path/to/saved_model/')

image = tf.io.read_file(image_path)
image_np = tf.image.decode_jpeg(image, channels=3).numpy()
input_tensor = np.expand_dims(image_np, 0)
detections = detect_fn(input_tensor) # This works fine
问题是我需要进行批量预测以将其缩放到 50 万张图像,但该模型的输入签名似乎仅限于处理形状为 (1, w, h, 3) 的数据。 .
这也意味着我不能在 Tensorflow Serving 中使用批处理。
我怎么解决这个问题?我可以只更改模型签名来处理批量数据吗?
所有工作(加载模型 + 预测)都在随对象检测 API 一起发布的官方容器内执行(来自 here)

最佳答案

我最近遇到了这个问题。当您使用 exporter_main_v2.py将检查点文件转换为 .pb文件,它将调用 exporter_lib_v2.py .我认为在文件 exporter_lib_v2.py 中( here ),TF2 硬固定了形状为 [1, None, None, 3] 的输入签名.我们必须把它改成[None, None, None, 3]需要从 1 修改该文件中的那些行( 138162170185 )至 None .然后重建 TF2 Object Detector API Repo ( link ) 并使用新构建的版本导出 .pb再次。

关于tensorflow - 使用经过训练的对象检测 API 模型和 TF 2 进行批量预测,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/63702841/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com