gpt4 book ai didi

java - 使用 python API 进行的训练作为 java API 中 LabelImage 模块的输入?

转载 作者:行者123 更新时间:2023-11-29 04:31:31 25 4
gpt4 key购买 nike

我对 java tensorflow API 有疑问。我使用 python tensorflow API 运行训练,生成文件 output_graph.pb 和 output_labels.txt。现在出于某种原因,我想将这些文件用作 java tensorflow API 中 LabelImage 模块的输入。我认为一切都会很好,因为该模块只需要一个 .pb 和一个 .txt。然而,当我运行该模块时,出现此错误:

2017-04-26 10:12:56.711402: W tensorflow/core/framework/op_def_util.cc:332] Op BatchNormWithGlobalNormalization is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
Exception in thread "main" java.lang.IllegalArgumentException: No Operation named [input] in the Graph
at org.tensorflow.Session$Runner.operationByName(Session.java:343)
at org.tensorflow.Session$Runner.feed(Session.java:137)
at org.tensorflow.Session$Runner.feed(Session.java:126)
at it.zero11.LabelImage.executeInceptionGraph(LabelImage.java:115)
at it.zero11.LabelImage.main(LabelImage.java:68)

如果您能帮我找出问题所在,我将不胜感激。此外,我想问你是否有一种方法可以从 java tensorflow API 运行训练,因为这会让事情变得更容易。

更准确地说:

事实上,我并没有使用自己编写的代码,至少相关步骤是这样的。我所做的就是用这个模块进行训练,https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/image_retraining/retrain.py ,为它提供包含根据描述划分为子目录的图像的目录。特别是,我认为这些是生成输出的行:

output_graph_def = graph_util.convert_variables_to_constants(
sess, graph.as_graph_def(), [FLAGS.final_tensor_name])
with gfile.FastGFile(FLAGS.output_graph, 'wb') as f:
f.write(output_graph_def.SerializeToString())
with gfile.FastGFile(FLAGS.output_labels, 'w') as f:
f.write('\n'.join(image_lists.keys()) + '\n')

然后,我将输出(一个 some_graph.pb 和一个 some_labels.txt)作为此 java 模块的输入:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java , 替换默认输入。我得到的错误是上面报告的错误。

最佳答案

LabelImage.java 中默认使用的模型与正在重新训练的模型不同,因此输入和输出节点的名称不对齐。请注意,TensorFlow 模型是图,feed()fetch() 的参数是图中节点的名称。因此,您需要知道适合您的模型的名称。

查看 retrain.py,它似乎有一个节点将 JPEG 文件的原始内容作为输入(节点 DecodeJpeg/contents )并在节点中生成标签集final_result .

如果是这种情况,那么您将在 Java 中执行类似以下操作(并且您不需要构建图形的位来规范化图像,因为这似乎是重新训练模型的一部分,因此替换LabelImage.java:64 类似于:

try (Tensor image = Tensor.create(imageBytes);
Graph g = new Graph()) {
g.importGraphDef(graphDef);
try (Session s = new Session(g);
// Note the change to the name of the node and the fact
// that it is being provided the raw imageBytes as input
Tensor result = s.runner().feed("DecodeJpeg/contents", image).fetch("final_result").run().get(0)) {
final long[] rshape = result.shape();
if (result.numDimensions() != 2 || rshape[0] != 1) {
throw new RuntimeException(
String.format(
"Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",
Arrays.toString(rshape)));
}
int nlabels = (int) rshape[1];
float[] probabilities = result.copyTo(new float[1][nlabels])[0];
// At this point nlabels = number of classes in your retrained model
DoSomethingWith(probabilities);
}
}

希望对您有所帮助。

关于java - 使用 python API 进行的训练作为 java API 中 LabelImage 模块的输入?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43628829/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com