用pytorch验证卷积定理

用pytorch验证卷积定理,pytorch,fft,convolution,theorem-proving,Pytorch,Fft,Convolution,Theorem Proving,基本上,这个定理的公式如下: F(F*g)=F(F)xF(g) 我知道这个定理,但我就是不能用pytorch来重现这个结果 以下是可复制的代码: import torch import torch.nn.functional as F # calculate f*g f = torch.ones((1,1,5,5)) g = torch.tensor(list(range(9))).view(1,1,3,3).float() conv = F.conv2d(f, g, bias=None, p

基本上,这个定理的公式如下:

F(F*g)=F(F)xF(g)

我知道这个定理,但我就是不能用pytorch来重现这个结果

以下是可复制的代码:

import torch
import torch.nn.functional as F

# calculate f*g
f = torch.ones((1,1,5,5))
g = torch.tensor(list(range(9))).view(1,1,3,3).float()
conv = F.conv2d(f, g, bias=None, padding=2)

# calculate F(f*g)
F_fg = torch.rfft(conv, signal_ndim=2, onesided=False)

# calculate F x G
f = f.squeeze()
g = g.squeeze()

# need to pad into at least [w1+w2-1, h1+h2-1], which is 7 in our case.
size = f.size(0) + g.size(0) - 1 

f_new = torch.zeros((7,7))
g_new = torch.zeros((7,7))

f_new[1:6,1:6] = f
g_new[2:5,2:5] = g

F_f = torch.rfft(f_new, signal_ndim=2, onesided=False)
F_g = torch.rfft(g_new, signal_ndim=2, onesided=False)
FxG = torch.mul(F_f, F_g)

print(FxG - F_fg)
以下是打印的结果(FxG-F_fg)

你可以看到,差异并不总是0

有人能告诉我为什么以及如何正确地做到这一点吗


谢谢

所以我仔细看了一下你到目前为止所做的工作。我已经确定了代码中的三个错误源。在这里,我将尽力充分阐述每一个问题

1.复数运算 PyTorch目前不支持复数乘法(AFAIK)。FFT运算只是返回一个实维数和虚维数的张量。我们需要显式地对复数乘法进行编码,而不是使用
torch.mul
*
运算符

(a+ib)*(c+id)=(a*c-b*d)+i(a*d+b*c)

2.卷积的定义 CNN文献中经常使用的“卷积”定义实际上与讨论卷积定理时使用的定义不同。我不会详细介绍,但是在滑动和乘法之前,内核会翻转。相反,pytorch、tensorflow、caffe等中的卷积运算。。。不做这种翻转

为了说明这一点,我们可以在应用FFT之前简单地翻转
g
(水平和垂直)

3.锚位 当使用卷积定理时,假设锚定点是填充的
g
的左上角。再说一次,我不会详细讨论这个问题,但它是如何计算出来的


第二点和第三点通过一个例子可能更容易理解。假设您使用了以下
g

[1 2 3]
[4 5 6]
[7 8 9]
而不是新的

[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[0 0 1 2 3 0 0]
[0 0 4 5 6 0 0]
[0 0 7 8 9 0 0]
[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
事实上应该是这样

[5 4 0 0 0 0 6]
[2 1 0 0 0 0 3]
[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[8 7 0 0 0 0 9]
我们垂直和水平翻转内核,然后应用循环移位,使内核的中心位于左上角


我最终重写了你的大部分代码,并对其进行了一些概括。最复杂的操作是正确定义
g_new
。我决定使用网格和模运算来同时翻转和移动索引。如果这里有什么对你没有意义,请留下评论,我会尽力澄清

导入火炬
导入torch.nn.功能为F
def conv2d_pyt(f,g):
断言len(f.size())==2
断言len(g.size())==2
f_new=f.unsqueze(0)。unsqueze(0)
g_new=g.unsqueze(0)。unsqueze(0)
pad_y=(g.size(0)-1)//2
pad_x=(g.size(1)-1)//2
fcg=F.conv2d(F_new,g_new,bias=None,padding=(pad_y,pad_x))
返回fcg[0,0,:,:]
def conv2d_fft(f,g):
断言len(f.size())==2
断言len(g.size())==2
#一般来说,输入不一定是奇数形状的,但会使生活更轻松
断言f.size(0)%2==1
断言f.size(1)%2==1
断言g.size(0)%2==1
断言g.size(1)%2==1
尺寸y=f.尺寸(0)+g.尺寸(0)-1
尺寸x=f.尺寸(1)+g.尺寸(1)-1
f_new=火炬零点((尺寸_y,尺寸_x))
g_new=火炬零点((尺寸_y,尺寸_x))
#将f复制到中心
f_pad_y=(f_new.size(0)-f.size(0))//2
f_pad_x=(f_new.size(1)-f.size(1))//2
f_new[f_pad_y:-f_pad_y,f_pad_x:-f_pad_x]=f
#g的锚定为0,0(翻转g并包裹圆形)
g_中心_y=g.size(0)//2
g_中心_x=g.size(1)//2
g_y,g_x=火炬网网格(火炬网网格(g.size(0)),火炬网网格(g.size(1)))
g_new_y=(g_y.flip(0)-g_center_y)%g_new.size(0)
g_new_x=(g_x.flip(1)-g_center_x)%g_new.size(1)
g_new[g_new_y,g_new_x]=g[g_y,g_x]
#对f和g进行fft
F_F=torch.rfft(F_new,signal_ndim=2,单边=False)
F_g=torch.rfft(g_new,signal_ndim=2,单边=False)
#复数乘法
FxG_real=F_F[:,:,0]*F_g[:,:,0]-F_F[:,:,1]*F_g[:,:,1]
FxG_imag=F_F[:,:,0]*F_g[:,:,1]+F_F[:,:,1]*F_g[:,:,0]
FxG=torch.stack([FxG_real,FxG_imag],dim=2)
#逆fft
fcg=火炬。irfft(FxG,信号=2,单侧=假)
#返回前的作物中心
返回fcg[f_pad_y:-f_pad_y,f_pad_x:-f_pad_x]
#计算f*g
f=火炬。随机数(11,7)
g=火炬。随机数(5,3)
fcg_pyt=conv2d_pyt(f,g)
fcg_fft=conv2d_fft(f,g)
平均差值=火炬平均值(火炬绝对值(fcg_pyt-fcg_fft))。项目()
打印(“平均差异:”,平均差异)
这让我

Average difference: 4.6866085767760524e-07

这非常接近于零。我们得不到精确零的原因仅仅是由于浮点错误。

CNN文献中所谓的“卷积”实际上是信号处理行话中的相关滤波。基本上,内核在CNN中滑动和乘法之前不会翻转。尝试
F_g=torch.rfft(g_new.flip(0).flip(1),…
,这将使您更接近结果。由于DFT假设信号是周期性的(傅里叶变换离散化所必需的),因此可能也存在一些填充差异。我稍后将对此进行验证。什么是了解更多信息的好资源,尤其是内核的循环移位和零填充?@Kiran 1)信号在频率上是离散的,当且仅当其在时间上是周期性的(类似地,信号在时间上是离散的,当且仅当其在频率上是周期性的)。因此,从离散时间到离散频率的DFT假设信号在时间上是周期性的。这就解释了为什么所有的转换都是循环的。2) 锚定位置是由于将DFT的输入解释为从t=0开始的信号的单个周期的约定。你可以在任何DSP书籍中学习这一点,我最喜欢的是奥本海姆的《离散时间信号处理》。
Average difference: 4.6866085767760524e-07