gpt4 book ai didi

java - 使用 Java API 的维度 0 的切片索引 0 超出范围

转载 作者:行者123 更新时间:2023-12-02 08:41:00 25 4
gpt4 key购买 nike

我已经生成了一个 SavedModel,我可以将其与以下 Python 代码一起使用

import base64
import numpy as np
import tensorflow as tf


fn_load_image = lambda filename: np.array([base64.urlsafe_b64encode(open(filename, "rb").read())])
filename='test.jpg'
with tf.Session() as sess:
loaded = tf.saved_model.loader.load(sess, ['serve'], 'tools/base64_model/1')
image = fn_load_image(filename)
p = sess.run('predictions:0', feed_dict={"input:0": image})
print(p)

这给了我我期望的值。

在同一模型上使用下面的 Java 代码时

    // load the model Bundle
try (SavedModelBundle b = SavedModelBundle.load("tools/base64_model/1",
"serve")) {

// create the session from the Bundle
Session sess = b.session();

// base64 representation of JPG
byte[] content = IOUtils.toByteArray(new FileInputStream(new File((args[0]))));

String encodedString = Base64.getUrlEncoder().encodeToString(content);

Tensor t = Tensors.create(encodedString);

// run the model and get the classification
final List<Tensor<?>> result = sess.runner().feed("input", 0, t).fetch("predictions", 0).run();

// print out the result.
System.out.println(result);
}

这应该是等价的,即我将图像的 Base64 表示发送到模型,但出现异常

Exception in thread "main" java.lang.IllegalArgumentException: slice index 0 of dimension 0 out of bounds. [[{{node map/strided_slice}}]] at org.tensorflow.Session.run(Native Method) at org.tensorflow.Session.access$100(Session.java:48) at org.tensorflow.Session$Runner.runHelper(Session.java:326) at org.tensorflow.Session$Runner.run(Session.java:276) at com.stolencamerafinder.storm.crawler.bolt.enrichments.HelloTensorFlow.main(HelloTensorFlow.java:35)

张量应该有不同的内容吗?以下是 saved_model_cli 告诉我的关于我的模型的信息。

MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

signature_def['serving_default']:
The given SavedModel SignatureDef contains the following input(s):
inputs['inputs'] tensor_info:
dtype: DT_STRING
shape: (-1)
name: input:0
The given SavedModel SignatureDef contains the following output(s):
outputs['outputs'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 2)
name: predictions:0
Method name is: tensorflow/serving/predict

最佳答案

您的模型期望输入等级为 1 的张量,而您提供等级为 0 的张量。

该行生成一个可变长度的标量张量(即 DT_STRING)。

Tensor t = Tensors.create(encodedString);

但是,预期张量的等级为 1,正如您可以通过此处的形状 (-1) 看到的那样,这意味着它需要一个包含不同数量元素的 vector 。

The given SavedModel SignatureDef contains the following input(s):
inputs['inputs'] tensor_info:
dtype: DT_STRING
shape: (-1)
name: input:0

因此,您的问题可能会通过传递字符串数组来解决。仅当您将字符串作为字节数组传递时,才可以使用 Tensors 工厂,如下所示:

// base64 representation of JPG
byte[] content = IOUtils.toByteArray(new FileInputStream(new File((args[0]))));
byte[] encodedBytes = Base64.getUrlEncoder().encode(content);
Tensor t = Tensors.create(new byte[][]{ encodedBytes });
...

关于java - 使用 Java API 的维度 0 的切片索引 0 超出范围,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61389869/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com