gpt4 book ai didi

python - Tensorflow,多处理中的更新权重

转载 作者:行者123 更新时间:2023-12-01 09:11:13 28 4
gpt4 key购买 nike

我定义了一个网络,每个范围包含每个进程的权重,每个进程分配其相应的权重,这是我的演示代码

from multiprocessing import Process

import tensorflow as tf


def init_network(name):
with tf.name_scope(name):
x = tf.Variable(int(name))
return x


def f(name, sess):
print('step into f()')
vars = tf.trainable_variables(name)
print(sess.run(vars[0]))
sess.run(vars[0].assign(int(name)+10))


if __name__ == '__main__':
sess = tf.Session()
x1 = init_network('1')
x2 = init_network('2')
sess.run(tf.global_variables_initializer())
p1 = Process(target=f, args=('1', sess))
p2 = Process(target=f, args=('2', sess))

p1.start()
p2.start()

p1.join()
p2.join()
print(sess.run([x1, x2]))

演示代码卡住了,sess 似乎无法在不同进程中共享,如何更新多处理设置中的权重?

最佳答案

谷歌搜索了一段时间后,我发现多处理不适用于TensorFlow,因此,我使用线程

from threading import Thread

import tensorflow as tf

def init_network(name):
with tf.name_scope(name):
x = tf.Variable(int(name))
return x

def f(name, sess):
with sess.as_default(), sess.graph.as_default():
print('step into f()')
vars = tf.trainable_variables(name)
print(vars)
sess.run(vars[0].assign(int(name)+10))
print(sess.run(vars[0]))


if __name__ == '__main__':
sess = tf.Session()
coord = tf.train.Coordinator()

x1 = init_network('1')
x2 = init_network('2')
sess.run(tf.global_variables_initializer())
print(sess.run([x1, x2]))

p1 = Thread(target=f, args=('1', sess))
p2 = Thread(target=f, args=('2', sess))
p1.start()
p2.start()
coord.join([p1, p2])
print(sess.run([x1, x2]))

现在可以了,默认 session 是当前线程的属性。如果您创建一个新线程并希望在该线程中使用默认 session ,则必须在该线程的函数中显式添加 with sess.as_default(): 。并且您必须显式输入 with sess.graph.as_default(): block 以使 sess.graph 成为默认图表。

tf.train.Coordinator 加入线程非常方便。还可以使用 thread.join() 方法来加入线程。

关于python - Tensorflow,多处理中的更新权重,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51636756/

28 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com