I have a problem or two with the input dimensions of modified U-Net architecture. In order to save your time and better understand/reproduce my results, I'll post the code and the output dimensions. The modified U-Net architecture is the MultiResUNet architecture from https://github.com/nibtehaz/MultiResUNet/blob/master/MultiResUNet.py. and is based on this paper https://arxiv.org/abs/1902.04049 Please Don't be turned off by the length of this code. You can simply copy-paste it and it shouldn't take longer than 10 seconds to reproduce my results. Also you don't need a dataset for this. Tested with TF.v1.9 Keras v.2.20.
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Conv2DTranspose, concatenate, BatchNormalization, Activation, add
from tensorflow.keras.models import Model
from tensorflow.keras.activations import relu
###{ 2D Convolutional layers
# Arguments: ######################################################################
# x {keras layer} -- input layer #
# filters {int} -- number of filters #
# num_row {int} -- number of rows in filters #
# num_col {int} -- number of columns in filters #
# Keyword Arguments:
# padding {str} -- mode of padding (default: {'same'})
# strides {tuple} -- stride of convolution operation (default: {(1, 1)})
# activation {str} -- activation function (default: {'relu'})
# name {str} -- name of the layer (default: {None})
# Returns:
# [keras layer] -- [output layer]}
# # ############################################################################
def conv2d_bn(x, filters ,num_row,num_col, padding = "same", strides = (1,1), activation = 'relu', name = None):
x = Conv2D(filters,(num_row, num_col), strides=strides, padding=padding, use_bias=False)(x)
x = BatchNormalization(axis=3, scale=False)(x)
if(activation == None):
return x
x = Activation(activation, name=name)(x)
return x
# our 2D transposed Convolution with batch normalization
# 2D Transposed Convolutional layers
# Arguments: #############################################################
# x {keras layer} -- input layer #
# filters {int} -- number of filters #
# num_row {int} -- number of rows in filters #
# num_col {int} -- number of columns in filters
# Keyword Arguments:
# padding {str} -- mode of padding (default: {'same'})
# strides {tuple} -- stride of convolution operation (default: {(2, 2)})
# name {str} -- name of the layer (default: {None})
# Returns:
# [keras layer] -- [output layer] ###################################
def trans_conv2d_bn(x, filters, num_row, num_col, padding='same', strides=(2, 2), name=None):
x = Conv2DTranspose(filters, (num_row, num_col), strides=strides, padding=padding)(x)
x = BatchNormalization(axis=3, scale=False)(x)
return x
# Our Multi-Res Block
# Arguments: ############################################################
# U {int} -- Number of filters in a corrsponding UNet stage #
# inp {keras layer} -- input layer #
# Returns: #
# [keras layer] -- [output layer] #
###################################################################
def MultiResBlock(U, inp, alpha = 1.67):
W = alpha * U
shortcut = inp
shortcut = conv2d_bn(shortcut, int(W*0.167) + int(W*0.333) +
int(W*0.5), 1, 1, activation=None, padding='same')
conv3x3 = conv2d_bn(inp, int(W*0.167), 3, 3,
activation='relu', padding='same')
conv5x5 = conv2d_bn(conv3x3, int(W*0.333), 3, 3,
activation='relu', padding='same')
conv7x7 = conv2d_bn(conv5x5, int(W*0.5), 3, 3,
activation='relu', padding='same')
out = concatenate([conv3x3, conv5x5, conv7x7], axis=3)
out = BatchNormalization(axis=3)(out)
out = add([shortcut, out])
out = Activation('relu')(out)
out = BatchNormalization(axis=3)(out)
return out
# Our ResPath:
# ResPath
# Arguments:#######################################
# filters {int} -- [description]
# length {int} -- length of ResPath
# inp {keras layer} -- input layer
# Returns:
# [keras layer] -- [output layer]#############
def ResPath(filters, length, inp):
shortcut = inp
shortcut = conv2d_bn(shortcut, filters, 1, 1,
activation=None, padding='same')
out = conv2d_bn(inp, filters, 3, 3, activation='relu', padding='same')
out = add([shortcut, out])
out = Activation('relu')(out)
out = BatchNormalization(axis=3)(out)
for i in range(length-1):
shortcut = out
shortcut = conv2d_bn(shortcut, filters, 1, 1,
activation=None, padding='same')
out = conv2d_bn(out, filters, 3, 3, activation='relu', padding='same')
out = add([shortcut, out])
out = Activation('relu')(out)
out = BatchNormalization(axis=3)(out)
return out
# MultiResUNet
# Arguments: ############################################
# height {int} -- height of image
# width {int} -- width of image
# n_channels {int} -- number of channels in image
# Returns:
# [keras model] -- MultiResUNet model###############
def MultiResUnet(height, width, n_channels):
inputs = Input((height, width, n_channels))
# downsampling part begins here
mresblock1 = MultiResBlock(32, inputs)
pool1 = MaxPooling2D(pool_size=(2, 2))(mresblock1)
mresblock1 = ResPath(32, 4, mresblock1)
mresblock2 = MultiResBlock(32*2, pool1)
pool2 = MaxPooling2D(pool_size=(2, 2))(mresblock2)
mresblock2 = ResPath(32*2, 3, mresblock2)
mresblock3 = MultiResBlock(32*4, pool2)
pool3 = MaxPooling2D(pool_size=(2, 2))(mresblock3)
mresblock3 = ResPath(32*4, 2, mresblock3)
mresblock4 = MultiResBlock(32*8, pool3)
# Upsampling part
up5 = concatenate([Conv2DTranspose(
32*4, (2, 2), strides=(2, 2), padding='same')(mresblock4), mresblock3], axis=3)
mresblock5 = MultiResBlock(32*8, up5)
up6 = concatenate([Conv2DTranspose(
32*4, (2, 2), strides=(2, 2), padding='same')(mresblock5), mresblock2], axis=3)
mresblock6 = MultiResBlock(32*4, up6)
up7 = concatenate([Conv2DTranspose(
32*2, (2, 2), strides=(2, 2), padding='same')(mresblock6), mresblock1], axis=3)
mresblock7 = MultiResBlock(32*2, up7)
conv8 = conv2d_bn(mresblock7, 1, 1, 1, activation='sigmoid')
model = Model(inputs=[inputs], outputs=[conv8])
return model
So now back to my problem with mismatched input/output dimensions in the UNet architecture.
If I choose filter height/width (128,128) or (256,256) or (512,512) and do :
model = MultiResUnet(128, 128,3)
display(model.summary())
Tensorflow gives me a perfect result of how the whole architecture looks like. Now if I do this
model = MultiResUnet(36, 36,3)
display(model.summary())
I get this error :
--------------------------------------------------------------------------- ValueError Traceback (most recent call last) in ----> 1 model = MultiResUnet(36, 36,3) 2 display(model.summary())
in MultiResUnet(height, width, n_channels) 25 26 up5 = concatenate([Conv2DTranspose( ---> 27 32*4, (2, 2), strides=(2, 2), padding='same')(mresblock4), mresblock3], axis=3) 28 mresblock5 = MultiResBlock(32*8, up5) 29
~/miniconda3/envs/MastersThenv/lib/python3.6/site-packages/tensorflow/python/keras/layers/merge.py in concatenate(inputs, axis, **kwargs) 682 A tensor, the concatenation of the inputs alongside axis
axis. 683 """ --> 684 return Concatenate(axis=axis, **kwargs)(inputs) 685 686~/miniconda3/envs/MastersThenv/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer.py in call(self, inputs, *args, **kwargs) 694 if all(hasattr(x, 'get_shape') for x in input_list): 695 input_shapes = nest.map_structure(lambda x: x.get_shape(), inputs) --> 696 self.build(input_shapes) 697 698 # Check input assumptions set after layer building, e.g. input shape.
~/miniconda3/envs/MastersThenv/lib/python3.6/site-packages/tensorflow/python/keras/utils/tf_utils.py in wrapper(instance, input_shape) 146 else: 147 input_shape = tuple(tensor_shape.TensorShape(input_shape).as_list()) --> 148 output_shape = fn(instance, input_shape) 149 if output_shape is not None: 150 if isinstance(output_shape, list):
~/miniconda3/envs/MastersThenv/lib/python3.6/site-packages/tensorflow/python/keras/layers/merge.py in build(self, input_shape) 388 'inputs with matching shapes ' 389 'except for the concat axis. ' --> 390 'Got inputs shapes: %s' % (input_shape)) 391 392 def _merge_function(self, inputs):
ValueError: A
Concatenatelayer requires inputs with matching shapes except for the concat axis. Got inputs shapes: [(None, 8, 8, 128), (None, 9, 9, 128)]
Why does the Conv2DTranspose give me the wrong dimension
(None, 8, 8, 128)
instead of
(None, 9, 9, 128)
and why doesn't the Concat function complain when I choose filter sizes like (128,128),(256,256) and etc. (multiples of 32) So to generalize this question how can I make this UNet architecture work with any filter size and how can I deal with the Conv2DTranspose layer producing an output that has one dimension less(width/height) than the actually needed dimension(when the filter size isn't a multiple of 32 or is not symmetric) and why doesn't this happen with other filter sizes that are a multiple of the 32. And what If I had variable Input sizes ??
Any help would be highly appreciated.
cheers, H