gpt4 book ai didi

python - 实现 numpy.isin 后跟 sum 的更快方法

转载 作者:行者123 更新时间:2023-11-28 22:11:08 24 4
gpt4 key购买 nike

我正在使用 python 脚本执行数据分析,并从分析中了解到超过 95% 的计算时间被执行以下操作的行占用 np.sum(C[np.isin(A, b)]) , 其中A , C是等维的二维 NumPy 数组 m x n , 和 b是可变长度的一维数组。我想知道如果不是专用的 NumPy 函数,是否有加速此类计算的方法?

A (int64) 的典型尺寸, C (float64) : 10M x 100

b (int64) 的典型尺寸: 1000

最佳答案

由于您的标签来自一个小的整数范围,您应该通过使用下面的 np.bincount (pp) 获得相当大的加速。或者,您可以通过创建掩码 (p2) 来加速查找。这---与您的原始代码一样---允许将 np.sum 替换为 math.fsum 这保证了机器精度内的精确结果(p3)。或者,我们可以将它进行 pythranize 处理以获得另一个 40% 加速 (p4)。

在我的平台上,numba soln (mx) 的速度与 pp 差不多,但也许我做得不对。

import numpy as np
import math
from subsum import pflat

MAXIND = 120_000

def OP():
return sum(C[np.isin(A, b)])

def pp():
return np.bincount(A.reshape(-1), C.reshape(-1), MAXIND)[np.unique(b)].sum()
def p2():
grid = np.zeros(MAXIND, bool)
grid[b] = True
return C[grid[A]].sum()
def p3():
grid = np.zeros(MAXIND, bool)
grid[b] = True
return math.fsum(C[grid[A]])
def p4():
return pflat(A.ravel(), C.ravel(), b, MAXIND)

import numba as nb

@nb.njit(parallel=True,fastmath=True)
def nb_ss(A,C,b):
s=set(b)
sum=0.
for i in nb.prange(A.shape[0]):
for j in range(A.shape[1]):
if A[i,j] in s:
sum+=C[i,j]
return sum

def mx():
return nb_ss(A,C,b)

sh = 100_000, 100

A = np.random.randint(0, MAXIND, sh)
C = np.random.random(sh)
b = np.random.randint(0, MAXIND, 1000)

print(OP(), pp(), p2(), p3(), p4(), mx())

from timeit import timeit

print("OP", timeit(OP, number=4)*250)
print("pp", timeit(pp, number=10)*100)
print("p2", timeit(p2, number=10)*100)
print("p3", timeit(p3, number=10)*100)
print("p4", timeit(p4, number=10)*100)
print("mx", timeit(mx, number=10)*100)

pythran模块的代码:

[求和.py]

import numpy as np

#pythran export pflat(int[:], float[:], int[:], int)

def pflat(A, C, b, MAXIND):
grid = np.zeros(MAXIND, bool)
grid[b] = True
return C[grid[A]].sum()

编译很简单pythran subsum.py

样本运行:

41330.15849965791 41330.15849965748 41330.15849965747 41330.158499657475 41330.15849965791 41330.158499657446
OP 1963.3807722493657
pp 53.23419079941232
p2 21.8758742994396
p3 26.829131800332107
p4 12.988955597393215
mx 52.37018179905135

关于python - 实现 numpy.isin 后跟 sum 的更快方法,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56120273/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com