gpt4 book ai didi

java - Python 中的 Tensorflow Java Api `toGraphDef` 等价物是什么?

转载 作者:搜寻专家 更新时间:2023-11-01 03:32:57 26 4
gpt4 key购买 nike

我正在使用 Tensorflow Java Api 将已创建的 Tensorflow 模型加载到 JVM 中。我以此为例:tensorflow/examples/LabelImage.java

这是我的简单 scala 代码:

import java.nio.file.{Files, Path, Paths}
import org.tensorflow.{Graph, Session, Tensor}

def readAllBytesOrExit(path: Path): Array[Byte] = Files.readAllBytes(path)
val graphDef = readAllBytesOrExit(Paths.get("PATH_TO_A_SINGLE_FILE_DESCRIBING_TF_MODEL.pb"))
val g = new Graph()
g.importGraphDef(graphDef)
val session = new Session(g)
val result: Tensor = session.runner().feed("input", image).fetch("output").run().get(0))

如何保存我的模型以将 session 和图形存储在同一个文件中。如上面的“PATH_TO_A_SINGLE_FILE_DESCRIBING_TF_MODEL.pb”中所述。

描述here它提到:

The serialized representation of the graph, often referred to as a GraphDef, can be generated by toGraphDef() and equivalents in other language APIs.

其他语言 API 中的等价物是什么?我觉得不明显

注意:我已经查看了 tensorflow_serving 下的 mnist_saved_model.py,但通过该过程保存它会得到一个 .pb 文件和一个 variables 文件夹。尝试加载该 .pb 文件时,我得到:java.lang.IllegalArgumentException: Invalid GraphDef

最佳答案

目前使用 tensorflow 的 Java API,我只找到了如何将图形保存为 graphDef(即没有其变量和元数据)。这可以通过将 Array[Byte] 写入文件来完成:

Files.write(Paths.get(modelDir, modelName), myGraph.toGraphDef)

这里 myGraph 是来自 Graph class 的 java 对象.

我建议使用 SavedModel 从 Python API 保存您的模型api在这里定义。它会将您的模型保存在一个文件夹中,其中包含 .pb 文件中的序列化图形和文件夹中的变量。请注意您使用的 tag_constants,因为您在 scala/java 代码中需要它来加载带有变量的模型。然后带有变量的图形和 session 很容易加载 SavedModelBundle来自 java api 的 java 类。它返回一个包装器,其中包含图形和包含变量值的 session :

val model = SavedModelBundle.load(modelDir, modelTag)

如果您已经尝试过此操作,也许您可​​以分享您的代码以了解它返回无效 GraphDef 的原因。

另一种选择是卡住你的图表,即将你的变量节点变成常量节点,这样一切都在 .pb 文件中是独立的。更多信息 here冷冻部分

关于java - Python 中的 Tensorflow Java Api `toGraphDef` 等价物是什么?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43242857/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com