6
votes

Batch Normalization has different behavior in training phase and testing phase.

For example, when using tf.contrib.layers.batch_norm in tensorflow, we should set different value for is_training in different phase.

My qusetion is: what if I still set is_training=True when testing? That is to say what if I still use the training mode in testing phase?

The reason why I come up with this question is that, the released code of both Pix2Pix and DualGAN don't set is_training=False when testing. And it seems that if is_training=False is set when testing, the quality of generated images could be very bad.

Is there someone could please explain this? thanks.

1

1 Answers

9
votes

During training, the BatchNorm-layer tries to do two things:

  • estimate the mean and variance of the entire training set (population statistics)
  • normalize the inputs mean and variance, such that they behave like a Gaussian

In the ideal case, one would use the population statistic of the entire dataset in the second point. However, these are unknown and change during training. There are also some other issues with this.

A work-around is doing the normalization of the input by

gamma * (x - mean) / sigma + b

based on mini-batch statistics mean, sigma.

During training, the running average of mini-batch statistics is used to approximate the population statistics.

Now, the original BatchNorm formulation uses the approximated mean and variance of the entire dataset for normalization during inference. As the network is fixed, the approximation of the mean and variance should be pretty good. While it seems to make sense to now use the population statistic, it is a critical change: from mini-batch statistics to statistics of the entire training data.

It is critical when batches are not iid or have very small batch sizes during training. (But I also observed it for batches of size 32).

The proposed BatchNorm implicitly simply assumes that both statistics are very similar. In particular, training on mini-batches of size 1 as in pix2pix or dualgan gives very bad information about the population statistic. Here it is the case, that they might contain totally different values.

Having now a deep network, the late layers expect inputs to be normalized batches (in the sense of mini-batch statistics). Note, they are trained on this particular kind of data. But using the entire-dataset statistics violates the assumption during inference.

How to solve this issue? Either also use mini-batch statistics during inference as in the implementations you mentioned. Or use BatchReNormalization which introduces 2 additional terms to remove the difference between the mini-batch and population statistics or simply use InstanceNormalization (for regression tasks), which is, in fact, the same as BatchNorm but treats each example in the batch individually and also does no use population statistics.

I also had this issue during research and now use for the regression task the InstanceNorm layer.