gpt4 book ai didi

numpy - 使用 numpy 进行 numba 编译有什么问题?

转载 作者:行者123 更新时间:2023-12-04 01:14:17 30 4
gpt4 key购买 nike

我无法编译这段代码:

import numpy as np
import numba
from numba import jit, float64, complex128
import math

@jit(complex128[:](float64,float64[:],float64))
def GaborWavelet(omega, t, Gabor_coef):

c1 = 0.3251520240633*math.sqrt(omega)
c2 = -0.5*Gabor_coef
c3 = omega*0.187390625129278

res = np.array(c2*(t * c3)**2, dtype = np.complex128)

res.imag = omega*t

return c1*np.exp(res)

它提出:

由于函数“GaborWavelet”因以下原因导致类型推断失败,因此编译正在回退到对象模式并启用循环提升:未找到用于签名的函数 Function() 的实现:

数组(数组(float64, 1d, C), dtype=class(complex128))

有 2 个候选实现:- 其中 2 个不匹配,原因是:函数“array”中的重载:文件:numba\core\typing\npydecl.py:第 504 行。使用参数:'(array(float64, 1d, C), dtype=class(complex128))':由于实现引发了特定错误而被拒绝:TypingError: array(float64, 1d, C) not allowed in a homogeneous sequence

res = np.array(c2*(t * c3)**2, dtype = np.complex128)
^

我做错了什么?

如何编译这段代码(里面有numpy方法)?

最佳答案

Numba 不支持您使用的两个东西,但支持等效选项:

  1. 通过np.array(arr, dtype=type) 进行类型转换。而是使用 arr.astype(type)

  2. 为复杂数据类型设置 arr.imag=values。而是使用 arr += values*1j

以下代码在我的机器上运行并且应该产生相同的结果:

import numpy as np
import numba
from numba import jit, float64, complex128
import math

@jit(complex128[:](float64,float64[:],float64))
def GaborWavelet(omega, t, Gabor_coef):

c1 = 0.3251520240633*math.sqrt(omega)
c2 = -0.5*Gabor_coef
c3 = omega*0.187390625129278

res = (c2*(t * c3)**2).astype(np.complex128)
res += omega*t*1j

return c1*np.exp(res)

关于numpy - 使用 numpy 进行 numba 编译有什么问题?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/63845350/

30 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com