gpt4 book ai didi

java - 为什么 Java Tensorflow session 似乎重置了状态,而 Python Tensorflow session 却没有?

转载 作者:塔克拉玛干 更新时间:2023-11-02 19:48:24 26 4
gpt4 key购买 nike

我正尝试在 Linux 上通过 1.4 Java API 构建和评估 TensorFlow Graphs。我注意到每次调用 Session.run() 时,Java API 似乎都会重置操作输出张量的值。这种行为似乎与 Python 中发生的情况不符。我最终的问题(详见下文)是如何避免这种明显的行为?

Python 示例

这里的示例是 Python 代码(也使用 1.4 API),它增加标量张量中的值。

>>> import tensorflow as tf
>>> x = tf.get_variable("x", [], dtype=tf.float32, initializer=tf.zeros_initializer)
>>> step = tf.constant(1.0)
>>> xUpdateOp = x.assign_add(step)
>>> s = tf.Session()
>>> s.run(x.initializer)
>>> x.eval(s)
0.0
>>> s.run(xUpdateOp)
1.0
>>> x.eval(s)
1.0
>>> s.run(xUpdateOp)
2.0
>>> x.eval(s)
2.0
>>>

请注意,正如预期的那样,评估 x 会给出其当前值,并且使用 session 运行 xUpdateOp 会使 x 变大 1。

Java 示例

这是我尝试使用 Java 构建一个递增标量张量的 Tensorflow 图。 Java API 中的初始化有所不同,因为它缺少一些 Python 的便捷方法。

public static void doCounting(){
try(Graph g = new Graph()){
try(Tensor<Float> zeroT = Tensors.create(0.0f);
Tensor<Float> stepT = Tensors.create(1.0f)){
Output<Float> zero = g.opBuilder("Const", "start")
.setAttr("dtype", zeroT.dataType())
.setAttr("value", zeroT)
.build().output(0);
Output<Float> step = g.opBuilder("Const", "step")
.setAttr("dtype", stepT.dataType())
.setAttr("value", stepT)
.build().output(0);
Output<Float> xVar = g.opBuilder("Variable", "x")
.setAttr("dtype", zero.dataType())
.setAttr("shape", zero.shape())
.build().output(0);
Output<Float> x = g.opBuilder("Assign", "init_x")
.addInput(xVar)
.addInput(zero)
.build().output(0);

Operation xUpdateOp = g.opBuilder("AssignAdd", "x_get_x_plus_step")
.addInput(x)
.addInput(step)
.build();

try(Session s = new Session(g)) {
s.runner().addTarget(xUpdateOp).run();
s.runner().addTarget(xUpdateOp).run();
s.runner().addTarget(xUpdateOp).run();

try(Tensor<Float> result = s.runner().fetch(xUpdateOp.name(), 0).run().get(0).expect(Float.class)){
System.out.println(result.floatValue());
}
}
}
}
}

以上代码片段的输出

1.0

但我预计它是 4.0,因为我在 xUpdateOp 上调用了 run() 4 次。即使我差一个 1.0 也不是我所期望的。

问题

我需要如何处理此 Java 示例才能获得与 Python 示例相同的行为?如何让 xUpdateOp 使用在之前调用 run() 时计算的 x 值?

我已经尝试过的

我已经尝试过使用 feed() 函数来输入 x 值

try(Session s = new Session(g)) {
try(Tensor<Float> x1 = s.runner().fetch(xUpdateOp.name()).run().get(0).expect(Float.class)) {
s.runner().feed(xUpdateOp.name(), 0, x1);
try (Tensor<Float> result = s.runner().fetch(xUpdateOp.name(), 0).run().get(0).expect(Float.class)) {
System.out.println(result.floatValue());
}
}
}

结果

1.0

我还尝试在没有 addTarget 或 fetch() 的情况下调用 run(),认为 addTarget 或 fetch() 是导致状态重置的原因。也许一旦 session 了解要运行的内容,它就可以运行多次。

try(Session s = new Session(g)) {
s.runner().addTarget(xUpdateOp).run();
s.runner().run();
s.runner().run();

try(Tensor<Float> result = s.runner().fetch(xUpdateOp.name(), 0).run().get(0).expect(Float.class)){
System.out.println(result.floatValue());
}
}

结果

Exception in thread "main" java.lang.IllegalArgumentException: Must specify at least one target to fetch or execute.
at org.tensorflow.Session.run(Native Method)
at org.tensorflow.Session.access$100(Session.java:48)
at org.tensorflow.Session$Runner.runHelper(Session.java:298)
at org.tensorflow.Session$Runner.run(Session.java:248)
at org.tensorflow.examples.Example.doCounting(MandelbrotExample.java:80)
at org.tensorflow.examples.Example.main(MandelbrotExample.java:50)
ERROR: Non-zero return code '1' from command: Process exited with status 1.

一些相关的问题

How to create/initialize a Variable with Tensorflow 1.0 Java API

java tensorflow reset_default_graph

Java - train loaded tensorflow model

提前感谢您的宝贵时间!

最佳答案

在您的示例中,xUpdateOpx 作为输入,x 是分配零< 的操作的输出 到变量。因此,每次运行 xUpdateOp 时,它都会首先将零分配给变量。

对您的代码稍作调整即可生成 4.0:

# Changed addInput(x) to addInput(xVar)
Operation xUpdateOp =
g.opBuilder("AssignAdd", "x_get_x_plus_step").addInput(xVar).addInput(step).build();

try (Session s = new Session(g)) {
# Initialize the variable once
s.runner().addTarget(x.op()).run();
s.runner().addTarget(xUpdateOp).run();
s.runner().addTarget(xUpdateOp).run();
s.runner().addTarget(xUpdateOp).run();

try (Tensor<Float> result =
s.runner().fetch(xUpdateOp.name(), 0).run().get(0).expect(Float.class)) {
System.out.println(result.floatValue());
}
}

与 Python 代码类比:上面的 Java 代码片段更像是问题中的 Python 代码。虽然问题中的 Java 代码更像是 Python 中的以下代码:

import tensorflow as tf

zero = tf.constant(0.0)
step = tf.constant(1.0)
xVar = tf.Variable(initial_value=zero, name="x")
x = tf.assign(xVar, zero)
xUpdateOp = tf.assign_add(x, step)

所以 tf.assign_add(x, step)tf.assign_add(xVar, step) 会产生很大的不同。在前者中,AssignAdd 操作应用于 Assign 操作的输出。

希望对您有所帮助。

关于java - 为什么 Java Tensorflow session 似乎重置了状态,而 Python Tensorflow session 却没有?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48118747/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com