gpt4 book ai didi

python - 在 numpy/openblas 上设置运行时的最大线程数

转载 作者:太空狗 更新时间:2023-10-29 20:10:49 26 4
gpt4 key购买 nike

我想知道是否可以在(Python)运行时更改 OpenBLAS 在 numpy 后面使用的最大线程数?

我知道可以在通过环境变量 OMP_NUM_THREADS 运行解释器之前设置它,但我想在运行时更改它。

通常,当使用 MKL 而不是 OpenBLAS 时,有可能:

import mkl
mkl.set_num_threads(n)

最佳答案

您可以通过使用 ctypes 调用 openblas_set_num_threads 函数来完成此操作。我经常发现自己想这样做,所以我写了一个小上下文管理器:

import contextlib
import ctypes
from ctypes.util import find_library

# Prioritize hand-compiled OpenBLAS library over version in /usr/lib/
# from Ubuntu repos
try_paths = ['/opt/OpenBLAS/lib/libopenblas.so',
'/lib/libopenblas.so',
'/usr/lib/libopenblas.so.0',
find_library('openblas')]
openblas_lib = None
for libpath in try_paths:
try:
openblas_lib = ctypes.cdll.LoadLibrary(libpath)
break
except OSError:
continue
if openblas_lib is None:
raise EnvironmentError('Could not locate an OpenBLAS shared library', 2)


def set_num_threads(n):
"""Set the current number of threads used by the OpenBLAS server."""
openblas_lib.openblas_set_num_threads(int(n))


# At the time of writing these symbols were very new:
# https://github.com/xianyi/OpenBLAS/commit/65a847c
try:
openblas_lib.openblas_get_num_threads()
def get_num_threads():
"""Get the current number of threads used by the OpenBLAS server."""
return openblas_lib.openblas_get_num_threads()
except AttributeError:
def get_num_threads():
"""Dummy function (symbol not present in %s), returns -1."""
return -1
pass

try:
openblas_lib.openblas_get_num_procs()
def get_num_procs():
"""Get the total number of physical processors"""
return openblas_lib.openblas_get_num_procs()
except AttributeError:
def get_num_procs():
"""Dummy function (symbol not present), returns -1."""
return -1
pass


@contextlib.contextmanager
def num_threads(n):
"""Temporarily changes the number of OpenBLAS threads.

Example usage:

print("Before: {}".format(get_num_threads()))
with num_threads(n):
print("In thread context: {}".format(get_num_threads()))
print("After: {}".format(get_num_threads()))
"""
old_n = get_num_threads()
set_num_threads(n)
try:
yield
finally:
set_num_threads(old_n)

你可以这样使用它:

with num_threads(8):
np.dot(x, y)

如评论中所述,openblas_get_num_threadsopenblas_get_num_procs 在撰写本文时是非常新的功能,因此可能无法使用,除非您从最新版本编译 OpenBLAS源代码。

关于python - 在 numpy/openblas 上设置运行时的最大线程数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/29559338/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com