gpt4 book ai didi

python - Python中的高效累加和

转载 作者:行者123 更新时间:2023-12-01 19:18:48 24 4
gpt4 key购买 nike

我有一个已知大小为 N 的向量 a,因此 np.sum(a) 为 1,并且 np.all(a>=0 ) 是正确的。我想确定达到阈值 t 的最小条目数。例如,我会做类似的事情:

idx = np.argsort(a)
asorted = a[idx][::-1]
sum_ = 0
number = 0
while sum_ < t:
number += 1
sum_ = np.sum(asorted[:number])

一旦 sum_ 大于 t,程序就会停止,变量 number 告诉我求和的最小条目数该阈值。

我正在寻找获取此数字的最有效方法,因为我必须执行此操作数百万次。

最佳答案

(已编辑)

(EDIT2:添加了更专业的 JIT 版本,以解决将 np.sort()numba 一起使用时出现的问题。)

(EDIT3:包括从 @hilberts_drinking_problem's answer 开始中值旋转的递归方法的计时)

我并不是100%你想要的,因为你的代码的前两行似乎什么也没做,但是在@hilberts_drinking_problem之后我编辑了我的答案,我假设你有一个错字并且:

sum_ = np.sum(arr[:i])

应该是:

sum_ = np.sum(asorted[:i])
<小时/>

然后,您的解决方案可以编写为如下函数:

import numpy as np


def min_sum_threshold_orig(arr, threshold=0.5):
idx = np.argsort(arr)
arr_sorted = arr[idx][::-1]
sum_ = 0
i = 0
while sum_ < threshold:
i += 1
sum_ = np.sum(arr_sorted[:i])
return i

但是:

  1. 您可以直接使用 np.sort() 而不是 np.argsort() 和索引
  2. 无需在每次迭代时计算整个总和,但您可以使用上一次迭代的总和
  3. 使用 while 循环是有风险的,因为如果 threshold 足够高(您的假设为 > 1.0),那么循环将永远不会结束

解决这些问题可以:

def min_sum_threshold(arr, threshold=0.5):
arr = np.sort(arr)[::-1]
sum_ = 0
for i in range(arr.size):
sum_ += arr[i]
if sum_ >= threshold:
break
return i + 1

在上面,显式循环成为瓶颈。解决这个问题的一个好方法是使用 numba:

import numba as nb


min_sum_threshold_nbn = nb.jit(min_sum_threshold)
min_sum_threshold_nbn.__name__ = 'min_sum_threshold_nbn'

但这可能不是最有效的方法,因为创建新数组时 numba 相对较慢。一种可能更快的方法是使用 arr.sort() 代替 np.sort() ,因为它是就地的,从而避免创建新数组:

@nb.jit
def min_sum_thres_nb_inplace(arr, threshold=0.5):
arr.sort()
sum_ = 0
for i in range(arr.size - 1, -1, -1):
sum_ += arr[i]
if sum_ >= threshold:
break
return arr.size - i

或者,可以仅 JIT 排序后的代码部分:

@nb.jit
def _min_sum_thres_nb(arr, threshold=0.5):
sum_ = 0.0
for i in range(arr.size):
sum_ += arr[i]
if sum_ >= threshold:
break
return i + 1


def min_sum_thres_nb(arr, threshold=0.5):
return _min_sum_thres_nb(np.sort(arr)[::-1], threshold)

对于较大的输入,两者之间的差异将很小。对于较小的情况,min_sum_thres_nb() 将由相对较慢的额外函数调用主导。由于修改其输入的基准测试函数存在缺陷,因此基准测试中省略了 min_sum_thres_nb_inplace(),但要理解的是,对于非常小的输入,其速度与 min_sum_thres_nbn() 一样快,并且对于较大的,它的性能与 min_sum_thres_nb() 基本相同。

<小时/>

或者可以使用矢量化方法,如@yatu's answer :

def min_sum_threshold_np_sum(arr, threshold=0.5):
return np.sum(np.cumsum(np.sort(arr)[::-1]) < threshold) + 1

或者,更好的是,使用np.searchsorted(),这样可以避免通过比较创建不必要的临时数组:

def min_sum_threshold_np_ss(arr, threshold=0.5):
return np.searchsorted(np.cumsum(np.sort(arr)[::-1]), threshold) + 1

或者,假设对整个数组进行排序的成本不必要地高:

def min_sum_threshold_np_part(arr, threshold=0.5):
n = arr.size
m = np.int(size * threshold) + 1
part_arr = np.partition(arr, n - m)[n - m:]
return np.searchsorted(np.cumsum(np.sort(arr)[::-1]), threshold) + 1

使用递归和中值旋转的更复杂的方法是:

def min_sum_thres_rec(arr, threshold=0.5, cutoff=64):
n = arr.size
if n <= cutoff:
return np.searchsorted(np.cumsum(np.sort(arr)[::-1]), threshold) + 1
else:
m = n // 2
partitioned = np.partition(arr, m)
low = partitioned[:m]
high = partitioned[m:]
sum_high = np.sum(high)
if sum_high >= threshold:
return min_sum_thres_rec(high, threshold)
else:
return min_sum_thres_rec(low, threshold - sum_high) + high.size

(后三个改编自@hilberts_drinking_problem's answer)

<小时/>

使用由此生成的输入对这些进行基准测试:

def gen_input(n, a=0, b=10000):
arr = np.random.randint(a, b, n)
arr = arr / np.sum(arr)
return arr

给出以下内容:

bm_full bm_zoom

这些表明,对于足够小的输入,numba 方法是最快的,但是一旦输入超过朴素方法的约 600 个元素或优化方法的约 900 个元素, em> 第一,使用 np.partition() 的 NumPy 方法虽然内存效率较低,但速度更快。

最终,超过约 4000 个元素,min_sum_thres_rec() 比所有其他建议的方法更快。也许可以为此方法编写一个更快的基于 numba 的实现。

请注意,优化 numba 方法比经过测试的简单 NumPy 方法要快。

关于python - Python中的高效累加和,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60661015/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com