gpt4 book ai didi

tensorflow - 将保存的 tensorflow 模型转换为 tensorflow Lite 的正确方法是什么

转载 作者:行者123 更新时间:2023-12-03 16:59:09 28 4
gpt4 key购买 nike

我有一个与 model zoo 中的所有模型相同的已保存 tensorflow 模型.
我想将其转换为 tesorflow lite,我从 tensorflow github 中找到了以下方法(我的 tensorflw 版本是 2):

!wget http://download.tensorflow.org/models/object_detection/tf2/20200711/ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.tar.gz 
# extract the downloaded file
!tar -xzvf ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.tar.gz

!pip install tf-nightly
import tensorflow as tf
converter = tf.lite.TFLiteConverter.from_saved_model('ssd_mobilenet_v2_320x320_coco17_tpu-8/saved_model')
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.experimental_new_converter = True

converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
tflite_model = converter.convert()

open("m.tflite", "wb").write(tflite_model)
但是转换模型的输出和输入形状与原始模型不匹配,请检查以下内容:
  • 原始模型输入和输出形状

  • enter image description here
  • 转换后的模型输入和输出形状

  • enter image description here
    所以这里有问题!输入/输出形状应该与原始模型相匹配!
    任何的想法?

    最佳答案

    从 Tensorflow github 问题,我使用他们的答案来解决我的问题。
    Link
    他们的做法:

    !pip install tf-nightly
    import tensorflow as tf

    ## TFLite Conversion
    model = tf.saved_model.load("saved_model")
    concrete_func = model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
    concrete_func.inputs[0].set_shape([1, 300, 300, 3])
    tf.saved_model.save(model, "saved_model_updated", signatures={"serving_default":concrete_func})
    converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir='saved_model_updated', signature_keys=['serving_default'])

    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
    tflite_model = converter.convert()

    ## TFLite Interpreter to check input shape
    interpreter = tf.lite.Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()

    # Get input and output tensors.
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()

    # Test the model on random input data.
    input_shape = input_details[0]['shape']
    print(input_shape)

    [ 1 300 300 3]


    谢谢 MeghnaNatraj

    关于tensorflow - 将保存的 tensorflow 模型转换为 tensorflow Lite 的正确方法是什么,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/63325216/

    28 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com