gpt4 book ai didi

python - 在函数内部构建 Tensorflow 图

转载 作者:太空狗 更新时间:2023-10-30 00:24:50 25 4
gpt4 key购买 nike

我正在学习 Tensorflow 并尝试正确构建我的代码。我(或多或少)知道如何构建图形或作为类方法,但我试图弄清楚如何最好地构建代码。我试过这个简单的例子:

def build_graph():                
g = tf.Graph()
with g.as_default():
a = tf.placeholder(tf.int8)
b = tf.add(a, tf.constant(1, dtype=tf.int8))
return g

graph = build_graph()
with tf.Session(graph=graph) as sess:
feed = {a: 3}
print(sess.run(b, feed_dict=feed))

应该只打印出 4。但是,当我这样做时,出现错误:

Cannot interpret feed_dict key as Tensor: Tensor 
Tensor("Placeholder:0", dtype=int8) is not an element of this graph.

我很确定这是因为函数 build_graph 中的占位符是私有(private)的,但是 with tf.Session(graph=graph) 不应该小心那个?在这种情况下,是否有更好的方法来使用提要字典?

最佳答案

有几个选项。

选项 1:只传递张量的名称而不是张量本身。

with tf.Session(graph=graph) as sess:
feed = {"Placeholder:0": 3}
print(sess.run("Add:0", feed_dict=feed))

在这种情况下,最好为节点提供有意义的名称,而不是像上面那样使用默认名称:

def build_graph():
g = tf.Graph()
with g.as_default():
a = tf.placeholder(tf.int8, name="a")
b = tf.add(a, tf.constant(1, dtype=tf.int8), name="b")
return g

graph = build_graph()
with tf.Session(graph=graph) as sess:
feed = {"a:0": 3}
print(sess.run("b:0", feed_dict=feed))

回想一下,名为 “foo” 的操作的输出是名为 “foo:0”“foo:1” 的张量,等等。大多数操作只有一个输出。

选项 2:让您的 build_graph() 函数返回所有重要节点。

def build_graph():
g = tf.Graph()
with g.as_default():
a = tf.placeholder(tf.int8)
b = tf.add(a, tf.constant(1, dtype=tf.int8))
return g, a, b

graph, a, b = build_graph()
with tf.Session(graph=graph) as sess:
feed = {a: 3}
print(sess.run(b, feed_dict=feed))

选项 3:将重要节点添加到集合中

def build_graph():
g = tf.Graph()
with g.as_default():
a = tf.placeholder(tf.int8)
b = tf.add(a, tf.constant(1, dtype=tf.int8))
for node in (a, b):
g.add_to_collection("important_stuff", node)
return g

graph = build_graph()
a, b = graph.get_collection("important_stuff")
with tf.Session(graph=graph) as sess:
feed = {a: 3}
print(sess.run(b, feed_dict=feed))

选项 4:按照@pohe 的建议,您可以使用 get_tensor_by_name()

def build_graph():
g = tf.Graph()
with g.as_default():
a = tf.placeholder(tf.int8, name="a")
b = tf.add(a, tf.constant(1, dtype=tf.int8), name="b")
return g

graph = build_graph()
a, b = [graph.get_tensor_by_name(name) for name in ("a:0", "b:0")]
with tf.Session(graph=graph) as sess:
feed = {a: 3}
print(sess.run(b, feed_dict=feed))

我个人最常使用选项 2,它非常简单,不需要玩弄名字。当图形很大并且会存在很长时间时,我使用选项 3,因为集合与模型一起保存,这是记录真正重要内容的快速方法。我并没有真正使用选项 1,因为我更喜欢实际引用对象(不知道为什么)。当您使用其他人构建的图形时,选项 4 很有用,而且他们没有为您提供对张量的直接引用。

希望这对您有所帮助!

关于python - 在函数内部构建 Tensorflow 图,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44418442/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com