5
votes

I'm going through the PyTorch Transfer Learning tutorial at: link

In the data augmentation stage, there is the following step to normalize images:

transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

I can understand why it's doing this but I can't find how the mean and std values get calculated? I tried to calculate the mean on the train data set and the mean values are:

array([ 0.11727478,  0.04542569, -0.28624609], dtype=float32)
3

3 Answers

10
votes

Your numbers don't seem right to me; since the ToTensor transform has output in the range [0.0, 1.0] it shouldn't be possible to get a negative mean.

If I calculate the mean with

traindata = datasets.ImageFolder(data_dir + '/train', transforms.ToTensor())
image_means = torch.stack([t.mean(1).mean(1) for t, c in traindata])
image_means.mean(0)

I get (0.5143, 0.4760, 0.3487) and for the validation set (0.5224, 0.4799, 0.3564). These are closer to the numbers in the tutorial. Searching for the specific numbers, you'll see that they appear in the Imagenet example, so my guess is that they are the means of the Imagenet dataset, of which the tutorial dataset is a subset.

0
votes

You can calculate the mean and standard deviation on the whole dataset by iterating all over the images. Like that

You need PyTorch and Torchvision

torch~=1.8.0
torchvision~=0.9.0

Code

import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

train_set = torchvision.datasets.ImageFolder(
    root='/Path/',
    transform=transforms.Compose([
        transforms.ToTensor()
    ])
)

loader = DataLoader(train_set, batch_size=1, num_workers=4)
print(loader)
data = next(iter(loader))
print("Mean", data[0].mean())
print("Std",  data[0].std())
0
votes
def get_mean_std(loader):
    mean = 0.
    std = 0.
    for images, _ in loader:
        batch_samples = images.size(0) # batch size (the last batch can have smaller size!)
        images = images.view(batch_samples, images.size(1), -1)
        mean += images.mean(2).sum(0)
        std += images.std(2).sum(0)

    mean /= len(loader.dataset)
    std /= len(loader.dataset)
    return mean, std