gpt4 book ai didi

python - mobilenetv2 tflite 不是 python3 的预期输出大小

转载 作者:太空宇宙 更新时间:2023-11-03 21:13:44 25 4
gpt4 key购买 nike

我的 mobilenetV2 SSD 型号遇到问题。我使用详细步骤 here 对其进行了转换,除了我使用 CLI 工具 tflite_convert 来执行相关步骤。

这工作正常,我能够执行推理,但输出大小不是我预期的。

以下Python代码行

interpreter.get_output_details()

告诉我要取回 10 个检测框:

[{'shape': array([ 1, 10,  4], dtype=int32), 'index': 252, 'name': 'TFLite_Detection_PostProcess', 'quantization': (0.0, 0), 'dtype': <class 'numpy.float32'>}, {'shape': array([ 1, 10], dtype=int32), 'index': 253, 'name': 'TFLite_Detection_PostProcess:1', 'quantization': (0.0, 0), 'dtype': <class 'numpy.float32'>}, {'shape': array([ 1, 10], dtype=int32), 'index': 254, 'name': 'TFLite_Detection_PostProcess:2', 'quantization': (0.0, 0), 'dtype': <class 'numpy.float32'>}, {'shape': array([1], dtype=int32), 'index': 255, 'name': 'TFLite_Detection_PostProcess:3', 'quantization': (0.0, 0), 'dtype': <class 'numpy.float32'>}]

到目前为止一切顺利,但在我的 pipeline.config 文件中,我指定了以下 post_processing 设置

post_processing {
batch_non_max_suppression {
score_threshold: 9.99999993922529e-09
iou_threshold: 0.6000000238418579
max_detections_per_class: 100
max_total_detections: 100
}
score_converter: SIGMOID
}

因此,考虑到在经典 tensorflow 中运行相同的模型会给我 100 个框,因此我预计检测的输出数量为 100。

有没有办法改变输出张量的大小?是在转换时还是运行时?

我在经典 tensorflow 中添加了张量输出详细信息:

[<tf.Tensor 'prefix/detection_boxes:0' shape=<unknown> dtype=float32>, <tf.Tensor 'prefix/detection_scores:0' shape=<unknown> dtype=float32>, <tf.Tensor 'prefix/detection_classes:0' shape=<unknown> dtype=float32>, <tf.Tensor 'prefix/num_detections:0' shape=<unknown> dtype=float32>]

形状未知,这是有道理的,因为我们可以有 100 个或更少的盒子......

对此的任何说明都将非常感激。

如果已经有人问过类似的问题,但我显然没有找到,请原谅。谢谢。

最佳答案

重新读取 export_tflite_ssd_graph.py 脚本后,似乎有一个选项可以设置保留的最大检测数量。

将其设置为 100 解决了我的问题。我感觉很糟糕。

对于那些感兴趣的人,我将导出命令从

python3 object_detection/export_tflite_ssd_graph.py \                                            
--pipeline_config_path=$model_dir/pipeline.config \
--trained_checkpoint_prefix=$model_dir/model.ckpt \
--output_directory=$output_dir \
--add_post_processing_op=true

python3 object_detection/export_tflite_ssd_graph.py \                                            
--pipeline_config_path=$model_dir/pipeline.config \
--trained_checkpoint_prefix=$model_dir/model.ckpt \
--output_directory=$output_dir \
--add_post_processing_op=true \
--max_detections=100

关于python - mobilenetv2 tflite 不是 python3 的预期输出大小,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54864071/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com