gpt4 book ai didi

python - 过滤NumPy数组:最佳方法是什么?

转载 作者:太空宇宙 更新时间:2023-11-03 13:54:33 25 4
gpt4 key购买 nike

假设我有一个NumPy数组arr要按元素筛选,例如。
我只想得到低于某个阈值的值。
有两种方法,例如:
使用生成器:k
使用布尔掩码切片:np.fromiter((x for x in arr if x < k), dtype=arr.dtype)
使用arr[arr < k]np.where()
使用arr[np.where(arr < k)]np.nonzero()
使用基于Cython的自定义实现
使用基于numba的自定义实现
哪个最快内存效率如何?
(编辑:根据@ShadowRanger评论添加arr[np.nonzero(arr < k)]

最佳答案

定义
使用生成器:

def filter_fromiter(arr, k):
return np.fromiter((x for x in arr if x < k), dtype=arr.dtype)

使用布尔掩码切片:
def filter_mask(arr, k):
return arr[arr < k]

使用 np.where()
def filter_where(arr, k):
return arr[np.where(arr < k)]

使用 np.nonzero()
def filter_nonzero(arr, k):
return arr[np.nonzero(arr < k)]

使用基于Cython的自定义实现:
单程 filter_cy()
两次通过 filter2_cy()
%%cython -c-O3 -c-march=native -a
#cython: language_level=3, boundscheck=False, wraparound=False, initializedcheck=False, cdivision=True, infer_types=True


cimport numpy as cnp
cimport cython as ccy

import numpy as np
import cython as cy


cdef long NUM = 1048576
cdef long MAX_VAL = 1048576
cdef long K = 1048576 // 2


cdef int smaller_than_cy(long x, long k=K):
return x < k


cdef size_t _filter_cy(long[:] arr, long[:] result, size_t size, long k):
cdef size_t j = 0
for i in range(size):
if smaller_than_cy(arr[i]):
result[j] = arr[i]
j += 1
return j


cpdef filter_cy(arr, k):
result = np.empty_like(arr)
new_size = _filter_cy(arr, result, arr.size, k)
return result[:new_size].copy()


cdef size_t _filtered_size(long[:] arr, size_t size, long k):
cdef size_t j = 0
for i in range(size):
if smaller_than_cy(arr[i]):
j += 1
return j


cpdef filter2_cy(arr, k):
cdef size_t new_size = _filtered_size(arr, arr.size, k)
result = np.empty(new_size, dtype=arr.dtype)
new_size = _filter_cy(arr, result, arr.size, k)
return result

使用基于Numba的自定义实现
单程 filter_np_nb()
两次通过 filter2_np_nb()
import numba as nb


@nb.jit
def filter_func(x, k=K):
return x < k


@nb.jit
def filter_np_nb(arr):
result = np.empty_like(arr)
j = 0
for i in range(arr.size):
if filter_func(arr[i]):
result[j] = arr[i]
j += 1
return result[:j].copy()


@nb.jit
def filter2_np_nb(arr):
j = 0
for i in range(arr.size):
if filter_func(arr[i]):
j += 1
result = np.empty(j, dtype=arr.dtype)
j = 0
for i in range(arr.size):
if filter_func(arr[i]):
result[j] = arr[i]
j += 1
return result

时间基准
基于生成器的 filter_fromiter()方法比其他方法慢得多(大约2个数量级,因此在图表中省略了它)。
计时将取决于输入数组的大小和过滤项的百分比。
作为输入大小的函数
第一个图将计时作为输入大小的函数进行处理(对于大约50%过滤掉的元素):
bm_size
一般来说,基于numba的方法始终是最快的,紧随其后的是cython方法。其中,对于中等和较大的输入,这两种传递方法是最快的在numpy中,基于 np.where()和基于 np.nonzero()的方法基本上是相同的(除了非常小的输入, np.nonzero()似乎稍慢),它们都比布尔掩码切片快,除了非常小的输入(低于~100个元素),布尔掩码切片快。
此外,对于非常小的输入,基于cython的解决方案比基于numpy的解决方案慢。
作为填充功能
第二个图将计时作为通过过滤器的项的函数来处理(对于大约100万个元素的固定输入大小):
bm_filling
第一种观察方法是,当接近50%填充时,所有的方法都是最慢的,而填充量越少,填充速度越快,对填充的速度越快(过滤出的百分比最高,在图的x轴中所显示的通过值的百分比最低)。
同样,Numba和Cython版本通常都比基于NumPy的版本快,Numba几乎总是最快的,Cython在图的最右边部分赢了Numba。
值得注意的例外是,当填充接近100%时,单通道numba/cython版本基本上被复制了大约两次,布尔掩模切片解决方案最终优于它们。
对于较大的充注阀,两个通道的方法具有增加的边际速度增益。
在NumPy中,基于 np.where()和基于 np.nonzero()的方法再次基本相同。
在比较基于NumPy的解决方案时, np.where()/ np.nonzero()解决方案几乎总是优于布尔掩码切片,除了图的最右边部分,布尔掩码切片成为最快的部分。
(完整代码可用 here
内存注意事项
基于生成器的 filter_fromiter()方法只需要最小的临时存储空间,与输入的大小无关。
记忆方面这是最有效的方法。
类似的内存效率是cython/numba两次传递方法,因为输出的大小是在第一次传递期间确定的。
在内存方面,cython和numba的单通道解决方案都需要输入大小的临时数组。
因此,这些是内存效率最低的方法。
布尔掩码切片解决方案需要输入大小为但类型为 bool的临时数组,在numpy中为1位,因此这大约是典型64位系统上numpy数组默认大小的64倍。
基于 np.where()的解决方案与第一步(内部 np.where())中的布尔掩码切片具有相同的要求,后者在第二步(输出 int)中转换为一系列 int64s(通常在64 but系统上 np.where())因此,第二步的内存需求是可变的,这取决于过滤元素的数量。
评论
当指定一个不同的过滤条件时,生成器方法也是最灵活的。
Cython解决方案需要指定数据类型以使其快速
对于numba和cython,可以将过滤条件指定为泛型函数(因此不需要硬编码),但必须在它们各自的环境中指定,并且必须注意确保正确编译该函数以提高速度,否则会观察到明显的减速
单程解决方案在返回之前需要额外的 .copy(),以避免浪费内存
由于 advanced indexing,numpy方法不返回输入的视图,而是返回一个副本:
arr = np.arange(100)
k = 50
print('`arr[arr > k]` is a copy: ', arr[arr > k].base is None)
# `arr[arr > k]` is a copy: True
print('`arr[np.where(arr > k)]` is a copy: ', arr[np.where(arr > k)].base is None)
# `arr[np.where(arr > k)]` is a copy: True
print('`arr[:k]` is a copy: ', arr[:k].base is None)
# `arr[:k]` is a copy: False

(编辑:在单程cython/numba版本中包括基于 np.nonzero()的解决方案和修复的内存泄漏,包括基于@shadowranger、@paulpanzer和@max9111注释的两程cython/numba版本。)

关于python - 过滤NumPy数组:最佳方法是什么?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58422690/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com