gpt4 book ai didi

python - 在 numba nopython 函数中计算阶乘的最快方法

转载 作者:太空宇宙 更新时间:2023-11-03 15:09:30 24 4
gpt4 key购买 nike

我有一个想要用 numba 编译的函数,但是我需要计算该函数内的阶乘。不幸的是 numba 不支持 math.factorial:

import math
import numba as nb

@nb.njit
def factorial1(x):
return math.factorial(x)

factorial1(10)
# UntypedAttributeError: Failed at nopython (nopython frontend)

我看到它支持 math.gamma (可用于计算阶乘),但是与真正的 math.gamma 函数相反,它不返回代表“整数值”的 float :

@nb.njit
def factorial2(x):
return math.gamma(x+1)

factorial2(10)
# 3628799.9999999995 <-- not exact

math.gamma(11)
# 3628800.0 <-- exact

math.factorial 相比,它的速度较慢:

%timeit factorial2(10)
# 1.12 µs ± 11.3 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
%timeit math.factorial(10)
# 321 ns ± 6.12 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

所以我决定定义自己的函数:

@nb.njit
def factorial3(x):
n = 1
for i in range(2, x+1):
n *= i
return n

factorial3(10)
# 3628800

%timeit factorial3(10)
# 821 ns ± 12.2 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

它仍然比 math.factorial 慢,但比基于 math.gamma 的 numba 函数更快,并且值是“精确的”。

因此,我正在寻找在 nopython numba 函数内计算正整数(<= 20;以避免溢出)的阶乘的最快方法。

最佳答案

对于 <= 20 的值,Python 使用查找表,正如评论中所建议的那样。 https://github.com/python/cpython/blob/3.6/Modules/mathmodule.c#L1452

LOOKUP_TABLE = np.array([
1, 1, 2, 6, 24, 120, 720, 5040, 40320,
362880, 3628800, 39916800, 479001600,
6227020800, 87178291200, 1307674368000,
20922789888000, 355687428096000, 6402373705728000,
121645100408832000, 2432902008176640000], dtype='int64')

@nb.jit
def fast_factorial(n):
if n > 20:
raise ValueError
return LOOKUP_TABLE[n]

由于 numba 调度开销,从 python 调用它比 python 版本稍慢。

In [58]: %timeit math.factorial(10)
10000000 loops, best of 3: 79.4 ns per loop

In [59]: %timeit fast_factorial(10)
10000000 loops, best of 3: 173 ns per loop

但在另一个 numba 函数内调用可能会快得多。

def loop_python():
for i in range(10000):
for n in range(21):
math.factorial(n)

@nb.njit
def loop_numba():
for i in range(10000):
for n in range(21):
fast_factorial(n)

In [65]: %timeit loop_python()
10 loops, best of 3: 36.7 ms per loop

In [66]: %timeit loop_numba()
10000000 loops, best of 3: 73.6 ns per loop

关于python - 在 numba nopython 函数中计算阶乘的最快方法,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44346188/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com