1
votes

I'm using TensorFlow for a multi-target regression problem. Specifically, in a fully convolutional residual network for pixel-wise labeling with the input being an image and the label a mask. In my case I am using brain MR as images and the labels are mask of the tumors.

I have accomplish a fairly decent result using my net:

Although I am sure there is still room for improvement. Therefore, I wanted to add batch normalization. I implemented it as follows:

# Convolutional Layer 1
Z10 = tf.nn.conv2d(X, W_conv10, strides = [1, 1, 1, 1], padding='SAME')
Z10 = tf.contrib.layers.batch_norm(Z10, center=True, scale=True, is_training = train_flag)
A10 = tf.nn.relu(Z10)
Z1 = tf.nn.conv2d(Z10, W_conv1, strides = [1, 2, 2, 1], padding='SAME')
Z1 = tf.contrib.layers.batch_norm(Z1, center=True, scale=True, is_training = train_flag)
A1 = tf.nn.relu(Z1)

for each the conv and transpose layers of my net. But the results are not what I expected. the net with batch normalization has a terrible performance. In orange is the loss of the net without batch normalization while the blue has it: loss

Not only the net is learning slower, the predicted labels are also very bad in the net using batch normalization.

Does any one know why this might be the case? Could it be my cost function? I am currently using

loss = tf.nn.sigmoid_cross_entropy_with_logits(logits = dA1, labels = Y) cost = tf.reduce_mean(loss)

3

3 Answers

3
votes

Batch normalization is a terrible normalization choice for tasks related to semantic information being passed through the network. Look into conditional normalization methods - Adaptive Instance Normalization, etc to understand my point. Also, this paper - https://arxiv.org/abs/1903.07291. Batch normalization washes away all the semantic information of the network.

0
votes

It might be a naive guess, but maybe your batch size is too little. Normalizing might be good if the batch is large enough to represent the distribution of input values for the layer. If the batch is too small, information might be lost by normalization. I also had problem with batch normalization on a semantic segmentaion task because the batch size had to be small (<10) due the input image size (1600x1200x3).

0
votes

I tried Batch Normalization on FCN - 8 architecture with the PASCAL VOC2012 dataset. And it gave terrible results, as others have mentioned above, but the model performed good without the Batch Normalization layers. One of my hypothesis for the network to perform badly is that in the decoder architecture we are mainly concerned with upsampling the feature space in a learnable fashion using CNN as a medium, because the feature map for the problem is set in the 1x1 conv performed at the end of base net which extracts the features.

We even add previous layers output from the encoder to the decoder (inspired from resnet architecture) and the reason to do so is to reduce the effect of vanishing gradient problem in deeper architectures.

And Batch Normalization works really well when we want to predict some classes from a picture or a sub - region of the picture, because there we don't have a decoder architecture to upsample the predicted feature space.

Please correct me if I am wrong.