gpt4 book ai didi

python - 如何使 numba(nopython=true) 处理元素数量未知的一维 numpy.ndarray 输入

转载 作者:行者123 更新时间:2023-12-01 00:22:27 24 4
gpt4 key购买 nike

我正在将一个(数学复杂/涉及但操作很少)自制经验分布类从 C++/MATLAB(我都有)移植到 Python。

该文件有大约 1100 行代码,包括注释和测试数据,其中包括

if __name__ == "__main__": 

位于文件底部。

第 83 行有函数声明:def cdf(self, x):

它编译并运行得很好,只是非常慢,所以我想用 @numba.jit(nopython=True) 进行编译以使其运行得更快。

但是,编译在文件 npts=len(x) 第 85 行函数的最早几行之一(仅前面的注释)处终止。

消息结尾为:

[1] During: typing of argument at
C:\Users\kdalbey\Canopy\scripts\empDist.py (85)
--%<-----------------------------------------------------------------

File "Canopy\scripts\empDist.py", line 85

This error may have been caused by the following argument(s):
- argument 0: cannot determine Numba type of <class '__main__.empDist'>

现在我确实在文件顶部做了 import numpy as np 但为了清楚地显示下面的消息,我尝试将 np 替换为 numpy。但我可能错过了一些。

如果我使用npts=x.size,我会收到相同的错误消息。

所以我尝试输入 x 为:

@numba.jit(nopython=True)
def cdf(self, x: numpy.ndarray(dtype=numpy.float64)):

我收到以下错误

---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
C:\Users\kdalbey\Canopy\scripts\empDist.py in <module>()
15 np.set_printoptions(precision=16)
16
---> 17 class empDist:
18 def __init__(self, xdata):
19 npts=len(xdata)
C:\Users\kdalbey\Canopy\scripts\empDist.py in empDist()
81
82 @numba.jit(nopython=True)
---> 83 def cdf(self, x: np.ndarray(dtype=np.float64)):
84 # compute the value of cdf at vector of points x
85 npts = x.size
TypeError: Required argument 'shape' (pos 1) not found

但我不知道 1D numpy.ndarray 预先有多少个元素(它是任意的)

我猜我也许可以做

@numba.jit(nopython=True)
def cdf(self, x: numpy.ndarray(shape=(), dtype=numpy.float64)):

它会克服该错误,然后返回到

[1] During: typing of argument at
C:\Users\kdalbey\Canopy\scripts\empDist.py (85)
--%<-----------------------------------------------------------------
File "Canopy\scripts\empDist.py", line 85
This error may have been caused by the following argument(s):
- argument 0: cannot determine Numba type of <class '__main__.empDist'>

如果我执行 npts=int(x.size)npts=numpy.int32(x.size) ,也会出现同样的错误,所以我确定问题出在 x 上。

最佳答案

由于多个问题(从 numba 版本 0.46.0 开始),您的方法存在问题:

  • numpy.ndarray(shape=(), dtype=numpy.float64) 确实尝试创建一个 NumPy 数组。您将其用作类型提示并不重要。它仍然执行(并且失败)。
  • 您应该在 jit 中使用更合适的(对于 numba)签名,而不是类型提示。或者甚至更好:完全省略签名,让 numba 来解决。在大多数情况下,numba 更擅长,并且花费您更少的精力(如果您不需要限制类型)。
  • 您无法在 nopython 模式下 jit 方法。更好的方法是创建一个函数并从您的方法中调用它。

所以在你的情况下:

import numba as nb

@nb.njit
def _cdf(x):
# do something with x

class empDist:
def cdf(self, x):
result = _cds(x)
...

您的示例可能更复杂,但这应该为您提供一个良好的起点。如果您需要使用实例属性,只需将它们传递给 _cdf (如果 numba 支持 them )。

<小时/>

总的来说,尝试在所有事情上使用 numba 并不是一个好主意。 Numba 的范围非常有限,但它的适用范围可能会令人惊叹。

就你而言,你说它很慢。那么第一步应该是分析你的代码并找出它慢的原因和位置。然后尝试找出是否可以用更快的方法解决这个瓶颈。通常问题不在于代码本身,而在于算法/方法。检查它是否使用次优方法。如果不是,它是一个数字重的部分,那么使用 numba 可能是有意义的 - 但请注意:通常您根本不需要 numba,因为只需通过优化即可获得足够的性能NumPy 部分。

关于python - 如何使 numba(nopython=true) 处理元素数量未知的一维 numpy.ndarray 输入,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58864004/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com