gpt4 book ai didi

python - 将 Tensorflow 与多处理一起使用时无法腌制 'weakref' 对象

转载 作者:行者123 更新时间:2023-12-04 14:58:24 25 4
gpt4 key购买 nike

我想同时训练多个神经网络,我正在尝试使用 multiprocessing 模块,以便每个网络都可以在单独的过程中进行训练,但我遇到了一个问题.当我运行下面的演示代码时(由于 apply_async 函数没有给出错误提示,我暂时将其更改为 apply 函数):

import tensorflow as tf
import multiprocessing as mp


class SeqModel(tf.keras.Sequential):
def __init__(self, input_size, hidden_sizes, output_size):
super().__init__()
self.add(tf.keras.layers.Dense(hidden_sizes[0], activation="relu", input_shape=(input_size,)))
for hidden_size in hidden_sizes[1:]: self.add(tf.keras.layers.Dense(hidden_size, activation="relu"))
if output_size is not None: self.add(tf.keras.layers.Dense(output_size))


class Partition:
def __init__(self, partition_id):
self.partition_id = partition_id
self.model = None

def initialization(self):
self.model = SeqModel(10,[10,10],10)

def test(self):
print(f'partition {self.partition_id} testing...')


def func():
partition_list = [Partition(i) for i in range(4)]

for partition in partition_list: partition.initialization()

p = mp.Pool(4)
for partition in partition_list:
p.apply(partition.test)
p.close()
p.join()


if __name__ == '__main__':
func()

我收到以下错误:

Traceback (most recent call last):
File "C:/Users/Administrator/Dropbox (ASU)/Work/Traffic State Estimation/traffic state estimation/dataset/mp/mp_net.py", line 43, in <module>
func()
File "C:/Users/Administrator/Dropbox (ASU)/Work/Traffic State Estimation/traffic state estimation/dataset/mp/mp_net.py", line 37, in func
p.apply(partition.test)
File "C:\Users\Administrator\AppData\Local\Programs\Python\Python38\lib\multiprocessing\pool.py", line 357, in apply
return self.apply_async(func, args, kwds).get()
File "C:\Users\Administrator\AppData\Local\Programs\Python\Python38\lib\multiprocessing\pool.py", line 771, in get
raise self._value
File "C:\Users\Administrator\AppData\Local\Programs\Python\Python38\lib\multiprocessing\pool.py", line 537, in _handle_tasks
put(task)
File "C:\Users\Administrator\AppData\Local\Programs\Python\Python38\lib\multiprocessing\connection.py", line 206, in send
self._send_bytes(_ForkingPickler.dumps(obj))
File "C:\Users\Administrator\AppData\Local\Programs\Python\Python38\lib\multiprocessing\reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
TypeError: cannot pickle 'weakref' object

如果我不进行分区初始化(分区实例中不涉及 SeqModel),代码运行没有问题。这是否意味着我不能在子进程中使用 tf 模型?

最佳答案

要使用 Pool,您的对象必须是可挑选的,因为 Pool 方法使用 mp.SimpleQueue 将任务发送到进程,并且 mp.SimpleQueue 只接受腌制对象。

虽然默认情况下 Tensorflow 模型不可选取,因此您不能轻松地将 Pool 与 Tensorflow 模型一起使用。查看 TensorFlow 中的未解决问题,使 Model 可挑选。

但是,您可以尝试通过讨论中建议的变通方法使 Model 可选 https://github.com/tensorflow/tensorflow/issues/34697#issuecomment-627193883

关于python - 将 Tensorflow 与多处理一起使用时无法腌制 'weakref' 对象,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/67440375/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com