2
votes

I am trying to use Batch Normalization using tf.layers.batch_normalization() and my code looks like this:

def create_conv_exp_model(fingerprint_input, model_settings, is_training):


  # Dropout placeholder
  if is_training:
    dropout_prob = tf.placeholder(tf.float32, name='dropout_prob')

  # Mode placeholder
  mode_placeholder = tf.placeholder(tf.bool, name="mode_placeholder")

  he_init = tf.contrib.layers.variance_scaling_initializer(mode="FAN_AVG")

  # Input Layer
  input_frequency_size = model_settings['bins']
  input_time_size = model_settings['spectrogram_length']
  net = tf.reshape(fingerprint_input,
                   [-1, input_time_size, input_frequency_size, 1],
                   name="reshape")
  net = tf.layers.batch_normalization(net, 
                                      training=mode_placeholder,
                                      name='bn_0')

  for i in range(1, 6):
    net = tf.layers.conv2d(inputs=net,
                           filters=8*(2**i),
                           kernel_size=[5, 5],
                           padding='same',
                           kernel_initializer=he_init,
                           name="conv_%d"%i)
    net = tf.layers.batch_normalization(net,
                                        training=mode_placeholder,
                                        name='bn_%d'%i)
    with tf.name_scope("relu_%d"%i):
      net = tf.nn.relu(net)
    net = tf.layers.max_pooling2d(net, [2, 2], [2, 2], 'SAME', 
                                  name="maxpool_%d"%i)

  net_shape = net.get_shape().as_list()
  net_height = net_shape[1]
  net_width = net_shape[2]
  net = tf.layers.conv2d( inputs=net,
                          filters=1024,
                          kernel_size=[net_height, net_width],
                          strides=(net_height, net_width),
                          padding='same',
                          kernel_initializer=he_init,
                          name="conv_f")
  net = tf.layers.batch_normalization( net, 
                                        training=mode_placeholder,
                                        name='bn_f')
  with tf.name_scope("relu_f"):
    net = tf.nn.relu(net)

  net = tf.layers.conv2d( inputs=net,
                          filters=model_settings['label_count'],
                          kernel_size=[1, 1],
                          padding='same',
                          kernel_initializer=he_init,
                          name="conv_l")

  ### Squeeze
  squeezed = tf.squeeze(net, axis=[1, 2], name="squeezed")

  if is_training:
    return squeezed, dropout_prob, mode_placeholder
  else:
    return squeezed, mode_placeholder

And my train step looks like this:

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
  optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate_input)
  gvs = optimizer.compute_gradients(cross_entropy_mean)
  capped_gvs = [(tf.clip_by_value(grad, -2., 2.), var) for grad, var in gvs]
  train_step = optimizer.apply_gradients(gvs))

During training, I am feeding the graph with:

train_summary, train_accuracy, cross_entropy_value, _, _ = sess.run(
    [
        merged_summaries, evaluation_step, cross_entropy_mean, train_step,
        increment_global_step
    ],
    feed_dict={
        fingerprint_input: train_fingerprints,
        ground_truth_input: train_ground_truth,
        learning_rate_input: learning_rate_value,
        dropout_prob: 0.5,
        mode_placeholder: True
    })

During validation,

validation_summary, validation_accuracy, conf_matrix = sess.run(
                [merged_summaries, evaluation_step, confusion_matrix],
                feed_dict={
                    fingerprint_input: validation_fingerprints,
                    ground_truth_input: validation_ground_truth,
                    dropout_prob: 1.0,
                    mode_placeholder: False
                })

My loss and accuracy curves (orange is training, blue is validation): Plot of loss vs number of iterations, Plot of accuracy vs number of iterations

The validation loss (and accuracy) seem very erratic. Is my implementation of Batch Normalization wrong? Or is this normal with Batch Normalization and I should wait for more iterations?

3
Does it work normally if you remove BatchNorm? It seems like there could be many things causing this.mxbi
@mxbi it works normally when I remove BatchNorm. I trained the BatchNorm model for 32 epochs but the validation loss did not decrease.Zafarullah Mahmood
Is this still a problem?gab

3 Answers

1
votes

You need to pass is_training to tf.layers.batch_normalization(..., training=is_training) or it tries to normalize the inference minibatches using the minibatch statistics instead of the training statistics, which is wrong.

0
votes

There are mainly two things to check.

1. Are you sure that you are using batch normalization (BN) correctly in the train op?

If you read the layer documentation:

Note: when training, the moving_mean and moving_variance need to be updated. By default the update ops are placed in tf.GraphKeys.UPDATE_OPS, so they need to be added as a dependency to the train_op. Also, be sure to add any batch_normalization ops before getting the update_ops collection. Otherwise, update_ops will be empty, and training/inference will not work properly.

For example:

x_norm = tf.layers.batch_normalization(x, training=training)

# ...
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
     train_op = optimizer.minimize(loss)

2. Otherwise, try lowering the "momentum" in the BN.

During the training, in fact, the BN uses two moving averages of the mean and the variance that are supposed to approximate the population statistics. Mean and variance are initialized to 0 and 1 respectively and then, step by step, they are multiplied by the momentum value (default is 0.99) and added the new value*0.01. At inference (test) time, the normalization uses these statistics. For this reason, it takes these values a little while to arrive at the "real" mean and variance of the data.

Source:

https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization

https://github.com/keras-team/keras/issues/7265

https://github.com/keras-team/keras/issues/3366

The original BN paper can be found here:

https://arxiv.org/abs/1502.03167

0
votes

I also observed oscillations in validation loss when adding batch norm before ReLU. We found that moving the batch norm after the ReLU resolved the issue.