Say I have a batch of images in the form of tensors with dimensions (B x C x W x H) where B is the batch size, C is the number of channels in the image, and W and H are the width and height of the image respectively. I'm looking to use the transforms.Normalize()
function to normalize my images with respect to the mean and standard deviation of the dataset across the C image channels, meaning that I want a resulting tensor in the form 1 x C. Is there a straightforward way to do this?
I tried torch.view(C, -1).mean(1)
and torch.view(C, -1).std(1)
but I get the error:
view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
Edit
After looking into how view()
works in PyTorch, I know realize why my approach doesn't work; however, I still can't figure out how to get the per-channel mean and standard deviation.