gpt4 book ai didi

python - 如何在 numba.njit 中进行离散傅立叶变换 (FFT)?

转载 作者:行者123 更新时间:2023-12-03 23:46:49 25 4
gpt4 key购买 nike

各位程序员大家好

我正在尝试制作 discrete Fourier transform在此 minimal working examplenumba.njit装饰器:

import numba
import numpy as np
import scipy
import scipy.fftpack

@numba.njit
def main():
wave = [[[0.09254795, 0.10001078, 0.10744892, 0.07755555, 0.08506225, 0.09254795],
[0.09907245, 0.10706145, 0.11502401, 0.08302302, 0.09105898, 0.09907245],
[0.09565098, 0.10336405, 0.11105158, 0.08015589, 0.08791429, 0.09565098],
[0.00181467, 0.001961, 0.00210684, 0.0015207, 0.00166789, 0.00181467]],
[[-0.45816267, - 0.46058367, - 0.46289091, - 0.45298182, - 0.45562851, -0.45816267],
[-0.49046506, - 0.49305676, - 0.49552669, - 0.48491893, - 0.48775223, -0.49046506],
[-0.47352483, - 0.47602701, - 0.47841162, - 0.46817027, - 0.4709057, -0.47352483],
[-0.00898358, - 0.00903105, - 0.00907629, - 0.008882, - 0.00893389, -0.00898358]],
[[0.36561472, 0.36057289, 0.355442, 0.37542627, 0.37056626, 0.36561472],
[0.39139261, 0.38599531, 0.38050268, 0.40189591, 0.39669325, 0.39139261],
[0.37787385, 0.37266296, 0.36736003, 0.38801438, 0.38299141, 0.37787385],
[0.00716892, 0.00707006, 0.00696945, 0.0073613, 0.00726601, 0.00716892]]]

new_fft = scipy.fftpack.fft(wave)


if __name__ == '__main__':
main()

输出:
C:\Users\Artur\Anaconda\python.exe C:/Users/Artur/Desktop/RL_framework/help_functions/test2.py
Traceback (most recent call last):
File "C:/Users/Artur/Desktop/RL_framework/help_functions/test2.py", line 25, in <module>
main()
File "C:\Users\Artur\Anaconda\lib\site-packages\numba\core\dispatcher.py", line 401, in _compile_for_args
error_rewrite(e, 'typing')
File "C:\Users\Artur\Anaconda\lib\site-packages\numba\core\dispatcher.py", line 344, in error_rewrite
reraise(type(e), e, None)
File "C:\Users\Artur\Anaconda\lib\site-packages\numba\core\utils.py", line 80, in reraise
raise value.with_traceback(tb)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Unknown attribute 'fft' of type Module(<module 'scipy.fftpack' from 'C:\\Users\\Artur\\Anaconda\\lib\\site-packages\\scipy\\fftpack\\__init__.py'>)

File "test2.py", line 21:
def main():
<source elided>

new_fft = scipy.fftpack.fft(wave)
^

[1] During: typing of get attribute at C:/Users/Artur/Desktop/RL_framework/help_functions/test2.py (21)

File "test2.py", line 21:
def main():
<source elided>

new_fft = scipy.fftpack.fft(wave)
^


Process finished with exit code 1

不幸的是 scipy.fftpack.fft似乎是 numba 不支持的遗留功能.所以我寻找替代品。我找到了两个:

1. scipy.fft(wave)这是上述遗留功能的更新版本。它产生此错误输出:
C:\Users\Artur\Anaconda\python.exe C:/Users/Artur/Desktop/RL_framework/help_functions/test2.py
Traceback (most recent call last):
File "C:/Users/Artur/Desktop/RL_framework/help_functions/test2.py", line 25, in <module>
main()
File "C:\Users\Artur\Anaconda\lib\site-packages\numba\core\dispatcher.py", line 401, in _compile_for_args
error_rewrite(e, 'typing')
File "C:\Users\Artur\Anaconda\lib\site-packages\numba\core\dispatcher.py", line 344, in error_rewrite
reraise(type(e), e, None)
File "C:\Users\Artur\Anaconda\lib\site-packages\numba\core\utils.py", line 80, in reraise
raise value.with_traceback(tb)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Module(<module 'scipy.fft' from 'C:\\Users\\Artur\\Anaconda\\lib\\site-packages\\scipy\\fft\\__init__.py'>) with parameters (list(list(list(float64))))
No type info available for Module(<module 'scipy.fft' from 'C:\\Users\\Artur\\Anaconda\\lib\\site-packages\\scipy\\fft\\__init__.py'>) as a callable.
[1] During: resolving callee type: Module(<module 'scipy.fft' from 'C:\\Users\\Artur\\Anaconda\\lib\\site-packages\\scipy\\fft\\__init__.py'>)
[2] During: typing of call at C:/Users/Artur/Desktop/RL_framework/help_functions/test2.py (21)


File "test2.py", line 21:
def main():
<source elided>

new_fft = scipy.fft(wave)
^


Process finished with exit code 1

2. np.fft.fft(wave)这似乎得到支持,但也会产生错误:
C:\Users\Artur\Anaconda\python.exe C:/Users/Artur/Desktop/RL_framework/help_functions/test2.py
Traceback (most recent call last):
File "C:/Users/Artur/Desktop/RL_framework/help_functions/test2.py", line 25, in <module>
main()
File "C:\Users\Artur\Anaconda\lib\site-packages\numba\core\dispatcher.py", line 401, in _compile_for_args
error_rewrite(e, 'typing')
File "C:\Users\Artur\Anaconda\lib\site-packages\numba\core\dispatcher.py", line 344, in error_rewrite
reraise(type(e), e, None)
File "C:\Users\Artur\Anaconda\lib\site-packages\numba\core\utils.py", line 80, in reraise
raise value.with_traceback(tb)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Unknown attribute 'fft' of type Module(<module 'numpy.fft' from 'C:\\Users\\Artur\\Anaconda\\lib\\site-packages\\numpy\\fft\\__init__.py'>)

File "test2.py", line 21:
def main():
<source elided>

new_fft = np.fft.fft(wave)
^

[1] During: typing of get attribute at C:/Users/Artur/Desktop/RL_framework/help_functions/test2.py (21)

File "test2.py", line 21:
def main():
<source elided>

new_fft = np.fft.fft(wave)
^


Process finished with exit code 1

你知道 fft函数,适用于 numba.njit装饰师?

最佳答案

如果您对一维 DFT 感到满意,您不妨使用 FFT。
在这里,报告了一个 Numba 友好的实现 fft_1d()处理任意输入大小:

import cmath
import numpy as np
import numba as nb


@nb.jit
def ilog2(n):
result = -1
if n < 0:
n = -n
while n > 0:
n >>= 1
result += 1
return result


@nb.njit(fastmath=True)
def reverse_bits(val, width):
result = 0
for _ in range(width):
result = (result << 1) | (val & 1)
val >>= 1
return result


@nb.njit(fastmath=True)
def fft_1d_radix2_rbi(arr, direct=True):
arr = np.asarray(arr, dtype=np.complex128)
n = len(arr)
levels = ilog2(n)
e_arr = np.empty_like(arr)
coeff = (-2j if direct else 2j) * cmath.pi / n
for i in range(n):
e_arr[i] = cmath.exp(coeff * i)
result = np.empty_like(arr)
for i in range(n):
result[i] = arr[reverse_bits(i, levels)]
# Radix-2 decimation-in-time FFT
size = 2
while size <= n:
half_size = size // 2
step = n // size
for i in range(0, n, size):
k = 0
for j in range(i, i + half_size):
temp = result[j + half_size] * e_arr[k]
result[j + half_size] = result[j] - temp
result[j] += temp
k += step
size *= 2
return result


@nb.njit(fastmath=True)
def fft_1d_arb(arr, fft_1d_r2=fft_1d_radix2_rbi):
"""1D FFT for arbitrary inputs using chirp z-transform"""
arr = np.asarray(arr, dtype=np.complex128)
n = len(arr)
m = 1 << (ilog2(n) + 2)
e_arr = np.empty(n, dtype=np.complex128)
for i in range(n):
e_arr[i] = cmath.exp(-1j * cmath.pi * (i * i) / n)
result = np.zeros(m, dtype=np.complex128)
result[:n] = arr * e_arr
coeff = np.zeros_like(result)
coeff[:n] = e_arr.conjugate()
coeff[-n + 1:] = e_arr[:0:-1].conjugate()
return fft_convolve(result, coeff, fft_1d_r2)[:n] * e_arr / m


@nb.njit(fastmath=True)
def fft_convolve(a_arr, b_arr, fft_1d_r2=fft_1d_radix2_rbi):
return fft_1d_r2(fft_1d_r2(a_arr) * fft_1d_r2(b_arr), False)


@nb.njit(fastmath=True)
def fft_1d(arr):
n = len(arr)
if not n & (n - 1):
return fft_1d_radix2_rbi(arr)
else:
return fft_1d_arb(arr)

与朴素的 DFT 算法( dft_1d()this 基本相同)相比,您获得了数量级的增长,而您通常仍然比 np.fft.fft() 慢很多.

vs_dft

相对速度因输入大小而异。
对于 2 的幂输入,这通常在 np.fft.fft() 的一个数量级内.

pow2

对于非 2 的幂,这通常在 np.fft.fft() 的两个数量级内.

not-pow2

对于最坏的情况(质数左右,这里是 2 + 1 的幂),这是 np.fft.fft() 的两倍.

primes

FFT 时序的非线性行为是由于需要更复杂的算法来处理不是 2 的幂的任意输入大小。这会影响此实现和来自 np.fft.fft() 的实现。 ,但是 np.fft.fft()包含更多优化,使其平均表现更好。

显示了 2 次幂 FFT 的替代实现 here .

关于python - 如何在 numba.njit 中进行离散傅立叶变换 (FFT)?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62213330/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com