Python torch.rfft-基于fft的卷积产生不同于空间卷积的输出

Python torch.rfft-基于fft的卷积产生不同于空间卷积的输出,python,image-processing,pytorch,fft,convolution,Python,Image Processing,Pytorch,Fft,Convolution,我在Pytorch中实现了基于FFT的卷积,并通过conv2d()函数将结果与空间卷积进行了比较。使用的卷积滤波器是平均滤波器。conv2d()函数根据预期的平均滤波产生平滑输出,但基于fft的卷积返回更模糊的输出。 我在这里附上了代码和输出- 空间卷积- from PIL import Image, ImageOps import torch from matplotlib import pyplot as plt from torchvision.transforms import ToTe

我在Pytorch中实现了基于FFT的卷积,并通过conv2d()函数将结果与空间卷积进行了比较。使用的卷积滤波器是平均滤波器。conv2d()函数根据预期的平均滤波产生平滑输出,但基于fft的卷积返回更模糊的输出。 我在这里附上了代码和输出-

空间卷积-

from PIL import Image, ImageOps
import torch
from matplotlib import pyplot as plt
from torchvision.transforms import ToTensor
import torch.nn.functional as F
import numpy as np

im = Image.open("/kaggle/input/tiger.jpg")
im = im.resize((256,256))
gray_im = im.convert('L') 
gray_im = ToTensor()(gray_im)
gray_im = gray_im.squeeze()

fil = torch.tensor([[1/9,1/9,1/9],[1/9,1/9,1/9],[1/9,1/9,1/9]])

conv_gray_im = gray_im.unsqueeze(0).unsqueeze(0)
conv_fil = fil.unsqueeze(0).unsqueeze(0)

conv_op = F.conv2d(conv_gray_im,conv_fil)

conv_op = conv_op.squeeze()

plt.figure()
plt.imshow(conv_op, cmap='gray')
def fftshift(image):
    sh = image.shape
    x = np.arange(0, sh[2], 1)
    y = np.arange(0, sh[3], 1)
    xm, ym  = np.meshgrid(x,y)
    shifter = (-1)**(xm + ym)
    shifter = torch.from_numpy(shifter)
    return image*shifter

shift_im = fftshift(conv_gray_im)
padded_fil = F.pad(conv_fil, (0, gray_im.shape[0]-fil.shape[0], 0, gray_im.shape[1]-fil.shape[1]))
shift_fil = fftshift(padded_fil)
fft_shift_im = torch.rfft(shift_im, 2, onesided=False)
fft_shift_fil = torch.rfft(shift_fil, 2, onesided=False)
shift_prod = fft_shift_im*fft_shift_fil
shift_fft_conv = fftshift(torch.irfft(shift_prod, 2, onesided=False))

fft_op = shift_fft_conv.squeeze()
plt.figure('shifted fft')
plt.imshow(fft_op, cmap='gray')
基于FFT的卷积-

from PIL import Image, ImageOps
import torch
from matplotlib import pyplot as plt
from torchvision.transforms import ToTensor
import torch.nn.functional as F
import numpy as np

im = Image.open("/kaggle/input/tiger.jpg")
im = im.resize((256,256))
gray_im = im.convert('L') 
gray_im = ToTensor()(gray_im)
gray_im = gray_im.squeeze()

fil = torch.tensor([[1/9,1/9,1/9],[1/9,1/9,1/9],[1/9,1/9,1/9]])

conv_gray_im = gray_im.unsqueeze(0).unsqueeze(0)
conv_fil = fil.unsqueeze(0).unsqueeze(0)

conv_op = F.conv2d(conv_gray_im,conv_fil)

conv_op = conv_op.squeeze()

plt.figure()
plt.imshow(conv_op, cmap='gray')
def fftshift(image):
    sh = image.shape
    x = np.arange(0, sh[2], 1)
    y = np.arange(0, sh[3], 1)
    xm, ym  = np.meshgrid(x,y)
    shifter = (-1)**(xm + ym)
    shifter = torch.from_numpy(shifter)
    return image*shifter

shift_im = fftshift(conv_gray_im)
padded_fil = F.pad(conv_fil, (0, gray_im.shape[0]-fil.shape[0], 0, gray_im.shape[1]-fil.shape[1]))
shift_fil = fftshift(padded_fil)
fft_shift_im = torch.rfft(shift_im, 2, onesided=False)
fft_shift_fil = torch.rfft(shift_fil, 2, onesided=False)
shift_prod = fft_shift_im*fft_shift_fil
shift_fft_conv = fftshift(torch.irfft(shift_prod, 2, onesided=False))

fft_op = shift_fft_conv.squeeze()
plt.figure('shifted fft')
plt.imshow(fft_op, cmap='gray')
原始图像-

空间卷积输出-

基于fft的卷积输出-


有人能解释一下这个问题吗?

你的代码的主要问题是Torch不做复数,它的FFT输出是一个3D数组,第三维有两个值,一个用于实部,一个用于虚部。因此,乘法不会进行复数乘法

目前Torch中没有定义复数乘法(请参阅),我们必须定义自己的乘法


如果要比较这两个卷积运算,下面是一个小问题,但也很重要:

FFT在第一个元素(图像的左上像素)中获取其输入的原点。为了避免输出移位,您需要生成一个填充内核,其中内核的原点是左上角的像素。事实上,这很棘手

您当前的代码:

fil=火炬张量([[1/9,1/9,1/9],[1/9,1/9,1/9],[1/9,1/9,1/9])
conv_fil=fil.unsqueze(0)。unsqueze(0)
填充填充=F.pad(转换填充,(0,灰色填充形状[0]-填充形状[0],0,灰色填充形状[1]-填充形状[1]))
生成一个填充内核,其中原点是像素(1,1),而不是(0,0)。它需要在每个方向上移动一个像素。NumPy有一个功能
roll
,对这一点很有用,我不知道火炬的等价物(我对火炬一点都不熟悉)。这应该起作用:

fil=火炬张量([[1/9,1/9,1/9],[1/9,1/9,1/9],[1/9,1/9,1/9])
padded_fil=fil.unsqueze(0).unsqueze(0).numpy()
填充薄膜=np.pad(填充薄膜,((0,灰色薄膜形状[0]-薄膜形状[0]),(0,灰色薄膜形状[1]-薄膜形状[1]))
填充薄膜=np卷(填充薄膜,-1,轴=(0,1))
填充填充=火炬。从填充填充(填充填充)

<> P>最后,应用到空间域图像中的<<代码> fftSHIFT 函数,使频域图像(应用于图像的FFT的结果)移位,使得原点位于图像的中间,而不是左上角。当查看FFT的输出时,该移位是有用的,但在计算卷积时,该移位是无意义的


把这些东西放在一起,卷积现在是:

def复数乘法(t1,t2):
real1,imag1=t1[:,:,0],t1[:,:,1]
real2,imag2=t2[:,:,0],t2[:,:,1]
返回火炬堆栈([real1*real2-imag1*imag2,real1*imag2+imag1*real2],dim=-1)
fft\u im=火炬.rfft(灰色\u im,2,单面=假)
fft_fil=火炬.rfft(填充_fil,2,单侧=假)
fft\u conv=torch.irfft(复数乘法(fft\u im,fft\u fil),2,单边=False)
请注意,您可以执行单边FFT以节省一点计算时间:

fft\u im=torch.rfft(灰色\u im,2,单面=真)
fft_fil=火炬.rfft(填充_fil,2,单侧=真)
fft\u conv=torch.irfft(复数乘法(fft\u im,fft\u fil),2,单边=True,信号大小=gray\u im.shape)

这里的频域大约是完整FFT的一半大小,但只剩下多余的部分。卷积的结果不变。

如何生成填充的fil?请看!哦,对不起,没听到那句话。我已经更新了代码。谢谢你的回答,但是我打印了填充内核,它的第一个值是(0,0),所以在执行np.roll之后,它将一些值移动到图像的最后一列和最后一行。所以我认为内核没有任何问题。使用这个左滚动内核的代码的fft输出是一些半倒置和半直立的混合图像。此外,我首先在没有fftshift函数的情况下进行了基于fft的卷积运算,它给出了与问题中所示相同的额外模糊输出,但却是一个反向输出(180度)。所以我做了fftshift部分,至少得到了一个垂直的输出。@psj:内核的原点是它的中心,如果你定义它,否则你会看到一个偏移的输出。将内核的中心置于(0,0)会导致内核的一部分(在本例中为1个像素)出现在图像的右端和底端。对于FFT,图像是周期性的。@psj:好的,我安装了Torch来了解发生了什么。事实证明,Torch不理解复数,这使得提供FFT变得毫无意义。我已经用工作代码更新了这个答案,但是如果有一个更好的工具,所有这些都会容易得多。直接使用NumPy,或者任何真正的图像处理软件包。谢谢,这很有效!IFFT输出图像现在看起来与conv输出类似。但是,当我打印两个矩阵时——conv_op和fft_conv(我尝试了两次裁剪fft_conv,以获得等于“有效”卷积的输出——一次从中间,一次从左上角),它们似乎并不相等——甚至不在一个小的误差范围内。这是两种方法之间的近似误差吗?另外,我应该从IFFT输出中选择哪种裁剪-我尝试了中间和左上角,但不知道逻辑上应该选择哪一种。@psj:执行的计算非常不同,因此结果在数值上会有所不同。这些差异将向图像边缘增加,其中FFT卷积与空域卷积有所不同。您应该显示两个图像之间的差异
plt.imshow(conv\u op-fft\u conv)
。这应该显示没有原始图像任何细节的图像,只是(结构化)噪声。