I'm attempting to implement a Fourier Convolutional Neural Network using tf.keras, where the input and kernel are transformed to the frequency domain, element-wise multiplication is performed, and then the output is inverse-transformed and cropped. The model summary shows there are no trainable parameters for the kernel in my FConv2D layer, even though I declare them using self.add_weight
. There should be (3*3*in_channels*no_of_kernels) parameters.
class FConv2D(tf.keras.layers.Layer):
def __init__(self, no_of_kernels, **kwargs):
self.no_of_kernels = no_of_kernels
super(FConv2D, self).__init__(**kwargs)
def build(self, input_shape):
self.kernel_shape = (3, 3 , int(input_shape[3]), self.no_of_kernels)
print(input_shape, self.kernel_shape)
self.kernel = self.add_weight(shape=(3,3, input_shape[-1], self.no_of_kernels),
initializer='random_normal',
trainable=True)
self.bias = self.add_weight(shape=(self.no_of_kernels,),
initializer='random_normal',
trainable=True)
super(FConv2D, self).build(input_shape)
def call(self, x):
crop_size = self.kernel.get_shape().as_list()[0] // 2
shape = x.get_shape().as_list()[1] + self.kernel.get_shape().as_list()[0] - 1
x = tf.transpose(x, perm=[0,3,1,2])
self.kernel = tf.transpose(self.kernel, perm=[3,2,0,1])
x = tf.signal.rfft2d(x, [shape, shape])
self.kernel = tf.signal.rfft2d(self.kernel, [shape, shape])
x = tf.einsum('imkl,jmkl->ijkl', x, self.kernel)
x = tf.signal.irfft2d(x, [shape, shape])
x = tf.transpose(x, perm=[0,2,3,1])
x = tf.nn.bias_add(x, self.bias)[:,crop_size:-1*crop_size, crop_size:-1*crop_size, :]
x = tf.nn.elu(x)
return x
When I build the model, it shows trainable parameters only for the bias term, not for the kernels.
m = tf.keras.models.Sequential()
m.add(FConv2D(32, input_shape=(32,32,3)))
m.summary()
Output:
Model: "sequential_37"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
f_conv2d_88 (FConv2D) (None, 32, 32, 32) 32
=================================================================
Total params: 32
Trainable params: 32
Non-trainable params: 0
The problem seems to be in call(self, x)
, because if I replace the Fourier convolution operation with a call to tf.nn.conv2d
, the expected number of parameters is listed (3*3*3*32+32=896).
I've confirmed that the parameters are not trainable by eliminating the bias term, and calling model.fit
, which does not run because there are no parameters to train.
What am I missing? Is Keras not able to have these complex operations inside a custom layer?