- Java 双重比较
- java - 比较器与 Apache BeanComparator
- Objective-C 完成 block 导致额外的方法调用?
- database - RESTful URI 是否应该公开数据库主键?
我正在使用 Tensorflow Java Api 将已创建的 Tensorflow 模型加载到 JVM 中。我以此为例:tensorflow/examples/LabelImage.java
这是我的简单 scala 代码:
import java.nio.file.{Files, Path, Paths}
import org.tensorflow.{Graph, Session, Tensor}
def readAllBytesOrExit(path: Path): Array[Byte] = Files.readAllBytes(path)
val graphDef = readAllBytesOrExit(Paths.get("PATH_TO_A_SINGLE_FILE_DESCRIBING_TF_MODEL.pb"))
val g = new Graph()
g.importGraphDef(graphDef)
val session = new Session(g)
val result: Tensor = session.runner().feed("input", image).fetch("output").run().get(0))
如何保存我的模型以将 session 和图形存储在同一个文件中。如上面的“PATH_TO_A_SINGLE_FILE_DESCRIBING_TF_MODEL.pb”中所述。
描述here它提到:
The serialized representation of the graph, often referred to as a GraphDef, can be generated by toGraphDef() and equivalents in other language APIs.
其他语言 API 中的等价物是什么?我觉得不明显
注意:我已经查看了 tensorflow_serving 下的 mnist_saved_model.py,但通过该过程保存它会得到一个 .pb
文件和一个 variables
文件夹。尝试加载该 .pb
文件时,我得到:java.lang.IllegalArgumentException: Invalid GraphDef
最佳答案
目前使用 tensorflow 的 Java API,我只找到了如何将图形保存为 graphDef(即没有其变量和元数据)。这可以通过将 Array[Byte] 写入文件来完成:
Files.write(Paths.get(modelDir, modelName), myGraph.toGraphDef)
这里 myGraph
是来自 Graph class 的 java 对象.
我建议使用 SavedModel 从 Python API 保存您的模型api在这里定义。它会将您的模型保存在一个文件夹中,其中包含 .pb 文件中的序列化图形和文件夹中的变量。请注意您使用的 tag_constants,因为您在 scala/java 代码中需要它来加载带有变量的模型。然后带有变量的图形和 session 很容易加载 SavedModelBundle来自 java api 的 java 类。它返回一个包装器,其中包含图形和包含变量值的 session :
val model = SavedModelBundle.load(modelDir, modelTag)
如果您已经尝试过此操作,也许您可以分享您的代码以了解它返回无效 GraphDef 的原因。
另一种选择是卡住你的图表,即将你的变量节点变成常量节点,这样一切都在 .pb 文件中是独立的。更多信息 here冷冻部分
关于java - Python 中的 Tensorflow Java Api `toGraphDef` 等价物是什么?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43242857/
我正在使用 Tensorflow Java Api 将已创建的 Tensorflow 模型加载到 JVM 中。我以此为例:tensorflow/examples/LabelImage.java 这是我
我是一名优秀的程序员,十分优秀!