gpt4 book ai didi

Python在子进程退出时执行函数

转载 作者:行者123 更新时间:2023-12-01 08:59:50 28 4
gpt4 key购买 nike

我有一个内存函数包装器,带有命中和未命中计数器。因为我无法从函数访问非局部变量,所以我使用字典来计算命中和未命中。

该函数在 48 个内核上的约 1000 个并行进程中运行,每个内核运行超过一百万次,因此我使用 Manager.dict 来管理分数。

仅保留分数会使我的执行时间增加三倍,所以我想做一些更聪明的事情 - 我想保留一个本地计数器,它只是一个普通的字典,当进程退出时,将该分数添加到通用分数字典中由经理管理。

有没有办法在子进程退出时执行函数?类似 atexit 的东西适用于生成的 child 。

相关代码:(注意MAGICAL_AT_PROCESS_EXIT_CLASS,这就是我想要的)

manager = Manager()

global_score = manager.dict({
"hits": 0,
"misses": 0
})

def memoize(func):
local_score = {
"hits": 0,
"misses": 0
}

cache = {}

def process_exit_handler():
global_score["hits"] += local_score["hits"]
global_score["misses"] += local_score["misses"]

MAGICAL_AT_PROCESS_EXIT_CLASS.register(process_exit_handler)

@wraps(func)
def wrap(*args):
cache_key = pickle.dumps(args)
if cache_key not in cache:
local_score["misses"] += 1
cache[cache_key] = func(*args)
else:
local_score["hits"] += 1
return cache[cache_key]

return wrap


def exit_handler():
print("Cache", global_score)

atexit.register(exit_handler)

(是的,我知道它独立缓存每个进程。是的,这是期望的行为)

当前解决方案:这仅与我的特定功能用例相关。我每个进程运行该函数一次,每次运行它会自行传播大约一百万次。我通过以下方式更改了包装方法:

@wraps(func)
def wrap(*args):
cache_key = pickle.dumps(args)
if cache_key not in cache:
local_score["misses"] += 1
local_score["open"] += 1
cache[cache_key] = func(*args)
local_score["open"] -= 1
else:
local_score["hits"] += 1

if local_score["open"] == 0:
score["hits"] += local_score["hits"]
score["misses"] += local_score["misses"]
local_score["hits"] = 0
local_score["misses"] = 0

return cache[cache_key]

不需要同步写几亿次,只需要同步进程数(1000个)即可。

最佳答案

通过子类化 Process 可以相对容易地实现这一点,通过内存来增强它,然后用它构建您自己的池,但因为您想使用 multiprocessing.Pool ,事情变得复杂了。 Pool它不是通过选择实现的,我们必须干预它的内部才能使其成为可能。当您继续阅读时,请确保没有子进程正在观看。

<小时/>

有两个问题需要解决。

  1. 使子进程在进程终止时调用退出处理程序。
  2. 预防 Pool在退出处理程序完成之前终止子进程。

为了使用 fork 作为子进程的启动方法,我发现有必要猴子补丁 multiprocessing.pool.worker 。我们可以使用atexit使用启动方法“spawn”(Windows 上默认),但这只会让我们少受一点点束缚,并剥夺我们 fork 的好处,因此以下代码不使用 atexit 。该补丁是 worker 的包装。 ,调用我们的自定义at_exit当工作进程返回时函数,这发生在进程即将退出时。

# at_exit_pool.py

import os
import threading
from functools import wraps
import multiprocessing.pool
from multiprocessing.pool import worker, TERMINATE, Pool
from multiprocessing import util, Barrier
from functools import partial


def finalized(worker):
"""Extend worker function with at_exit call."""
@wraps(worker)
def wrapper(*args, **kwargs):
result = worker(*args, **kwargs)
at_exit() # <-- patch
return result
return wrapper


worker = finalized(worker)
multiprocessing.pool.worker = worker # patch

此解决方案也是子类化 Pool来处理这两个问题。 PatientPool引入两个强制参数 at_exitat_exit_argsat_exit正在使用退出处理程序并且 PatientPool是搭载initializer来自标准Pool在子进程中注册退出处理程序。以下是处理注册退出处理程序的函数:

# at_exit_pool.py

def at_exit(func=None, barrier=None, *args):
"""Call at_exit function and wait on barrier."""
func(*args)
print(os.getpid(), 'barrier waiting') # DEBUG
barrier.wait()


def register_at_exit(func, barrier, *args):
"""Register at_exit function."""
global at_exit
at_exit = partial(at_exit, func, barrier, *args)


def combi_initializer(at_exit_args, initializer, initargs):
"""Piggyback initializer with register_at_exit."""
if initializer:
initializer(*initargs)
register_at_exit(*at_exit_args)

正如您在 at_exit 中看到的那样,我们将使用 multiprocessing.Barrier 。使用此同步原语是我们第二个问题的解决方案,防止 Pool在退出处理程序完成其工作之前终止子进程。

屏障的工作方式是阻止任何调用 .wait() 的进程。只要一个'当事人'的进程号没有被调用就可以了.wait()就在上面。

PatientPool初始化这样的屏障并将其传递给其子进程。 parties此屏障中的参数设置为子进程数 + 1。子进程正在调用 .wait()在这个障碍上,一旦他们完成at_exitPatientPool本身也调用.wait()在这个壁垒上。这发生在_terminate_pool内我们在 Pool 中重写的方法以此目的。这样做可以防止池过早终止子进程,因为所有进程都调用 .wait()仅当所有子进程都到达屏障时才会被释放。

# at_exit_pool.py

class PatientPool(Pool):
"""Pool class which awaits completion of exit handlers in child processes
before terminating the processes."""

def __init__(self, at_exit, at_exit_args=(), processes=None,
initializer=None, initargs=(), maxtasksperchild=None,
context=None):
# changed--------------------------------------------------------------
self._barrier = self._get_barrier(processes)

at_exit_args = (at_exit, self._barrier) + at_exit_args
initargs = (at_exit_args, initializer, initargs)

super().__init__(
processes, initializer=combi_initializer, initargs=initargs,
maxtasksperchild=maxtasksperchild, context=context
)
# ---------------------------------------------------------------------

@staticmethod
def _get_barrier(processes):
"""Get Barrier object for use in _terminate_pool and
child processes."""
if processes is None: # this will be repeated in super().__init__(...)
processes = os.cpu_count() or 1
if processes < 1:
raise ValueError("Number of processes must be at least 1")

return Barrier(processes + 1)

def _terminate_pool(self, taskqueue, inqueue, outqueue, pool,
worker_handler, task_handler, result_handler, cache):
"""changed from classmethod to normal method"""
# this is guaranteed to only be called once
util.debug('finalizing pool')

worker_handler._state = TERMINATE
task_handler._state = TERMINATE

util.debug('helping task handler/workers to finish')
self.__class__._help_stuff_finish(inqueue, task_handler, len(pool)) # changed

assert result_handler.is_alive() or len(cache) == 0

result_handler._state = TERMINATE
outqueue.put(None) # sentinel

# We must wait for the worker handler to exit before terminating
# workers because we don't want workers to be restarted behind our back.
util.debug('joining worker handler')
if threading.current_thread() is not worker_handler:
worker_handler.join()

# patch ---------------------------------------------------------------
print('_terminate_pool barrier waiting') # DEBUG
self._barrier.wait() # <- blocks until all processes have called wait()
print('_terminate_pool barrier crossed') # DEBUG
# ---------------------------------------------------------------------

# Terminate workers which haven't already finished.
if pool and hasattr(pool[0], 'terminate'):
util.debug('terminating workers')
for p in pool:
if p.exitcode is None:
p.terminate()

util.debug('joining task handler')
if threading.current_thread() is not task_handler:
task_handler.join()

util.debug('joining result handler')
if threading.current_thread() is not result_handler:
result_handler.join()

if pool and hasattr(pool[0], 'terminate'):
util.debug('joining pool workers')
for p in pool:
if p.is_alive():
# worker has not yet exited
util.debug('cleaning up worker %d' % p.pid)
p.join()

现在,在主模块中,您只需切换 Pool对于 PatientPool并传递所需的at_exit - 论据。为了简单起见,我的退出处理程序将 local_score 附加到 toml 文件。请注意local_score需要是一个全局变量,以便退出处理程序可以访问它。

import os
from functools import wraps
# from multiprocessing import log_to_stderr, set_start_method
# import logging
import toml
from at_exit_pool import register_at_exit, PatientPool


local_score = {
"hits": 0,
"misses": 0
}


def memoize(func):

cache = {}

@wraps(func)
def wrap(*args):
cache_key = str(args) # ~14% faster than pickle.dumps(args)
if cache_key not in cache:
local_score["misses"] += 1
cache[cache_key] = func(*args)
else:
local_score["hits"] += 1
return cache[cache_key]

return wrap


@memoize
def foo(x):
for _ in range(int(x)):
x - 1
return x


def dump_score(pathfile):
with open(pathfile, 'a') as fh:
toml.dump({str(os.getpid()): local_score}, fh)


if __name__ == '__main__':

# set_start_method('spawn')
# logger = log_to_stderr()
# logger.setLevel(logging.DEBUG)

PATHFILE = 'score.toml'
N_WORKERS = 4

arguments = [10e6 + i for i in range(10)] * 5
# print(arguments[:10])

with PatientPool(at_exit=dump_score, at_exit_args=(PATHFILE,),
processes=N_WORKERS) as pool:

results = pool.map(foo, arguments, chunksize=3)
# print(results[:10])

运行此示例将产生如下终端输出,其中“_terminate_pool Barrier Crossed”将始终最后执行,而此行之前的流程可能会有所不同:

555 barrier waiting
_terminate_pool barrier waiting
554 barrier waiting
556 barrier waiting
557 barrier waiting
_terminate_pool barrier crossed

Process finished with exit code 0

包含本次运行分数的 toml 文件如下所示:

[555]
hits = 3
misses = 8
[554]
hits = 3
misses = 9
[556]
hits = 2
misses = 10
[557]
hits = 5
misses = 10

关于Python在子进程退出时执行函数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52543336/

28 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com