gpt4 book ai didi

python - 在不写入磁盘的情况下在 session 之间重用 TensorFlow 变量的值

转载 作者:太空宇宙 更新时间:2023-11-03 11:43:12 25 4
gpt4 key购买 nike

在 sklearn 中,我习惯于拥有一个可以运行 fit 然后 predict 的模型。但是,对于 TensorFlow,我在调用 predict 时无法从 fit 加载学习到的参数。归结为我不知道如何在 session 之间重用变量的值。例如,

import tensorflow as tf

x = tf.Variable(0.0)

# fit code
with tf.Session() as sess1:
sess1.run(tf.global_variables_initializer())
sess1.run(tf.assign(x, 1.0)) # at end of training, x = 1.0

# predict code
with tf.Session() as sess2:
sess2.run(tf.global_variables_initializer())
print(sess2.run(x)) # want this to be 1.0, but is 0.0

我能想到一个解决方法,但它看起来真的很笨拙,如果我想重用多个变量会很烦人:

import tensorflow as tf

x = tf.Variable(0.0)

# fit code
with tf.Session() as sess1:
sess1.run(tf.global_variables_initializer())
sess1.run(tf.assign(x, 1.0)) # at end of training, x = 1.0
learned_x = sess1.run(x) # remember value of learned x at end of session

# predict code
with tf.Session() as sess2:
sess2.run(tf.global_variables_initializer())
sess2.run(tf.assign(x, learned_x))
print(sess2.run(x)) # prints 1.0

如何在不写入磁盘的情况下在 session 之间重用变量(即使用 tf.train.Saver)?我在上面写的解决方法是执行此操作的正确方法吗?

最佳答案

要模仿 sklearn 的模型,只需将 session 包装到一个类中,以便您可以在方法之间共享它,例如

class Model:
def __init__(self):
self.graph = self.build_graph()
self.session = tf.Session()
self.session.run(tf.global_variables_initializer())

def build_graph(self):
return {'x': tf.Variable(0.0)}

def fit(self):
self.session.run(tf.assign(self.graph['x'], 1.0))

def predict(self):
print(self.session.run(self.graph['x']))

def close(self):
tf.reset_default_graph()
self.session.close()

m = Model()
m.fit()
m.predict()
m.close()

确保手动关闭 session 并相应地处理异常。

关于python - 在不写入磁盘的情况下在 session 之间重用 TensorFlow 变量的值,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45131230/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com