8
votes

Say I have a batch of images in the form of tensors with dimensions (B x C x W x H) where B is the batch size, C is the number of channels in the image, and W and H are the width and height of the image respectively. I'm looking to use the transforms.Normalize() function to normalize my images with respect to the mean and standard deviation of the dataset across the C image channels, meaning that I want a resulting tensor in the form 1 x C. Is there a straightforward way to do this?

I tried torch.view(C, -1).mean(1) and torch.view(C, -1).std(1) but I get the error:

view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

Edit

After looking into how view() works in PyTorch, I know realize why my approach doesn't work; however, I still can't figure out how to get the per-channel mean and standard deviation.

2

2 Answers

9
votes

Note that variances add, not standard deviations. See detailed explanation here: https://apcentral.collegeboard.org/courses/ap-statistics/classroom-resources/why-variances-add-and-why-it-matters

Here is the modified code:

nimages = 0
mean = 0.0
var = 0.0
for i_batch, batch_target in enumerate(trainloader):
    batch = batch_target[0]
    # Rearrange batch to be the shape of [B, C, W * H]
    batch = batch.view(batch.size(0), batch.size(1), -1)
    # Update total number of images
    nimages += batch.size(0)
    # Compute mean and std here
    mean += batch.mean(2).sum(0) 
    var += batch.var(2).sum(0)

mean /= nimages
var /= nimages
std = torch.sqrt(var)

print(mean)
print(std)
8
votes

You just need to rearrange batch tensor in a right way: from [B, C, W, H] to [B, C, W * H] by:

batch = batch.view(batch.size(0), batch.size(1), -1)

Here is complete usage example on random data:

Code:

import torch
from torch.utils.data import TensorDataset, DataLoader

data = torch.randn(64, 3, 28, 28)
labels = torch.zeros(64, 1)
dataset = TensorDataset(data, labels)
loader = DataLoader(dataset, batch_size=8)

nimages = 0
mean = 0.
std = 0.
for batch, _ in loader:
    # Rearrange batch to be the shape of [B, C, W * H]
    batch = batch.view(batch.size(0), batch.size(1), -1)
    # Update total number of images
    nimages += batch.size(0)
    # Compute mean and std here
    mean += batch.mean(2).sum(0) 
    std += batch.std(2).sum(0)

# Final step
mean /= nimages
std /= nimages

print(mean)
print(std)

Output:

tensor([-0.0029, -0.0022, -0.0036])
tensor([0.9942, 0.9939, 0.9923])