1
votes

I would like to create a custom tf.keras.layers.Layer resembling the below function:

def conv_block(inputs, filters, kernel_size, strides=(1, 1, 1),
                 padding='valid', activation=True, block_name='conv3d'):

    with tf.name_scope(block_name):
      conv = Conv3D(filters=filters, kernel_size=kernel_size, strides=strides,
                    padding=padding, activation=None,
                    name='{}_conv'.format(block_name))(inputs)
      batch_norm = BatchNormalization(
          name='{}_batch_norm'.format(block_name))(conv)

      if activation:
        relu = ReLU(max_value=6, name='{}_relu'.format(block_name))(batch_norm)
        res_layer = relu
      else:
        res_layer = batch_norm
    return res_layer

I went through the documentation available here and here and subsequently I created the below class:

class ConvBlock(tf.keras.layers.Layer):

    def __init__(self, filters, kernel_size, strides=(1, 1, 1), padding='valid', activation=True, **kwargs):
        super(ConvBlock, self).__init__()
        self.filters = filters
        self.kernel_size = kernel_size
        self.strides = strides
        self.padding = padding
        self.activation = activation

        self.conv_1 = Conv3D(filters=self.filters, 
                             kernel_size=self.kernel_size, 
                             strides=self.strides, 
                             padding=self.padding, 
                             activation=None)

        self.batch_norm_1 = BatchNormalization()
        self.relu_1 = ReLU(max_value=6)

    def call(self, inputs):
        conv = self.conv_1(inputs)
        batch_norm = self.batch_norm_1(conv)

        if self.activation:
            relu = self.relu_1(batch_norm)
            return relu
        else:
            return batch_norm

I want to use this Layer several times throughout my model. I have several questions around this:

  1. The documentation mentions using add_weights() in the build() method. However would it be necessary in this case?
  2. Do I need to include a build()method at all?
  3. How do I get the output shape of the layer? The documentation mentions using the below function:

    def compute_output_shape(self, input_shape): shape = tf.TensorShape(input_shape).as_list() shape[-1] = self.output_dim return tf.TensorShape(shape)

How can I use this function to compute the shape of the output layer?

1

1 Answers

0
votes

maybe you can directly use a function to encapsulate your repetitive operations instead of subclassing layer, only if you think you need to play with weights or pattern of initialized weights use subclassing because that is the right method over the latter.

Example:

def simple_conv(x):
   x = Conv2d(x)
   x = Bathcnorm(x)
   return x