gpt4 book ai didi

python - 如何将参数从全局模型复制到线程特定模型

转载 作者:行者123 更新时间:2023-11-28 17:25:43 24 4
gpt4 key购买 nike

上下文:

我是 TensorFlow 的新手,我正在尝试实现 this paper 中的一些算法这需要偶尔从全局共享模型复制到本地线程特定模型。

问题:

完成上述任务的最佳/正确方法是什么?我在下面提供了一个虚拟示例,说明我目前正在执行此操作的方式以及我遇到的错误。有人可以解释为什么会发生错误吗?

import tensorflow as tf
import threading

class ExampleModel(object):
def __init__(self, graph):
with graph.as_default():
self.w = tf.Variable(tf.constant(1, shape=[1,2]))

sess = tf.InteractiveSession()
graph = tf.get_default_graph()
global_network = ExampleModel(graph)
sess.run(tf.initialize_all_variables())

def example(i):
global global_network, graph
local_network = ExampleModel(graph)
sess.run(local_network.w.assign(global_network.w))

threads = []
for i in range(5):
t = threading.Thread(target=example, args=(i,))
threads.append(t)

for t in threads:
t.start()

错误:

Exception in thread Thread-3:
Traceback (most recent call last):
File "/Users/kennyhsu5/anaconda/lib/python2.7/threading.py", line 801, in __bootstrap_inner
self.run()
File "/Users/kennyhsu5/anaconda/lib/python2.7/threading.py", line 754, in run
self.__target(*self.__args, **self.__kwargs)
File "tmp.py", line 16, in example
local_network = ExampleModel(graph)
File "tmp.py", line 7, in __init__
self.w = tf.Variable(tf.constant(1, shape=[1,2]))
File "/Users/kennyhsu5/anaconda/lib/python2.7/site-packages/tensorflow/python/ops/variables.py", line 211, in __init__
dtype=dtype)
File "/Users/kennyhsu5/anaconda/lib/python2.7/site-packages/tensorflow/python/ops/variables.py", line 319, in _init_from_args
self._snapshot = array_ops.identity(self._variable, name="read")
File "/Users/kennyhsu5/anaconda/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2976, in __exit__
self._graph._pop_control_dependencies_controller(self)
File "/Users/kennyhsu5/anaconda/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2996, in _pop_control_dependencies_controller
assert self._control_dependencies_stack[-1] is controller
AssertionError

最佳答案

关于tf.Graph Tensorflow 中的类:

Important note: This class is not thread-safe for graph construction. All operations should be created from a single thread, or external synchronization must be provided. Unless otherwise specified, all methods are not thread-safe.

self.w = ... 声明和 local_network.w.assign(...) 操作导致错误。

我知道它基本上会破坏您的多线程实现,但您可以通过将这些声明移至主线程来修复上述代码。然后,您可以使用线程实际运行您规定的操作。例如:

import tensorflow as tf
import threading

class ExampleModel(object):
def __init__(self, graph):
with graph.as_default():
self.w = tf.Variable(tf.constant(1, shape=[1,2]))

sess = tf.InteractiveSession()
graph = tf.get_default_graph()
global_network = ExampleModel(graph)
sess.run(tf.global_variables_initializer())

def example(sess, assign_w):
sess.run(assign_w)

threads = []
for i in range(5):
local_network = ExampleModel(graph)
assign_w = local_network.w.assign(global_network.w)
t = threading.Thread(target=example, args=(sess, assign_w))
threads.append(t)

for t in threads:
t.start()

我还建议您通过 args 参数将变量传递给线程,而不是使用 global

最后,考虑使用 global_variables_initializer 而不是弃用的 initialize_all_variables

关于python - 如何将参数从全局模型复制到线程特定模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/39199270/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com