I currently investigate the development of a convolutional neural network involving up to 5 or 6 dimensional arrays efficiently.
I was aware that many of the tools used for convolutional neural networks do not really deal with ND convolutions, so I decided to try and write an implementation of Helix Convolution, whereby the convolution can be treated as a large, 1D convolution (see Reference 1. http://sepwww.stanford.edu/public/docs/sep95/jon1/paper_html/node2.html , Reference 2 https://sites.ualberta.ca/~mostafan/Files/Papers/md_convolution_TLE2009.pdf for more details of the concept).
I did this under the (possibly incorrect) assumption that a large, single dimensional convolution was likely to be easier on a GPU than a multidimensional one, as well as that the method is trivially scalable to N dimensions.
Particularly, a quote from Reference 2. states:
We have not found important gains in computational efficiency between N-D standard convolution versus using the algorithm described in the text. We have, however, found that writing codes for seismic data regularization with the described trick leads to algorithms that can easily handle regularization problems with any number of spatial dimensions (Naghizadeh and Sacchi, 2009).
I have written an implementation of the function below, which compares to signal.fftconvolve
. It is slower on the CPU compared to this function, but I would nonetheless like to see how it performs on the GPU in PyTorch as a forward convolutional layer.
Can someone kindly help me port this code to PyTorch so I can verify how it behaves?
"""
HELIX CONVOLUTION FUNCTION
Shrink:
CROPS THE SIZE OF THE CONVOLVED SIGNAL DOWN TO THE ORIGINAL SIZE OF THE ORIGINAL.
Pad:
PADS THE DIFFERENCE BETWEEN THE ORIGINAL SHAPE AND THE DESIRED, CONVOLVED SHAPE FOR KERNEL AND SIGNAL.
GetLength:
EXTRACTS THE LENGTH OF THE UNWOUND STRIP OF THE SIGNAL AND KERNEL THAT IS TO BE CONVOLVED.
FFTConvolve:
USES THE NUMPY FFT PACKAGE TO PERFORM FAST FOURIER CONVOLUTION ON THE SIGNALS
Convolve:
USES HELIX CONVOLUTION ON AN INPUT ARRAY AND KERNEL.
"""
import numpy as np
from numpy import *
from scipy import signal
import operator
import time
class HelixCPU:
@classmethod
def Shrink(cls,array, bounding):
start = tuple(map(lambda a, da: (a-da)//2, array.shape, bounding))
end = tuple(map(operator.add, start, bounding))
slices = tuple(map(slice, start, end))
return array[slices]
@classmethod
def Pad(cls,array, target_shape):
diff = target_shape-array.shape
padder=[(0,val) for val in diff]
padded = np.pad(array, padder, 'constant')
return padded
@classmethod
def GetLength(cls,array_shape, padded_shape):
temp=1
steps=np.zeros_like(array_shape)
for i, entry in enumerate(padded_shape[::-1]):
if(i==len(padded_shape)-1):
steps[i]=1
else:
temp=entry*temp
steps[i]=temp
steps=np.roll(steps, 1)
steps=steps[::-1]
ones=np.ones_like(array_shape)
ones[-1]=0
out=np.multiply(steps,array_shape - ones)
length = np.sum(out)
return length
@classmethod
def FFTConvolve(cls, in1, in2, len1, len2):
s1 = len1
s2 = len2
shape = s1 + s2 - 1
fsize = 2 ** np.ceil(cp.log2(shape)).astype(int)
fslice = slice(0, shape)
conv = np.fft.ifft(np.fft.fft(in1, int(fsize)) * np.fft.fft(in2, int(fsize)))[fslice].copy()
return conv
@classmethod
def Convolve(cls,array, kernel):
m = array.shape
n = kernel.shape
mn = np.add(m, n)
mn = mn-np.ones_like(mn)
k_pad=cls.Pad(kernel, mn)
a_pad=cls.Pad(array, mn)
length_k = cls.GetLength(kernel.shape, k_pad.shape);
length_a = cls.GetLength(array.shape, a_pad.shape);
k_flat = k_pad.flatten()[0:length_k]
a_flat = a_pad.flatten()[0:length_a]
conv = cls.FFTConvolve(a_flat, k_flat)
conv = np.resize(conv,mn)
conv = cls.Shrink(conv, m)
return conv
def main():
array=np.random.rand(25,25,41,51)
kernel=np.random.rand(10, 10, 10, 10)
start2 =time.process_time()
test2 = HelixCPU.Convolve(array, kernel)
end2=time.process_time()
start1= time.process_time()
test1 = signal.fftconvolve(array, kernel, "same")
end1= time.process_time()
print ("")
print ("========================")
print ("SOME LARGE CONVOLVED RANDOM ARRAYS. ")
print ("========================")
print("")
print ("Random Calorimeter Image of Size {0} Created".format(array.shape))
print ("Random Kernel of Size {0} Created".format(kernel.shape))
print("")
print ("Value\tOriginal\tHelix")
print ("Time Taken [s]\t{0}\t{1}\t{2}".format( (end1-start1), (end2-start2), (end2-start2)/(end1-start1) ))
print ("Maximum Value\t{:03.2f}\t{:13.2f}".format( np.max(test1), np.max(test2) ))
print ("Matrix Norm \t{:03.2f}\t{:13.2f}".format( np.linalg.norm(test1), np.linalg.norm(test2) ))
print ("All Close?\t{0}".format(np.allclose(test1, test2)))