2
votes

When classifying the CIFAR10 in PyTorch, there are normally 50,000 training samples and 10,000 testing samples. However, if I need to create a validation set, I can do it by splitting the training set into 40000 train samples and 10000 validation samples. I used the following codes

train_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
test_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])

cifar_train_L = CIFAR10('./data',download=True, train= True, transform = train_transform)
cifar_test = CIFAR10('./data',download=True, train = False, transform= test_transform) 

train_size = int(0.8*len(cifar_training))
val_size = len(cifar_training) - train_size
cifar_train, cifar_val = torch.utils.data.random_split(cifar_train_L,[train_size,val_size])

train_dataloader = torch.utils.data.DataLoader(cifar_train, batch_size= BATCH_SIZE, shuffle= True, num_workers=2)
test_dataloader = torch.utils.data.DataLoader(cifar_test,batch_size= BATCH_SIZE, shuffle= True, num_workers= 2)
val_dataloader = torch.utils.data.DataLoader(cifar_val,batch_size= BATCH_SIZE, shuffle= True, num_workers= 2)

Normally, when augmenting data in PyTorch, different augmenting processes are used under the transforms.Compose function (i.e., transforms.RandomHorizontalFlip()). However, if I use these augmentation processes before splitting the training set and validation set, the augmented data will also be included in the validation set. Is there any way, I can fix this problem?

In short, I want to manually split the training dataset into train and validation set as well as I want to use the data augmentation technique into the new training set.

1

1 Answers

2
votes

You can manually override the transforms of the dataset:

cifar_train, cifar_val = torch.utils.data.random_split(cifar_train_L,[train_size,val_size])
cifar_val.transforms = test_transform