gpt4 book ai didi

python - Cython - 有效地过滤类型化的内存 View

转载 作者:太空宇宙 更新时间:2023-11-04 07:53:53 25 4
gpt4 key购买 nike

此 Cython 函数返回 numpy 数组元素中的一个随机元素,该元素在一定范围内:

cdef int search(np.ndarray[int] pool):
cdef np.ndarray[int] limited
limited = pool[(pool >= lower_limit) & (pool <= upper_limit)]
return np.random.choice(limited)

这很好用。但是,此功能对我的代码的性能非常关键。类型化的内存 View 显然比 numpy 数组快得多,但它们不能像上面那样被过滤。

我如何使用类型化的内存 View 编写一个与上面的功能相同的函数?还是有其他方法可以提高功能的性能?

最佳答案

好吧,让我们从使代码更通用开始,稍后我会谈到性能方面。

我通常不使用:

import numpy as np
cimport numpy as np

我个人喜欢为 cimported 包使用不同的名称,因为它有助于将 C 端和 NumPy-Python 端分开。所以对于这个答案,我将使用

import numpy as np
cimport numpy as cnp

此外,我将制作函数的 lower_limitupper_limit 参数。也许这些是在您的情况下静态(或全局)定义的,但它使示例更加独立。因此,起点是对您的代码稍作修改:

cpdef int search_1(cnp.ndarray[int] pool, int lower_limit, int upper_limit):
cdef cnp.ndarray[int] limited
limited = pool[(pool >= lower_limit) & (pool <= upper_limit)]
return np.random.choice(limited)

Cython 中一个非常好的功能是 fused types ,因此您可以轻松地将此函数概括为不同的类型。您的方法仅适用于 32 位整数数组(至少如果 int 在您的计算机上是 32 位)。很容易支持更多的数组类型:

ctypedef fused int_or_float:
cnp.int32_t
cnp.int64_t
cnp.float32_t
cnp.float64_t

cpdef int_or_float search_2(cnp.ndarray[int_or_float] pool, int_or_float lower_limit, int_or_float upper_limit):
cdef cnp.ndarray[int_or_float] limited
limited = pool[(pool >= lower_limit) & (pool <= upper_limit)]
return np.random.choice(limited)

当然你可以根据需要添加更多类型。优点是新版本可以在旧版本失败的地方工作:

>>> search_1(np.arange(100, dtype=np.float_), 10, 20)
ValueError: Buffer dtype mismatch, expected 'int' but got 'double'
>>> search_2(np.arange(100, dtype=np.float_), 10, 20)
19.0

现在它更通用了,让我们看看您的函数实际做了什么:

  • 您创建一个 bool 数组,其中元素高于下限
  • 您创建一个 bool 数组,其中元素低于上限
  • 您通过按位和两个 bool 数组创建一个 bool 数组。
  • 您创建一个新数组,其中仅包含 bool 掩码为真的元素
  • 你只从最后一个数组中提取一个元素

为什么要创建这么多数组?我的意思是你可以简单地计算有多少元素在限制内,取一个介于 0 和限制内的元素数之间的随机整数,然后在结果数组中的那个索引处取的任何元素.

cimport cython

@cython.boundscheck(False)
@cython.wraparound(False)
cpdef int_or_float search_3(cnp.ndarray[int_or_float] arr, int_or_float lower_bound, int_or_float upper_bound):
cdef int_or_float element

# Count the number of elements that are within the limits
cdef Py_ssize_t num_valid = 0
for index in range(arr.shape[0]):
element = arr[index]
if lower_bound <= element <= upper_bound:
num_valid += 1

# Take a random index
cdef Py_ssize_t random_index = np.random.randint(0, num_valid)

# Go through the array again and take the element at the random index that
# is within the bounds
cdef Py_ssize_t clamped_index = 0
for index in range(arr.shape[0]):
element = arr[index]
if lower_bound <= element <= upper_bound:
if clamped_index == random_index:
return element
clamped_index += 1

它不会更快,但会节省大量内存。因为你没有中间数组,你根本不需要内存 View ——但如果你愿意,你可以将参数列表中的 cnp.ndarray[int_or_float] arr 替换为 int_or_float [:] 甚至 int_or_float[::1] arr 并在 memoryview 上操作(它可能不会更快,但也不会更慢)。

我通常更喜欢 numba 而不是 Cython(至少如果我正在使用它)所以让我们将它与该代码的 numba 版本进行比较:

import numba as nb
import numpy as np

@nb.njit
def search_numba(arr, lower, upper):
num_valids = 0
for item in arr:
if item >= lower and item <= upper:
num_valids += 1

random_index = np.random.randint(0, num_valids)

valid_index = 0
for item in arr:
if item >= lower and item <= upper:
if valid_index == random_index:
return item
valid_index += 1

还有一个 numexpr 变体:

import numexpr

np.random.choice(arr[numexpr.evaluate('(arr >= l) & (arr <= u)')])

好吧,让我们做一个基准测试:

from simple_benchmark import benchmark, MultiArgument

arguments = {2**i: MultiArgument([np.random.randint(0, 100, size=2**i, dtype=np.int_), 5, 50]) for i in range(2, 22)}
funcs = [search_1, search_2, search_3, search_numba, search_numexpr]

b = benchmark(funcs, arguments, argument_name='array size')

enter image description here

因此,通过不使用中间数组,你可以快大约 5 倍,如果你使用 numba,你可以得到另一个因子 5(好像我在那里遗漏了一些可能的 Cython 优化,numba 通常快 2 倍或像 Cython 一样快)。因此,您可以使用 numba 解决方案将速度提高约 20 倍。

numexpr 在这里并没有真正的可比性,主要是因为您不能在那里使用 bool 数组索引。

差异将取决于数组的内容和限制。您还必须衡量应用程序的性能。


顺便说一句:如果下限和上限通常不会改变,最快的解决方案是过滤数组一次,然后多次调用 np.random.choice .这可能会快几个数量级

lower_limit = ...
upper_limit = ...
filtered_array = pool[(pool >= lower_limit) & (pool <= upper_limit)]

def search_cached():
return np.random.choice(filtered_array)

%timeit search_cached()
2.05 µs ± 122 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

几乎快了 1000 倍,而且根本不需要 Cython 或 numba。但这是一种特殊情况,可能对您没有用。


如果你想自己做,基准设置在这里(基于 Jupyter Notebook/Lab,因此 %-symbols):

%load_ext cython

%%cython

cimport numpy as cnp
import numpy as np

cpdef int search_1(cnp.ndarray[int] pool, int lower_limit, int upper_limit):
cdef cnp.ndarray[int] limited
limited = pool[(pool >= lower_limit) & (pool <= upper_limit)]
return np.random.choice(limited)

ctypedef fused int_or_float:
cnp.int32_t
cnp.int64_t
cnp.float32_t
cnp.float64_t

cpdef int_or_float search_2(cnp.ndarray[int_or_float] pool, int_or_float lower_limit, int_or_float upper_limit):
cdef cnp.ndarray[int_or_float] limited
limited = pool[(pool >= lower_limit) & (pool <= upper_limit)]
return np.random.choice(limited)

cimport cython

@cython.boundscheck(False)
@cython.wraparound(False)
cpdef int_or_float search_3(cnp.ndarray[int_or_float] arr, int_or_float lower_bound, int_or_float upper_bound):
cdef int_or_float element
cdef Py_ssize_t num_valid = 0
for index in range(arr.shape[0]):
element = arr[index]
if lower_bound <= element <= upper_bound:
num_valid += 1

cdef Py_ssize_t random_index = np.random.randint(0, num_valid)

cdef Py_ssize_t clamped_index = 0
for index in range(arr.shape[0]):
element = arr[index]
if lower_bound <= element <= upper_bound:
if clamped_index == random_index:
return element
clamped_index += 1

import numexpr
import numba as nb
import numpy as np

def search_numexpr(arr, l, u):
return np.random.choice(arr[numexpr.evaluate('(arr >= l) & (arr <= u)')])

@nb.njit
def search_numba(arr, lower, upper):
num_valids = 0
for item in arr:
if item >= lower and item <= upper:
num_valids += 1

random_index = np.random.randint(0, num_valids)

valid_index = 0
for item in arr:
if item >= lower and item <= upper:
if valid_index == random_index:
return item
valid_index += 1

from simple_benchmark import benchmark, MultiArgument

arguments = {2**i: MultiArgument([np.random.randint(0, 100, size=2**i, dtype=np.int_), 5, 50]) for i in range(2, 22)}
funcs = [search_1, search_2, search_3, search_numba, search_numexpr]

b = benchmark(funcs, arguments, argument_name='array size')

%matplotlib widget

import matplotlib.pyplot as plt

plt.style.use('ggplot')
b.plot()

关于python - Cython - 有效地过滤类型化的内存 View ,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51427792/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com