gpt4 book ai didi

python - 从数组中选择最小的 n 个元素的最快方法是什么?

转载 作者:太空宇宙 更新时间:2023-11-03 14:53:40 28 4
gpt4 key购买 nike

我在写 quick select algorithm 时很开心使用 numba 并希望分享结果。

考虑数组x

np.random.seed([3,1415])
x = np.random.permutation(np.arange(10))
x

array([9, 4, 5, 1, 7, 6, 8, 3, 2, 0])

拉取最小的 n 个元素的最快方法是什么。

我试过了
np.partition

np.partition(x, 5)[:5]

array([0, 1, 2, 3, 4])

pd.Series.nsmallest

pd.Series(x).nsmallest(5).values

array([0, 1, 2, 3, 4])

最佳答案

一般来说,我不建议尝试打败 NumPy。很少有人可以竞争(对于长数组),找到更快的实现就更少了。即使速度更快,也可能不会快 2 倍。所以它很少值得。

但是我最近尝试自己做这样的事情,所以我可以分享一些有趣的结果。

这不是我自己想出来的。我的方法基于 numbas (re-)implementation of np.median . 他们可能知道他们在做什么。

我最终得到的是:

import numba as nb
import numpy as np

@nb.njit
def _partition(A, low, high):
"""copied from numba source code"""
mid = (low + high) >> 1
if A[mid] < A[low]:
A[low], A[mid] = A[mid], A[low]
if A[high] < A[mid]:
A[high], A[mid] = A[mid], A[high]
if A[mid] < A[low]:
A[low], A[mid] = A[mid], A[low]
pivot = A[mid]

A[high], A[mid] = A[mid], A[high]

i = low
for j in range(low, high):
if A[j] <= pivot:
A[i], A[j] = A[j], A[i]
i += 1

A[i], A[high] = A[high], A[i]
return i

@nb.njit
def _select_lowest(arry, k, low, high):
"""copied from numba source code, slightly changed"""
i = _partition(arry, low, high)
while i != k:
if i < k:
low = i + 1
i = _partition(arry, low, high)
else:
high = i - 1
i = _partition(arry, low, high)
return arry[:k]

@nb.njit
def _nlowest_inner(temp_arry, n, idx):
"""copied from numba source code, slightly changed"""
low = 0
high = n - 1
return _select_lowest(temp_arry, idx, low, high)

@nb.njit
def nlowest(a, idx):
"""copied from numba source code, slightly changed"""
temp_arry = a.flatten() # does a copy! :)
n = temp_arry.shape[0]
return _nlowest_inner(temp_arry, n, idx)

我在计时之前加入了一些热身电话。预热是为了让编译时间不包括在计时中:

rselect(np.random.rand(10), 5)
nlowest(np.random.rand(10), 5)

由于计算机速度(慢得多),我稍微更改了元素数量和重复次数。但结果似乎表明我(好吧,numba 开发人员做到了)已经打败了 NumPy:

results = pd.DataFrame(
index=pd.Index([100, 500, 1000, 5000, 10000, 50000, 100000, 500000], name='Size'),
columns=pd.Index(['nsmall_np', 'nsmall_pd', 'nsmall_pir', 'nlowest'], name='Method')
)

rselect(np.random.rand(10), 5)
nlowest(np.random.rand(10), 5)

for i in results.index:
x = np.random.rand(i)
n = i // 2
for j in results.columns:
stmt = '{}(x, n)'.format(j)
setp = 'from __main__ import {}, x, n'.format(j)
results.set_value(i, j, timeit(stmt, setp, number=100))

print(results)

Method nsmall_np nsmall_pd nsmall_pir nlowest
Size
100 0.00343059 0.561372 0.00190855 0.000935566
500 0.00428461 1.79398 0.00326862 0.00187225
1000 0.00560669 3.36844 0.00432595 0.00364284
5000 0.0132515 0.305471 0.0142569 0.0108995
10000 0.0255161 0.340215 0.024847 0.0248285
50000 0.105937 0.543337 0.150277 0.118294
100000 0.2452 0.835571 0.333697 0.248473
500000 1.75214 3.50201 2.20235 1.44085

enter image description here

关于python - 从数组中选择最小的 n 个元素的最快方法是什么?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44338676/

28 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com