gpt4 book ai didi

python - 无法使用带有 *args 参数的函数的 jit 编译函数

转载 作者:行者123 更新时间:2023-12-01 07:02:28 25 4
gpt4 key购买 nike

我正在尝试编译一个接受 numpy 数组和元组的函数使用 numba 的 *arg 形式的参数。

import numba as nb
import numpy as np

@nb.njit(cache=True)
def myfunc(t, *p):
val = 0
for j in range(0, len(p), 2):
val += p[j]*np.exp(-p[j+1]*t)
return val

T = np.arange(12)
pars = (1.0, 2.0, 3.0, 4.0)
mfunc = myfunc(T, *pars)

但是我得到了这个结果

In [1]: run numba_test.py                                                                                                                                                                  
---------------------------------------------------------------------------
TypingError Traceback (most recent call last)
~/Programs/my-python/numba_test.py in <module>
12
13 T = np.arange(12)
---> 14 mfunc = myfunc(T, 1.0, 2.0, 3.0, 4.0)

...
...
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<built-in function iadd>) with argument(s) of type(s): (Literal[int](0), array(float64, 1d, C))
Known signatures:
* (int64, int64) -> int64
* (int64, uint64) -> int64
* (uint64, int64) -> int64
* (uint64, uint64) -> uint64
* (float32, float32) -> float32
* (float64, float64) -> float64
* (complex64, complex64) -> complex64
* (complex128, complex128) -> complex128
* parameterized
In definition 0:
All templates rejected with literals.
...
...
All templates rejected without literals.
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: typing of intrinsic-call at /home/cshugert/Programs/my-python/numba_test.py (9)

File "numba_test.py", line 9:
def myfunc(t, *p):
<source elided>
for j in range(0, len(p), 2):
val += p[j]*np.exp(-p[j+1]*t)
^

Numba 确实支持使用元组,所以我认为可能有我可以在 jit 编译器中添加一些签名。但是,我不确定到底该放什么。难道是numba编译器的情况无法处理以 *args 作为参数的函数?我可以做些什么以使我的功能能够正常工作吗?

最佳答案

我们再看一下错误信息

TypingError: Failed in nopython mode pipeline (step: nopython frontend)                                                                                                                    
Invalid use of Function(<built-in function iadd>) with argument(s)
of type(s): (Literal[int](0), array(float64, 1d, C))
Known signatures:
* (int64, int64) -> int64
* (int64, uint64) -> int64
* (uint64, int64) -> int64
* (uint64, uint64) -> uint64
* (float32, float32) -> float32
* (float64, float64) -> float64
* (complex64, complex64) -> complex64
* (complex128, complex128) -> complex128
* parameterized

错误是 <built-in function iadd> ,即+ .如果您查看该错误,则不是由于 *args 的传递所致但由于以下声明:

val += p[j]*np.exp(-p[j+1]*t)

基本上是 + 的所有兼容类型提到,它不支持添加 integerarray (有关详细信息,请参阅错误消息和已知签名)。

您可以通过设置 val 来解决此问题使用 np.zeros 作为零数组(参见文档 here )。

import numba as nb
import numpy as np

@nb.njit
def myfunc(t, *p):
val = np.zeros(12) #<------------ Set it as an array of zeros
for j in range(0, len(p), 2):
val += p[j]*np.exp(-p[j+1]*t)
return val

T = np.arange(12)
pars = (1.0, 2.0, 3.0, 4.0)
mfunc_val = myfunc(T, *pars)

您可以在 this Google Colab notebook 中查看代码.

关于python - 无法使用带有 *args 参数的函数的 jit 编译函数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58576591/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com