gpt4 book ai didi

python - 如何在PyTorch中正确使用Numpy的FFT函数?

转载 作者:行者123 更新时间:2023-11-30 09:50:34 25 4
gpt4 key购买 nike

我最近接触了 PyTorch,并开始浏览该库的文档和教程。在 "Creating extensions using numpy and scipy"教程中的“无参数示例”下,使用 numpy 创建了一个名为 BadFFTFunction 的示例函数。

该函数的描述指出:

"This layer doesn’t particularly do anything useful or mathematicallycorrect.

It is aptly named BadFFTFunction"

该函数及其用法如下:

from numpy.fft import rfft2, irfft2

class BadFFTFunction(Function):

def forward(self, input):
numpy_input = input.numpy()
result = abs(rfft2(numpy_input))
return torch.FloatTensor(result)

def backward(self, grad_output):
numpy_go = grad_output.numpy()
result = irfft2(numpy_go)
return torch.FloatTensor(result)

def incorrect_fft(input):
return BadFFTFunction()(input)

input = Variable(torch.randn(8, 8), requires_grad=True)
result = incorrect_fft(input)
print(result.data)
result.backward(torch.randn(result.size()))
print(input.grad)

不幸的是,我最近才接触到信号处理,并且不确定这个函数中的(可能是明显的)错误在哪里。

我想知道,如何修复这个函数,使其前向和后向输出正确?

如何修复 BadFFTFunction 以便在 PyTorch 中使用可微分的 FFT 函数?

最佳答案

我认为错误是:首先,该函数尽管名称中有 FFT,但仅返回 FFT 输出的幅度/绝对值,而不是完整的复系数。另外,仅使用逆 FFT 来计算幅度的梯度在数学上可能没有多大意义(?)。

有一个名为 pytorch-fft 的包试图在 pytorch 中提供 FFT 函数。您可以看到一些 autograd 功能的实验代码 here 。另请注意此issue中的讨论.

关于python - 如何在PyTorch中正确使用Numpy的FFT函数?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45746504/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com