gpt4 book ai didi

python - Tensorflow:在类中创建图形并在外部运行

转载 作者:太空狗 更新时间:2023-10-29 20:13:42 25 4
gpt4 key购买 nike

我相信我很难理解图表在 tensorflow 中的工作原理以及如何访问它们。我的直觉是,“with graph:”下的线条会将图形形成为单个实体。因此,我决定创建一个类,该类在实例化时构建一个图形,并拥有一个运行该图形的函数,如下所示;

class Graph(object):

#To build the graph when instantiated
def __init__(self, parameters ):
self.graph = tf.Graph()
with self.graph.as_default():
...
prediction = ...
cost = ...
optimizer = ...
...
# To launch the graph
def launchG(self, inputs):
with tf.Session(graph=self.graph) as sess:
...
sess.run(optimizer, feed_dict)
loss = sess.run(cost, feed_dict)
...
return variables

接下来的步骤是创建一个主文件,该文件将组装要传递给类的参数,构建图形然后运行它;

#Main file
...
parameters_dict = { 'n_input': 28, 'learnRate': 0.001, ... }

#Building graph
G = Graph(parameters_dict)
P = G.launchG(Input)
...

这对我来说非常优雅,但它不太管用(很明显)。确实,似乎 launchG 函数无法访问图中定义的节点,这给我错误,例如;

---> 26 sess.run(optimizer, feed_dict)

NameError: name 'optimizer' is not defined

也许是我对 python(和 tensorflow)的理解太有限了,但我有一种奇怪的印象,即在创建图形(G)后,以该图形作为参数运行 session 应该可以访问节点它,而不需要我提供明确的访问权限。

有什么启示吗?

最佳答案

节点predictioncostoptimizer是在方法__init__中创建的局部变量,它们不能在 launchG 方法中访问。

最简单的解决方法是将它们声明为类 Graph 的属性:

class Graph(object):

#To build the graph when instantiated
def __init__(self, parameters ):
self.graph = tf.Graph()
with self.graph.as_default():
...
self.prediction = ...
self.cost = ...
self.optimizer = ...
...
# To launch the graph
def launchG(self, inputs):
with tf.Session(graph=self.graph) as sess:
...
sess.run(self.optimizer, feed_dict)
loss = sess.run(self.cost, feed_dict)
...
return variables

您还可以通过 graph.get_tensor_by_namegraph.get_operation_by_name 使用它们的确切名称来检索图的节点。

关于python - Tensorflow:在类中创建图形并在外部运行,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/37770911/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com