gpt4 book ai didi

python - Numba nopython 模式不能接受二维 bool 索引

转载 作者:太空宇宙 更新时间:2023-11-03 23:55:20 58 4
gpt4 key购买 nike

我正在尝试使用 numba 加速代码(目前我正在使用 numba 0.45.1)并遇到 bool 索引问题。代码如下:

from numba import njit
import numpy as np

n_max = 1000

n_arr = np.hstack((np.arange(1,3),
np.arange(3,n_max, 3)
))

@njit
def func(arr):
idx = np.arange(arr[-1]).reshape((-1,1)) < arr -2
result = np.zeros(idx.shape)
result[idx] = 10.1
return result

new_arr = func(n_arr)

运行代码后,我立即收到以下消息

TypingError: Invalid use of Function(<built-in function setitem>) with argument(s) of type(s): (array(float64, 2d, C), array(bool, 2d, C), float64)
* parameterized
In definition 0:
All templates rejected with literals.
In definition 1:
All templates rejected without literals.
In definition 2:
All templates rejected with literals.
In definition 3:
All templates rejected without literals.
In definition 4:
All templates rejected with literals.
In definition 5:
All templates rejected without literals.
In definition 6:
All templates rejected with literals.
In definition 7:
All templates rejected without literals.
In definition 8:
TypeError: unsupported array index type array(bool, 2d, C) in [array(bool, 2d, C)]
raised from C:\Users\User\Anaconda3\lib\site-packages\numba\typing\arraydecl.py:71
In definition 9:
TypeError: unsupported array index type array(bool, 2d, C) in [array(bool, 2d, C)]
raised from C:\Users\User\Anaconda3\lib\site-packages\numba\typing\arraydecl.py:71
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: typing of setitem at C:/Users/User/Desktop/all python file/5.5.5/numba index broadcasting2.py (29)

注意最后一行的(29)对应的是第29行,也就是result[idx] = 10.1,我尝试给result赋值的那一行其索引为 idx,一个二维 bool 索引。


我想解释一下,必须在 @njit 中包含该语句 result[idx] = 10.1 。尽管我想在 @njit 中排除这条语句,但我做不到,因为这一行恰好位于我正在处理的代码的中间。

如果我坚持要在 @njit 中包含赋值语句 result[idx] = 10.1,到底需要更改什么才能使其正常工作?如果可能的话,我希望在 @njit 中看到一些可以运行的涉及二维 bool 索引的代码示例。

谢谢

最佳答案

Numba 当前不支持使用二维数组进行花式索引。见:

https://numba.pydata.org/numba-doc/dev/reference/numpysupported.html#array-access

但是,您可以通过明确地使用 for 循环而不是依赖广播重写您的函数来获得等效的行为:

from numba import njit
import numpy as np

n_max = 1000

n_arr = np.hstack((np.arange(1,3),
np.arange(3,n_max, 3)
))

def func(arr):
idx = np.arange(arr[-1]).reshape((-1,1)) < arr -2
result = np.zeros(idx.shape)
result[idx] = 10.1
return result

@njit
def func2(arr):
M = arr[-1]
N = arr.shape[0]
result = np.zeros((M, N))
for i in range(M):
for j in range(N):
if i < arr[j] - 2:
result[i, j] = 10.1

return result

new_arr = func(n_arr)
new_arr2 = func2(n_arr)
print(np.allclose(new_arr, new_arr2)) # True

在我的机器上,使用您提供的示例输入,func2func 快大约 3.5 倍。

关于python - Numba nopython 模式不能接受二维 bool 索引,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57915632/

58 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com