gpt4 book ai didi

python - 在 Keras 和 Tensorflow 中为多线程设置复制模型

转载 作者:太空狗 更新时间:2023-10-30 02:41:13 27 4
gpt4 key购买 nike

我正在尝试在 Keras 和 TensorFlow 中实现 actor-critic 的异步版本。我将 Keras 用作构建网络层的前端(我直接使用 tensorflow 更新参数)。我有一个 global_model 和一个主要的 tensorflow session 。但在每个线程中,我创建了一个 local_model,它从 global_model 复制参数。我的代码看起来像这样

def main(args):
config=tf.ConfigProto(log_device_placement=False,allow_soft_placement=True)
sess = tf.Session(config=config)
K.set_session(sess) # K is keras backend
global_model = ConvNetA3C(84,84,4,num_actions=3)

threads = [threading.Thread(target=a3c_thread, args=(i, sess, global_model)) for i in range(NUM_THREADS)]

for t in threads:
t.start()

def a3c_thread(i, sess, global_model):
K.set_session(sess) # registering a session for each thread (don't know if it matters)
local_model = ConvNetA3C(84,84,4,num_actions=3)
sync = local_model.get_from(global_model) # I get the error here

#in the get_from function I do tf.assign(dest.params[i], src.params[i])

我收到来自 Keras 的用户警告

UserWarning: The default TensorFlow graph is not the graph associated with the TensorFlow session currently registered with Keras, and as such Keras was not able to automatically initialize a variable. You should consider registering the proper session with Keras via K.set_session(sess)

随后是 tf.assign 操作的 tensorflow 错误,表示操作必须在同一个图上。

ValueError: Tensor("conv1_W:0", shape=(8, 8, 4, 16), dtype=float32_ref, device=/device:CPU:0) must be from the same graph as Tensor("conv1_W:0", shape=(8, 8, 4, 16), dtype=float32_ref)

我不确定到底出了什么问题。

谢谢

最佳答案

错误来自 Keras,因为 tf.get_default_graph() is sess.graph 返回 False。从 TF 文档中,我看到 tf.get_default_graph() 正在返回当前线程的默认图。当我开始一个新线程并创建一个图形时,它被构建为特定于该线程的单独图形。我可以通过执行以下操作来解决此问题,

with sess.graph.as_default():
local_model = ConvNetA3C(84,84,4,3)

关于python - 在 Keras 和 Tensorflow 中为多线程设置复制模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/40154320/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com