gpt4 book ai didi

python - *为什么* multiprocessing 序列化我的函数和闭包?

转载 作者:太空宇宙 更新时间:2023-11-04 06:40:00 26 4
gpt4 key购买 nike

根据 https://docs.python.org/3/library/multiprocessing.htmlmultiprocessing forks (for *nix) 创建一个工作进程来执行任务。我们可以通过在 fork 之前的模块中设置一个全局变量来验证这一点。如果 worker 函数导入该模块并发现变量存在,则进程内存已被复制。所以它是:

import os

def f(x):
import sys
return sys._mypid # <<< value is returned by subprocess!


def set_state():
import sys
sys._mypid = os.getpid()

def g():
from multiprocessing import Pool
pool = Pool(4)
try:
for z in pool.imap(f, range(1000)):
print(z)
finally:
pool.close()
pool.join()

if __name__=='__main__':
set_state()
g()

但是,如果事情是这样进行的,那么 multiprocessing 在序列化工作函数 f 方面有什么业务?

在这个例子中:

import os

def set_state():
import sys
sys._mypid = os.getpid()

def g():
def f(x):
import sys
return sys._mypid

from multiprocessing import Pool
pool = Pool(4)
try:
for z in pool.imap(f, range(1000)):
print(z)
finally:
pool.close()
pool.join()

if __name__=='__main__':
set_state()
g()

我们得到:

AttributeError: Can't pickle local object 'g.<locals>.f'

Stackoverflow 和互联网上有很多解决这个问题的方法。 (Python 的标准 pickle 函数可以处理函数,但不能处理闭包数据)。

但我们为什么会到这里? f 的写时复制版本位于 fork 进程的内存中。为什么它需要序列化?

最佳答案

Derp——它必须是这样的,因为:

    pool = Pool(4)  <<< processes created here

for z in pool.imap(f, range(1000)): <<< reference to function

仅供引用...任何想要 fork 的人,只要新进程可以访问函数(从而避免序列化函数),都可以遵循以下模式:

import collections
import multiprocessing as mp
import os
import pickle
import threading

_STATUS_DATA = 0
_STATUS_ERR = 1
_STATUS_POISON = 2


Message = collections.namedtuple(
"Message",
["status",
"payload",
"sequence_id"
]
)

def parallel_map(
target,
args,
num_processes,
inq_maxsize=None,
outq_maxsize=None,
serialize=pickle.dumps,
deserialize=pickle.loads,
start_method="fork",
preserve_order=True,
):
"""
:param target: Target function
:param args: Iterable of single parameter arguments for target.
:param num_processes: Number of processes.
:param inq_maxsize:
:param outq_maxsize:
:param serialize:
:param deserialize:
:param start_method:
:param preserve_order: If true result are returns in the order received by args. Otherwise,
first result is returned first
:return:
"""
if inq_maxsize is None: inq_maxsize=10*num_processes
if outq_maxsize is None: outq_maxsize=10*num_processes
inq = mp.Queue(maxsize=inq_maxsize)
outq = mp.Queue(maxsize=outq_maxsize)
poison = serialize(Message(_STATUS_POISON, None, -1))
deserialize(poison) # Test

def work():
while True:
obj = inq.get()
# print("{} - GET .. OK".format(os.getpid()))
# inq.task_done()

try:
msg = deserialize(obj)
assert isinstance(msg, Message)
if msg.status==_STATUS_POISON:
outq.put(serialize(Message(_STATUS_POISON,None,msg.sequence_id)))
# print("{} - RETURN POISON .. OK".format(os.getpid()))
return
else:
args, kw = msg.payload
result = target(*args,**kw)
outq.put(serialize(Message(_STATUS_DATA,result,msg.sequence_id)))

except Exception as e:
try:
outq.put(serialize(Message(_STATUS_ERR,e,msg.sequence_id)))
except Exception as e2:
try:
outq.put(serialize(Message(_STATUS_ERR,None,-1)))
# outq.put(serialize(1,Exception("Unable to serialize response")))
# TODO. Log exception
except Exception as e3:
pass

if start_method == "thread":
_start_method = threading.Thread
else:
_start_method = mp.get_context('fork').Process

processes = [
_start_method(
target=work,
name="parallel_map.work"
)
for _ in range(num_processes)]

for p in processes:
p.start()

quitting = []
def quit_processes():
if not quitting:
quitting.append(1)
# Send poison pills - kill child processes
for _ in range(num_processes):
inq.put(poison)

nsent = [0]
def send():
# Send the data
for seq_id, arg in enumerate(args):
obj = ((arg,), {})
inq.put(serialize(Message(_STATUS_DATA, obj, seq_id)))
nsent[0] += 1
quit_processes()

# Publish
sender = threading.Thread(
target=send,
name="parallel_map.sender",
daemon=True)
sender.start()

try:
# Consume
nquit = [0]
buffer = {}
nyielded = 0
while True:
result = outq.get() # Waiting here
# outq.task_done()
msg = deserialize(result)
assert isinstance(msg, Message)
if msg.status == _STATUS_POISON:
nquit[0]+=1
# print(">>> QUIT ACK {}".format(nquit[0]))
if nquit[0]>=num_processes:
break
else:
assert msg.sequence_id>=0

if preserve_order:
buffer[msg.sequence_id] = msg
while True:
if nyielded not in buffer:
break

msg = buffer.pop(nyielded)
nyielded += 1
if msg.status==_STATUS_ERR:
if isinstance(msg.payload, Exception):
raise msg.payload
else:
raise Exception("Unexpected exception")
else:
assert msg.status==_STATUS_DATA
yield msg.payload
else:
if msg.status==_STATUS_ERR:
if isinstance(msg.payload, Exception):
raise msg.payload
else:
raise Exception("Unexpected exception")
else:
assert msg.status==_STATUS_DATA
yield msg.payload


# if nyielded == nsent:
# break

except Exception as e:
raise
finally:
if not quitting:
quit_processes()
sender.join()
for p in processes:
p.join()


def f(x):
time.sleep(0.01)
if x ==-1:
raise Exception("Boo")
return x

用法:

        def f(x):
time.sleep(0.01)
if x ==-1:
raise Exception("Boo")
return x

for result in parallel_map(target=f, <<< not serialized
args=range(100),
num_processes=8,
start_method="fork"):
pass

... 需要注意的是:当你 fork 时,程序中的每个线程都会死掉一只小狗。

关于python - *为什么* multiprocessing 序列化我的函数和闭包?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58442111/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com