gpt4 book ai didi

java - 如何使用 Tensorflow 1.0 Java API 创建/初始化变量

转载 作者:塔克拉玛干 更新时间:2023-11-02 19:12:08 24 4
gpt4 key购买 nike

我正在尝试移植这行 Python 代码:

my_var = tf.Variable(3, name="input_a")

到Java。我能够以这种方式使用 tf.constant 做到这一点:

graph.opBuilder("Const", name)
.setAttr("dtype", tensorVal.dataType())
.setAttr("value", tensorVal).build()
.output(0);

我尝试了一种类似的变量方法:

graph.opBuilder("Variable", name)
.setAttr("dtype", tensorVal.dataType())
.setAttr("shape", shape)
.build()
.output(0);

但是我得到这个错误:

Exception in thread "main" java.lang.IllegalStateException: Attempting to use uninitialized value input_a
[[Node: input_a/_2 = _Send[T=DT_INT32, client_terminated=false, recv_device="/job:localhost/replica:0/task:0/gpu:0", send_device="/job:localhost/replica:0/task:0/cpu:0", send_device_incarnation=1, tensor_name="edge_5_input_a", _device="/job:localhost/replica:0/task:0/cpu:0"](input_a)]]

我想我需要用这个值设置一个特殊的属性,或者我需要稍后初始化它。但是我找不到路。

我计划对大多数其他 tf 方法执行相同的操作(here 我目前的努力)。所以我想了解如何自己提出答案。例如,通过查看此 Python 源代码:

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/variable_scope.py https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/state_ops.py

我怀疑我应该分配“initializer”属性,但 java API 或初始化方法中没有 Initializer 接口(interface)。是不是还没有执行?我是 tensorflow 和 Python 的新手。

最佳答案

和你一样的需求,我用tensorflow的assign节点给我的变量赋值。所以首先你需要按照你做的方式定义你的节点,然后你需要添加具有相应值的节点。然后我稍后在我的图表中引用这个新分配的节点,这样它就不会引发错误 java.lang.IllegalStateException: Attempting to use uninitialized value

我用 GraphBuilder 类扩展了 Graph 功能并添加了这个必需的类:

class GraphBuilder(g: Graph ) {
def variable(name: String, dataType: DataType, shape: Shape): Output = {
g.opBuilder("Variable", name)
.setAttr("dtype", dataType)
.setAttr("shape", shape)
.build()
.output(0)
}

def assign(value: Output, variable: Output): Output = {
graph.opBuilder("Assign", "Assign/" + variable.op().name()).addInput(variable).addInput(value).build().output(0)
}
}

val WValue = Array.fill(numFeatures)(Array.fill(hiddenDim)(0.0))
val W = builder.variable("W", DataType.DOUBLE, Shape.make(numFeatures, hiddenDim))
val W_init = builder.assign(builder.constant("Wval", WValue), W)

assign 节点会在每次前向传递时为您的变量分配预设值,因此它也不适合训练。但无论如何,从这篇文章看来,您似乎需要添加依赖项,因为默认情况下 JAVA API 不提供训练节点:https://github.com/tensorflow/tensorflow/issues/5518 .

关于java - 如何使用 Tensorflow 1.0 Java API 创建/初始化变量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42813989/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com