gpt4 book ai didi

java - TensorFlow 2.0 和 Java API

转载 作者:行者123 更新时间:2023-12-03 19:14:28 44 4
gpt4 key购买 nike

(注意,我已经解决了我的问题并在底部发布了代码)

我正在玩 TensorFlow,后端处理必须在 Java 中进行。我从 https://developers.google.com/machine-learning/crash-course 中获取了其中一个模型并用 tf.saved_model.save(my_model,"house_price_median_income") 保存它(使用 docker 容器)。我复制了模型并将其加载到 Java 中(使用从源代码构建的 2.0 东西,因为我在 Windows 上)。
我可以加载模型并运行它:

   try (SavedModelBundle model = SavedModelBundle.load("./house_price_median_income", "serve")) {
try (Session session = model.session()) {
Session.Runner runner = session.runner();
float[][] in = new float[][]{ {2.1518f} } ;

Tensor<?> jack = Tensor.create(in);
runner.feed("serving_default_layer1_input", jack);

float[][] probabilities = runner.fetch("StatefulPartitionedCall").run().get(0).copyTo(new float[1][1]);

for (int i = 0; i < probabilities.length; ++i) {
System.out.println(String.format("-- Input #%d", i));
for (int j = 0; j < probabilities[i].length; ++j) {
System.out.println(String.format("Class %d - %f", i, probabilities[i][j]));
}
}
}
}

以上内容硬编码为输入和输出,但我希望能够读取模型并提供一些信息,以便最终用户可以选择输入和输出等。

我可以使用 python 命令获取输入和输出:saved_model_cli show --dir ./house_price_median_income --all

我想做的是通过 Java 获取输入和输出,因此我的代码不需要执行 python 脚本来获取它们。我可以通过以下方式进行操作:
 Graph graph = model.graph();
Iterator<Operation> itr = graph.operations();
while (itr.hasNext()) {
GraphOperation e = (GraphOperation)itr.next();
System.out.println(e);

这会将输入和输出都输出为“操作”但是我怎么知道它是输入和/或输出? python 工具使用 SignatureDef 但这似乎根本没有出现在 TensorFlow 2.0 java 的东西中。我是否遗漏了一些明显的东西,还是只是从 TensforFlow 2.0 Java 库中遗漏了?

注意,我已经通过下面的答案帮助对我的问题进行了分类。这是我的全部代码,以防将来有人会喜欢它。请注意,这是 TF 2.0 并使用下面提到的 SNAPSHOT。我做了一些假设,但它展示了如何提取输入和输出,然后使用它们来运行模型
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.exceptions.TensorFlowException;
import org.tensorflow.Session.Run;
import org.tensorflow.Graph;
import org.tensorflow.Operation;
import org.tensorflow.Output;
import org.tensorflow.GraphOperation;
import org.tensorflow.proto.framework.SignatureDef;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.tensorflow.proto.framework.MetaGraphDef;
import java.util.Map;
import org.tensorflow.proto.framework.TensorInfo;
import org.tensorflow.types.TFloat32;
import org.tensorflow.tools.Shape;
import java.nio.FloatBuffer;
import org.tensorflow.tools.buffer.DataBuffers;
import org.tensorflow.tools.ndarray.FloatNdArray;
import org.tensorflow.tools.ndarray.StdArrays;
import org.tensorflow.proto.framework.TensorInfo;

public class v2tensor {
public static void main(String[] args) {
try (SavedModelBundle savedModel = SavedModelBundle.load("./house_price_median_income", "serve")) {
SignatureDef modelInfo = savedModel.metaGraphDef().getSignatureDefMap().get("serving_default");
TensorInfo input1 = null;
TensorInfo output1 = null;
Map<String, TensorInfo> inputs = modelInfo.getInputsMap();
for(Map.Entry<String, TensorInfo> input : inputs.entrySet()) {
if (input1 == null) {
input1 = input.getValue();
System.out.println(input1.getName());
}
System.out.println(input);
}
Map<String, TensorInfo> outputs = modelInfo.getOutputsMap();
for(Map.Entry<String, TensorInfo> output : outputs.entrySet()) {
if (output1 == null) {
output1=output.getValue();
}
System.out.println(output);
}

try (Session session = savedModel.session()) {
Session.Runner runner = session.runner();
FloatNdArray matrix = StdArrays.ndCopyOf(new float[][]{ { 2.1518f } } );

try (Tensor<TFloat32> jack = TFloat32.tensorOf(matrix) ) {
runner.feed(input1.getName(), jack);
try ( Tensor<TFloat32> rezz = runner.fetch(output1.getName()).run().get(0).expect(TFloat32.DTYPE) ) {
TFloat32 data = rezz.data();
data.scalars().forEachIndexed((i, s) -> {
System.out.println(s.getFloat());
} );
}
}
}
} catch (TensorFlowException ex) {
ex.printStackTrace();
}
}
}

最佳答案

你需要做的是阅读SavedModelBundle元数据作为 MetaGraphDef ,从那里您可以从 SignatureDef 中检索输入和输出名称,就像在 Python 中一样。

在 TF Java 1.*(即您在示例中使用的客户端)中,原型(prototype)定义在 tensorflow 中不可用。 Artifact ,你需要添加一个依赖到org.tensorflow:proto以及反序列化 SavedModelBundle.metaGraphDef() 的结果变成 MetaGraphDef原型(prototype)。

在 TF Java 2.* 中(新客户端实际上只能作为 here 的快照提供),protos 立即出现,因此您可以简单地调用此行来检索正确的 SignatureDef :

savedModel.metaGraphDef().signatureDefMap.getValue("serving_default")

关于java - TensorFlow 2.0 和 Java API,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61228372/

44 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com