gpt4 book ai didi

tensorflow - 如何在java中读取tensorflow模型的输出

转载 作者:行者123 更新时间:2023-12-05 01:41:17 24 4
gpt4 key购买 nike

我尝试将 TensorflowLite 与来自 https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md 的 ssdlite_mobilenet_v2_coco 模型一起使用转换为 tflite 文件以在我的 android 应用程序 (java) 中检测相机流中的对象。我执行

    interpreter.run(input, output);

其中输入是转换为 ByteBuffer 的图像,输出是 float 组 - 大小为 [1][10][4] 以匹配张量。

如何将这个 float 组转换成一些可读的输出? - 例如获取边界框坐标、对象名称、概率。

最佳答案

好吧,我明白了。首先,我在 python 中运行以下命令:

>>> import tensorflow as tf
>>> interpreter = tf.contrib.lite.Interpreter("detect.tflite")

然后加载 Tflite 模型:

>>> interpreter.allocate_tensors()
>>> input_details = interpreter.get_input_details()
>>> output_details = interpreter.get_output_details()

现在我已经详细了解了输入和输出应该是什么样子

>>> input_details
[{'name': 'normalized_input_image_tensor', 'index': 308, 'shape': array([ 1, 300, 300, 3], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}]

所以输入是转换后的图像 - 形状 300 x 300

>>> output_details
[{'name': 'TFLite_Detection_PostProcess', 'index': 300, 'shape': array([ 1, 10, 4], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}, {'name': 'TFLite_Detection_PostProcess:1', 'index': 301, 'shape': array([ 1, 10], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}, {'name': 'TFLite_Detection_PostProcess:2', 'index': 302, 'shape': array([ 1, 10], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}, {'name': 'TFLite_Detection_PostProcess:3', 'index': 303, 'shape': array([1], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}]

现在我得到了这个模型中多个输出的规范。我需要改变

interpreter.run(input, output) 

interpreter.runForMultipleInputsOutputs(inputs, map_of_indices_to_outputs);

其中“输入”是:

private Object[1] inputs;
inputs[0] = imgData; //imgData - image converted to bytebuffer

map_of_indices_to_outputs 是:

private Map<Integer, Object> output_map = new TreeMap<>();
private float[1][10][4] boxes;
private float[1][10] scores;
private float[1][10] classes;
output_map.put(0, boxes);
output_map.put(1, classes);
output_map.put(2, scores);

现在运行后我得到了盒子中 10 个对象的坐标,类中对象的索引(在 coco 标签文件中)你必须加 1 才能获得正确的键!和分数中的概率。

希望这对以后的人有帮助。

关于tensorflow - 如何在java中读取tensorflow模型的输出,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54716143/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com