2
votes

I have a problem or two with the input dimensions of modified U-Net architecture. In order to save your time and better understand/reproduce my results, I'll post the code and the output dimensions. The modified U-Net architecture is the MultiResUNet architecture from https://github.com/nibtehaz/MultiResUNet/blob/master/MultiResUNet.py. and is based on this paper https://arxiv.org/abs/1902.04049 Please Don't be turned off by the length of this code. You can simply copy-paste it and it shouldn't take longer than 10 seconds to reproduce my results. Also you don't need a dataset for this. Tested with TF.v1.9 Keras v.2.20.

import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Conv2DTranspose, concatenate, BatchNormalization, Activation, add
from tensorflow.keras.models import Model
from tensorflow.keras.activations import relu 

###{ 2D Convolutional layers

   # Arguments: ######################################################################
   #     x {keras layer} -- input layer                                   #
   #     filters {int} -- number of filters                                        #
   #     num_row {int} -- number of rows in filters                               #
   #     num_col {int} -- number of columns in filters                           #

    # Keyword Arguments:
   #     padding {str} -- mode of padding (default: {'same'})
  #      strides {tuple} -- stride of convolution operation (default: {(1, 1)})
 #       activation {str} -- activation function (default: {'relu'})
#        name {str} -- name of the layer (default: {None})

  #  Returns:
  #          [keras layer] -- [output layer]}

      # # ############################################################################


def conv2d_bn(x, filters ,num_row,num_col, padding = "same", strides = (1,1), activation = 'relu', name = None):

    x = Conv2D(filters,(num_row, num_col), strides=strides, padding=padding, use_bias=False)(x)
    x = BatchNormalization(axis=3, scale=False)(x)
    if(activation == None):
        return x
    x = Activation(activation, name=name)(x)

    return x

# our 2D transposed Convolution with batch normalization

 # 2D Transposed Convolutional layers

 #   Arguments:      #############################################################
 #       x {keras layer} -- input layer                                         #
 #       filters {int} -- number of filters                                    #
 #       num_row {int} -- number of rows in filters                           #
 #       num_col {int} -- number of columns in filters

 #   Keyword Arguments:
 #       padding {str} -- mode of padding (default: {'same'})
 #       strides {tuple} -- stride of convolution operation (default: {(2, 2)}) 
 #       name {str} -- name of the layer (default: {None})

  #  Returns:
  #      [keras layer] -- [output layer] ###################################

def trans_conv2d_bn(x, filters, num_row, num_col, padding='same', strides=(2, 2), name=None): 

    x = Conv2DTranspose(filters, (num_row, num_col), strides=strides, padding=padding)(x)
    x = BatchNormalization(axis=3, scale=False)(x)

    return x

# Our Multi-Res Block 

# Arguments: ############################################################
#        U {int} -- Number of filters in a corrsponding UNet stage     #
#        inp {keras layer} -- input layer                             #

#    Returns:                                                       #
#        [keras layer] -- [output layer]                           #
###################################################################

def MultiResBlock(U, inp, alpha = 1.67):

    W = alpha * U

    shortcut = inp

    shortcut = conv2d_bn(shortcut, int(W*0.167) + int(W*0.333) +
                         int(W*0.5), 1, 1, activation=None, padding='same')

    conv3x3 = conv2d_bn(inp, int(W*0.167), 3, 3,
                        activation='relu', padding='same')

    conv5x5 = conv2d_bn(conv3x3, int(W*0.333), 3, 3,
                        activation='relu', padding='same')

    conv7x7 = conv2d_bn(conv5x5, int(W*0.5), 3, 3,
                        activation='relu', padding='same')

    out = concatenate([conv3x3, conv5x5, conv7x7], axis=3)
    out = BatchNormalization(axis=3)(out)

    out = add([shortcut, out])
    out = Activation('relu')(out)
    out = BatchNormalization(axis=3)(out)

    return out

# Our ResPath:
# ResPath

#    Arguments:#######################################
#        filters {int} -- [description]
#        length {int} -- length of ResPath
#        inp {keras layer} -- input layer 

#    Returns:
#        [keras layer] -- [output layer]#############



def ResPath(filters, length, inp):
    shortcut = inp
    shortcut = conv2d_bn(shortcut, filters, 1, 1,
                         activation=None, padding='same')

    out = conv2d_bn(inp, filters, 3, 3, activation='relu', padding='same')

    out = add([shortcut, out])
    out = Activation('relu')(out)
    out = BatchNormalization(axis=3)(out)

    for i in range(length-1):

        shortcut = out
        shortcut = conv2d_bn(shortcut, filters, 1, 1,
                             activation=None, padding='same')

        out = conv2d_bn(out, filters, 3, 3, activation='relu', padding='same')

        out = add([shortcut, out])
        out = Activation('relu')(out)
        out = BatchNormalization(axis=3)(out)

    return out



#    MultiResUNet

#    Arguments: ############################################
#        height {int} -- height of image 
#        width {int} -- width of image 
#        n_channels {int} -- number of channels in image

#    Returns:
#        [keras model] -- MultiResUNet model###############




def MultiResUnet(height, width, n_channels):



    inputs = Input((height, width, n_channels))

    # downsampling part begins here 

    mresblock1 = MultiResBlock(32, inputs)
    pool1 = MaxPooling2D(pool_size=(2, 2))(mresblock1)
    mresblock1 = ResPath(32, 4, mresblock1)

    mresblock2 = MultiResBlock(32*2, pool1)
    pool2 = MaxPooling2D(pool_size=(2, 2))(mresblock2)
    mresblock2 = ResPath(32*2, 3, mresblock2)

    mresblock3 = MultiResBlock(32*4, pool2)
    pool3 = MaxPooling2D(pool_size=(2, 2))(mresblock3)
    mresblock3 = ResPath(32*4, 2, mresblock3)

    mresblock4 = MultiResBlock(32*8, pool3)


    # Upsampling part 

    up5 = concatenate([Conv2DTranspose(
        32*4, (2, 2), strides=(2, 2), padding='same')(mresblock4), mresblock3], axis=3)
    mresblock5 = MultiResBlock(32*8, up5)

    up6 = concatenate([Conv2DTranspose(
        32*4, (2, 2), strides=(2, 2), padding='same')(mresblock5), mresblock2], axis=3)
    mresblock6 = MultiResBlock(32*4, up6)

    up7 = concatenate([Conv2DTranspose(
        32*2, (2, 2), strides=(2, 2), padding='same')(mresblock6), mresblock1], axis=3)
    mresblock7 = MultiResBlock(32*2, up7)


    conv8 = conv2d_bn(mresblock7, 1, 1, 1, activation='sigmoid')

    model = Model(inputs=[inputs], outputs=[conv8])

    return model

So now back to my problem with mismatched input/output dimensions in the UNet architecture.

If I choose filter height/width (128,128) or (256,256) or (512,512) and do :

 model = MultiResUnet(128, 128,3)
 display(model.summary()) 

Tensorflow gives me a perfect result of how the whole architecture looks like. Now if I do this

     model = MultiResUnet(36, 36,3)
     display(model.summary()) 

I get this error :

--------------------------------------------------------------------------- ValueError Traceback (most recent call last) in ----> 1 model = MultiResUnet(36, 36,3) 2 display(model.summary())

in MultiResUnet(height, width, n_channels) 25 26 up5 = concatenate([Conv2DTranspose( ---> 27 32*4, (2, 2), strides=(2, 2), padding='same')(mresblock4), mresblock3], axis=3) 28 mresblock5 = MultiResBlock(32*8, up5) 29

~/miniconda3/envs/MastersThenv/lib/python3.6/site-packages/tensorflow/python/keras/layers/merge.py in concatenate(inputs, axis, **kwargs) 682 A tensor, the concatenation of the inputs alongside axis axis. 683 """ --> 684 return Concatenate(axis=axis, **kwargs)(inputs) 685 686

~/miniconda3/envs/MastersThenv/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer.py in call(self, inputs, *args, **kwargs) 694 if all(hasattr(x, 'get_shape') for x in input_list): 695 input_shapes = nest.map_structure(lambda x: x.get_shape(), inputs) --> 696 self.build(input_shapes) 697 698 # Check input assumptions set after layer building, e.g. input shape.

~/miniconda3/envs/MastersThenv/lib/python3.6/site-packages/tensorflow/python/keras/utils/tf_utils.py in wrapper(instance, input_shape) 146 else: 147 input_shape = tuple(tensor_shape.TensorShape(input_shape).as_list()) --> 148 output_shape = fn(instance, input_shape) 149 if output_shape is not None: 150 if isinstance(output_shape, list):

~/miniconda3/envs/MastersThenv/lib/python3.6/site-packages/tensorflow/python/keras/layers/merge.py in build(self, input_shape) 388 'inputs with matching shapes ' 389 'except for the concat axis. ' --> 390 'Got inputs shapes: %s' % (input_shape)) 391 392 def _merge_function(self, inputs):

ValueError: A Concatenate layer requires inputs with matching shapes except for the concat axis. Got inputs shapes: [(None, 8, 8, 128), (None, 9, 9, 128)]

Why does the Conv2DTranspose give me the wrong dimension

(None, 8, 8, 128)

instead of

(None, 9, 9, 128)

and why doesn't the Concat function complain when I choose filter sizes like (128,128),(256,256) and etc. (multiples of 32) So to generalize this question how can I make this UNet architecture work with any filter size and how can I deal with the Conv2DTranspose layer producing an output that has one dimension less(width/height) than the actually needed dimension(when the filter size isn't a multiple of 32 or is not symmetric) and why doesn't this happen with other filter sizes that are a multiple of the 32. And what If I had variable Input sizes ??

Any help would be highly appreciated.

cheers, H

2

2 Answers

4
votes

U-Net family of models (such as the MultiResUNet model above) follow an encoder-decoder architecture. Encoder is a down-sampling path with feature extraction whereas the decoder an upsampling one. Feature maps from encoder are concatenated at the decoder through skip-connections. These feature maps are concatenated at the last axis, the 'channel' axis (considering the features to be having dimensions [batch_size, height, width, channels]). Now, for the features to be concatenated at any axis ('channel' axis, in our case), the dimensions at all the other axes must match.

In the above model architecture, there are 3 downsampling/max-pooling operations being performed (through MaxPooling2D)in the encoder path. At the decoder path 3 upsampling/transpose-conv operations are performed, aiming to restore the image back to the full dimension. However, for the concatenations (through skip-connections) to happen, the downsampled and upsampled feature dimensions of height, width & batch_size should remain identical at every "level" of the model. I'll illustrate this with the examples you mentioned in the question:

1st case : Input dimensions (128,128,3) : 128 -> 64 -> 32 -> 16 -> 32 -> 64 -> 128

2nd case: Input dimensions (36,36,3) : 36 -> 18 -> 9 -> 4 -> 8 -> 16 -> 32

In the 2nd case, when the height and width of feature map reaches 9 in the encoder path, further downsampling leads to a dimension change (loss) that cannot be regained in the decoder while upsampling. Hence, it throws an error due to inability to concatenate feature maps of dimensions [(None, 8, 8, 128)] & [(None, 9, 9, 128)].

In general, for a simple encoder-decoder model (with skip-connections) having 'n' downsampling (MaxPooling2D) layers, the input dimension must be a multiple of 2^n to be able to concatenate the model's encoder features at the decoder. In this case n=3, hence the input must be a multiple of 8 to not run into these dimension mismatch errors.

Hope this helps! :)

0
votes

Thanks @Balraj Ashwath for the great answer! Then, if your input has shape h and you want to use this architecture with depth d (h >= 2^d), one possibility is to pad the dimension of h with delta_h zeros, given by the following expression:

import numpy as np

h, d = 36, 3
delta_h = np.ceil(h/(2**d)) * (2**d) - h
print(delta_h)
> 4.0

Then, following the example of @Balraj Ashwath:

40 -> 20 -> 10 -> 5 -> 10 -> 20 -> 40