gpt4 book ai didi

java - 在Java中使用TensorFlow的Python Tensor

转载 作者:行者123 更新时间:2023-12-02 02:18:05 26 4
gpt4 key购买 nike

我有一个在 Python 中运行的 Tensorflow 程序,出于某些方便的原因,我想在 Java 上运行相同的程序,因此我必须保存我的模型并将其加载到我的 Java 应用程序中。

我的问题是不知道如何保存张量对象,这是我的代码:

class Main:
def __init__(self, checkpoint):
...
self.g = tf.Graph()
self.sess = tf.Session()

self.img_placeholder = tf.placeholder(tf.float32,
shape=(1, 679, 1024, 3), name='img_placeholder')

#self.preds is an instance of Tensor
self.preds = transform(self.img_placeholder)

self.saver = tf.train.Saver()
self.saver.restore(self.sess, checkpoint)

def ffwd(...):

...
_preds = self.sess.run(self.preds, feed_dict=
{self.img_placeholder: self.X})

...

因此,由于我无法创建我的张量(变换函数在幕后创建神经网络...),我必须保存它并将其重新加载到 Java 中。我找到了保存 session 的方法,但没有找到保存张量实例的方法。

有人可以给我一些关于如何实现这一目标的见解吗?

最佳答案

Python Tensor对象是对图中操作的特定输出的符号引用。

图中的操作可以通过其字符串名称来唯一标识。该操作的特定输出由该操作的输出列表中的整数索引来标识。该指数通常为零,因为绝大多数操作都会产生单一输出。

要获取 Python 中 Tensor 对象引用的操作名称和输出索引,您可以执行以下操作:

print(preds.op.name)
print(preds.value_index) # Most likely will be 0

然后在 Java 中,您可以按名称提供/获取节点。假设 preds.op.name 返回字符串 foo,并且 preds.value_index返回整数 1,然后在 Java 中,您将执行以下操作:

session.runner().feed("img_placeholder").fetch("foo", 1)

(有关详细信息,请参阅 javadoc for org.tensorflow.Session.Runner)。

您可能会在 https://github.com/tensorflow/models/tree/master/samples/languages/java 中找到链接到的幻灯片这些幻灯片中的演讲者注释很有用。

希望有帮助。

关于java - 在Java中使用TensorFlow的Python Tensor,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48979023/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com