I implemented FFT-based convolution in Pytorch and compared the result with spatial convolution via conv2d() function. The convolution filter used is an average filter. The conv2d() function produced smoothened output due to average filtering as expected but the fft-based convolution returned a more blurry output. I have attached the code and outputs here -
spatial convolution -
from PIL import Image, ImageOps
import torch
from matplotlib import pyplot as plt
from torchvision.transforms import ToTensor
import torch.nn.functional as F
import numpy as np
im = Image.open("/kaggle/input/tiger.jpg")
im = im.resize((256,256))
gray_im = im.convert('L')
gray_im = ToTensor()(gray_im)
gray_im = gray_im.squeeze()
fil = torch.tensor([[1/9,1/9,1/9],[1/9,1/9,1/9],[1/9,1/9,1/9]])
conv_gray_im = gray_im.unsqueeze(0).unsqueeze(0)
conv_fil = fil.unsqueeze(0).unsqueeze(0)
conv_op = F.conv2d(conv_gray_im,conv_fil)
conv_op = conv_op.squeeze()
plt.figure()
plt.imshow(conv_op, cmap='gray')
FFT-based convolution -
def fftshift(image):
sh = image.shape
x = np.arange(0, sh[2], 1)
y = np.arange(0, sh[3], 1)
xm, ym = np.meshgrid(x,y)
shifter = (-1)**(xm + ym)
shifter = torch.from_numpy(shifter)
return image*shifter
shift_im = fftshift(conv_gray_im)
padded_fil = F.pad(conv_fil, (0, gray_im.shape[0]-fil.shape[0], 0, gray_im.shape[1]-fil.shape[1]))
shift_fil = fftshift(padded_fil)
fft_shift_im = torch.rfft(shift_im, 2, onesided=False)
fft_shift_fil = torch.rfft(shift_fil, 2, onesided=False)
shift_prod = fft_shift_im*fft_shift_fil
shift_fft_conv = fftshift(torch.irfft(shift_prod, 2, onesided=False))
fft_op = shift_fft_conv.squeeze()
plt.figure('shifted fft')
plt.imshow(fft_op, cmap='gray')
original image -
spatial convolution output -
fft-based convolution output -
Could someone kindly explain the issue?



padded_fil? Please see minimal reproducible example! - Cris Luengo