gpt4 book ai didi

java - Java 中的 feed_dict 等效项

转载 作者:行者123 更新时间:2023-12-01 23:08:04 24 4
gpt4 key购买 nike

我正在使用 Java 来提供通过 Python 学习的 Tensorflow 模型。该模型有两个输入。代码如下:

  def predict(float32InputShape: (Long, Long),
float32Inputs: Seq[Seq[Float]],
uint8InputShape: (Long, Long),
uint8Inputs: Seq[Seq[Byte]]
): Array[Float] = {

val float32Input = Tensor.create(
Array(float32InputShape._1, float32InputShape._2),
FloatBuffer.wrap(float32Inputs.flatten.toArray)
)
val uint8Input = Tensor.create(
classOf[UInt8],
Array(uint8InputShape._1, uint8InputShape._2),
ByteBuffer.wrap(uint8Inputs.flatten.toArray)
)

val tfResult = session
.runner()
.feed("serving_default_float32_Input", float32Input)
.feed("serving_default_uint8_Input", uint8Input)
.fetch("PartitionedCall")
.run()
.get(0)
.expect(classOf[java.lang.Float])

tfResult
}

我想要做的是通过像 Python 中的 feed_dict 一样传递输入来重构该方法,使其更加通用。也就是说,类似:

    def predict2(inputs: Map[String, Seq[Seq[Float]]]): Array[Float] = {
...
session
.runner()
.feed(inputs)
...
}

其中inputs 映射的键是输入图层的名称。除非我创建一个宏(我想避免),否则不可能使用 feed 方法来执行此操作。

有没有办法使用 Tensorflow 的 Java API 来做到这一点(我使用的是 TF 2.0)?

编辑:我找到了解决方案(感谢 @geometrikal 的回答),代码是 Scala 中的,但在 Java 中应该不会太难。

    val runnerWithInputLayers = inputs.foldLeft(session.runner()) {
case (sess, (layerName, array)) =>
val tensor = createTensor(array)
sess.feed(layerName, tensor)
}

val output = runnerWithInputLayers
.fetch(outputLayer)
.run()
.get(0)
.expect(Float.getClass)

这是可能的,因为 .feed 方法返回一个带有提供的输入层的 Session.Runner

最佳答案

您可以循环喂养每个。如果不太熟悉 java 脚本,但伪代码类似于

例如

val tfResult = session.runner()
for(key, value : inputs) {
tfResult = tfResult(key, value)
}
tfResult = tfResult.fetch("PartitionedCall")
.run()
.get(0)
.expect(classOf[java.lang.Float])

请记住,您可以在任何时候分解函数链,例如result = foo.bar().baz().qux() 可以写成 temp = foo.bar().baz();结果= temp.qux()

关于java - Java 中的 feed_dict 等效项,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58391108/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com