During training, the BatchNorm-layer tries to do two things:
- estimate the mean and variance of the entire training set (population statistics)
- normalize the inputs mean and variance, such that they behave like a Gaussian
In the ideal case, one would use the population statistic of the entire dataset in the second point. However, these are unknown and change during training. There are also some other issues with this.
A work-around is doing the normalization of the input by
gamma * (x - mean) / sigma + b
based on mini-batch statistics mean
, sigma
.
During training, the running average of mini-batch statistics is used to approximate the population statistics.
Now, the original BatchNorm formulation uses the approximated mean and variance of the entire dataset for normalization during inference. As the network is fixed, the approximation of the mean
and variance
should be pretty good. While it seems to make sense to now use the population statistic, it is a critical change: from mini-batch statistics to statistics of the entire training data.
It is critical when batches are not iid or have very small batch sizes during training. (But I also observed it for batches of size 32).
The proposed BatchNorm implicitly simply assumes that both statistics are very similar. In particular, training on mini-batches of size 1 as in pix2pix or dualgan gives very bad information about the population statistic. Here it is the case, that they might contain totally different values.
Having now a deep network, the late layers expect inputs to be normalized batches (in the sense of mini-batch statistics). Note, they are trained on this particular kind of data. But using the entire-dataset statistics violates the assumption during inference.
How to solve this issue? Either also use mini-batch statistics during inference as in the implementations you mentioned. Or use BatchReNormalization which introduces 2 additional terms to remove the difference between the mini-batch and population statistics
or simply use InstanceNormalization (for regression tasks), which is, in fact, the same as BatchNorm but treats each example in the batch individually and also does no use population statistics.
I also had this issue during research and now use for the regression task the InstanceNorm layer.