gpt4 book ai didi

java - 在java中加载 tensorflow 模型

转载 作者:太空宇宙 更新时间:2023-11-03 21:18:50 24 4
gpt4 key购买 nike

我正在尝试在 Java 中加载 tensorflow 模型。

tf.saved_model.simple_save(
sess,
"/tmp/model/"+timestamp,
inputs={"input_x" : cnn.input_x},
outputs={"input_y" : cnn.input_y})

这就是我在 python 中保存 tensorflow 模型的方法。

public static void main( String[] args ) throws IOException
{
// good idea to print the version number, 1.2.0 as of this writing
System.out.println(TensorFlow.version());
final int NUM_PREDICTIONS = 1;
Random r = new Random();
long[] shape = new long[] {1,56};
IntBuffer buf = IntBuffer.allocate(1*56);
for (int i = 0; i < 56; i++) {
buf.put(r.nextInt());
}
buf.flip();


// load the model Bundle
try (SavedModelBundle b = SavedModelBundle.load("/tmp/model/1549001254", "serve")) {

Session sess = b.session();

// run the model and get the result, 4.0f.
try(Tensor x = Tensor.create(shape, buf)){
float[] result = sess.runner()
.feed("input_x", x)
.fetch("input_y")
.run()
.get(0)
.copyTo(new float[1][2])[0];

// print out the result.
System.out.println(result[0]);
}

}
}

这就是我在 Java 中加载它的方式。

The given SavedModel SignatureDef contains the following input(s):
inputs['input_x'] tensor_info:
dtype: DT_INT32
shape: (-1, 56)
name: input_x:0
The given SavedModel SignatureDef contains the following output(s):
outputs['input_y'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 2)
name: input_y:0
Method name is: tensorflow/serving/predict

输入和输出保存完好。

1.12.0
2019-02-01 15:58:59.065677: I tensorflow/cc/saved_model/reader.cc:31] Reading SavedModel from: /tmp/model/1549001254
2019-02-01 15:58:59.072601: I tensorflow/cc/saved_model/reader.cc:54] Reading meta graph with tags { serve }
2019-02-01 15:58:59.085912: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 SSE4.2 AVX AVX2
2019-02-01 15:58:59.132271: I tensorflow/cc/saved_model/loader.cc:162] Restoring SavedModel bundle.
2019-02-01 15:58:59.199331: I tensorflow/cc/saved_model/loader.cc:138] Running MainOp with key legacy_init_op on SavedModel bundle.
2019-02-01 15:58:59.199435: I tensorflow/cc/saved_model/loader.cc:259] SavedModel load for tags { serve }; Status: success. Took 133774 microseconds.
Exception in thread "main" java.lang.IllegalArgumentException: You must feed a value for placeholder tensor 'input_y' with dtype float and shape [?,2]
[[{{node input_y}} = Placeholder[_output_shapes=[[?,2]], dtype=DT_FLOAT, shape=[?,2], _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]
at org.tensorflow.Session.run(Native Method)
at org.tensorflow.Session.access$100(Session.java:48)
at org.tensorflow.Session$Runner.runHelper(Session.java:314)
at org.tensorflow.Session$Runner.run(Session.java:264)
at Use_model.main(Use_model.java:38)

但是无法加载模型...错误信息是这样的。

我不知道问题是什么以及如何解决。

最佳答案

您的代码中的 input_y 存在一些困惑。异常(exception)情况是:

You must feed a value for placeholder tensor 'input_y' with dtype float and shape [?,2]

这意味着,在您的 python 代码中,input_y 被定义为占位符。我猜这是包含 input_x 项目标签的占位符。然后应该在损失函数中使用 input_y 将 cnn 的最后一层(我们称之为 cnn.output)与实际标签(cnn.input_y)进行比较,例如:

loss = tf.square(cnn.input_y - cnn.output)

然后,你的Python代码应该将cnn.output保存在输出字典中,而不是cnn.input_y:

tf.saved_model.simple_save(
sess,
"/tmp/model/"+timestamp,
inputs={"input_x" : cnn.input_x},
outputs={"output" : cnn.output})

在你的java代码中你应该获取“输出”:

float[] result = sess.runner()
.feed("input_x", x)
.fetch("output")
.run()
.get(0)
.copyTo(new float[1][2])[0];

关于java - 在java中加载 tensorflow 模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54474618/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com