gpt4 book ai didi

java - tensorflow.keras 在 python 中保存模型并在 Java 中加载

转载 作者:行者123 更新时间:2023-11-30 05:52:02 24 4
gpt4 key购买 nike

我有一个经过微调的 vgg 模型,我使用 tensorflow.keras 功能 API 创建了该模型,并使用 tf.contrib.saved_model.save_keras_model 保存了模型。因此模型以以下结构保存:assets 文件夹,其中包含 saved_model.json 文件、saved_model.pb 文件,以及变量文件夹,其中包含 checkpoint strong>、variables.data-00000-of-00001variables.index

我可以轻松地在 python 中加载模型并使用 tf.contrib.saved_model.load_keras_model(saved_model_path) 获得预测,但我不知道如何在 JAVA 中加载模型。我用谷歌搜索了很多,发现了这个How to export Keras .h5 to tensorflow .pb?导出为 pb 文件,然后按照此链接 Loading in Java 加载它。我无法卡住图表,而且我尝试使用 simple_save,但 tensorflow.keras 不支持 simple_save (AttributeError:模块“tensorflow.contrib.saved_model”没有属性“simple_save”)。那么有人可以帮助我弄清楚在 JAVA 中加载我的模型(tensorflow.keras 功能 API 模型)需要哪些步骤。

我拥有的saved_model.pb 文件是否足够好,可以在JAVA 端加载?我需要创建输入/输出占位符吗?那我该如何导出呢?
感谢您的帮助。

最佳答案

如果您有一个以 SavedModel 格式保存的模型(看起来您确实这样做了,并且 tf.contrib.saved_model.save_keras_model 之类的东西可以帮助创建),那么在 Java 中您可以使用 SavedModelBundle.load加载并提供服务。您不需要“卡住”模型。

您可能会发现以下有用:

但基本思想是您的代码将类似于:

try (SavedModelBundle model = SavedModelBundle.load("<directory>", "serve")) {
try (Tensor<?> input = makeInputTensor();
Tensor<?> output = model.session().runner().feed("INPUT_TENSOR", input).fetch("OUTPUT_TENSOR", output).run().get(0)) {
// Use output
}
}

其中 "INPUT_TENSOR""OUTPUT_TENSOR" 是 TensorFlow 图中输入和输出节点的名称。安装 TensorFlow for Python 时安装的 saved_model_cli 命令行工具可以显示模型中这些张量的名称。

请注意,按照另一位评论者的建议,使用 TensorFlow Java API 可能比使用 TensorFlow Lite 更适合服务器/桌面应用程序。这是因为 TensorFLow Lite 运行时虽然针对小型设备进行了优化(在内存占用等方面),但尚无法导出所有模型。虽然 TensorFlow Java API 使用完全相同的运行时,因此具有与 TensorFlow for Python 完全相同的功能。

希望有帮助。

关于java - tensorflow.keras 在 python 中保存模型并在 Java 中加载,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53745373/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com