作者热门文章
- html - 出于某种原因,IE8 对我的 Sass 文件中继承的 html5 CSS 不友好?
- JMeter 在响应断言中使用 span 标签的问题
- html - 在 :hover and :active? 上具有不同效果的 CSS 动画
- html - 相对于居中的 html 内容固定的 CSS 重复背景?
我尝试将 TensorflowLite 与来自 https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md 的 ssdlite_mobilenet_v2_coco 模型一起使用转换为 tflite 文件以在我的 android 应用程序 (java) 中检测相机流中的对象。我执行
interpreter.run(input, output);
其中输入是转换为 ByteBuffer 的图像,输出是 float 组 - 大小为 [1][10][4] 以匹配张量。
如何将这个 float 组转换成一些可读的输出? - 例如获取边界框坐标、对象名称、概率。
最佳答案
好吧,我明白了。首先,我在 python 中运行以下命令:
>>> import tensorflow as tf
>>> interpreter = tf.contrib.lite.Interpreter("detect.tflite")
然后加载 Tflite 模型:
>>> interpreter.allocate_tensors()
>>> input_details = interpreter.get_input_details()
>>> output_details = interpreter.get_output_details()
现在我已经详细了解了输入和输出应该是什么样子
>>> input_details
[{'name': 'normalized_input_image_tensor', 'index': 308, 'shape': array([ 1, 300, 300, 3], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}]
所以输入是转换后的图像 - 形状 300 x 300
>>> output_details
[{'name': 'TFLite_Detection_PostProcess', 'index': 300, 'shape': array([ 1, 10, 4], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}, {'name': 'TFLite_Detection_PostProcess:1', 'index': 301, 'shape': array([ 1, 10], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}, {'name': 'TFLite_Detection_PostProcess:2', 'index': 302, 'shape': array([ 1, 10], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}, {'name': 'TFLite_Detection_PostProcess:3', 'index': 303, 'shape': array([1], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}]
现在我得到了这个模型中多个输出的规范。我需要改变
interpreter.run(input, output)
到
interpreter.runForMultipleInputsOutputs(inputs, map_of_indices_to_outputs);
其中“输入”是:
private Object[1] inputs;
inputs[0] = imgData; //imgData - image converted to bytebuffer
map_of_indices_to_outputs 是:
private Map<Integer, Object> output_map = new TreeMap<>();
private float[1][10][4] boxes;
private float[1][10] scores;
private float[1][10] classes;
output_map.put(0, boxes);
output_map.put(1, classes);
output_map.put(2, scores);
现在运行后我得到了盒子中 10 个对象的坐标,类中对象的索引(在 coco 标签文件中)你必须加 1 才能获得正确的键!和分数中的概率。
希望这对以后的人有帮助。
关于tensorflow - 如何在java中读取tensorflow模型的输出,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54716143/
我是一名优秀的程序员,十分优秀!