I am new to pytorch and I want to use Vgg for transfer learning. I want to delete the fully connected layers and add some new fully connected layers. Also rather than RGB input I want to use grayscale input. For this I will add the weights of the input layer and get a single weight. So the three channel's weights will be added.
I achieved deleting the fully connected layers but I am having trouble with grayscale part. I add the three weights together and form a new weight. Then I try to change the state dict of the vgg model but this gives me error. The networks code is below:
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
vgg=models.vgg16(pretrained = True).features[:30]
w1=vgg.state_dict()['0.weight'][:,0,:,:] #first channel of first input layer's weight
w2=vgg.state_dict()['0.weight'][:,1,:,:]
w3=vgg.state_dict()['0.weight'][:,2,:,:]
w4=w1+w2+w3 # add the three weigths of the channels
w4=w4.unsqueeze(1) # make it 4 dimensional
a=vgg.state_dict()#create a new statedict
a['0.weight']=w4 #replace the new state dict's weigt
vgg.load_state_dict(a) # this line gives the error,load the new state dict
self.vgg =nn.Sequential(vgg)
self.fc1 = nn.Linear(14*14*512, 1000)
self.fc2 = nn.Linear(1000, 2)
def forward(self, x):
x = self.vgg(x)
x = x.view(-1, 14 * 14 * 512)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
This gives an error of:
RuntimeError: Error(s) in loading state_dict for Sequential: size mismatch for 0.weight: copying a param with shape torch.Size([64, 1, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 3, 3, 3]).
So it doesn't allow me to replace the weight with a different sized weight. Is there a solution for this problem or is there anything other that I can try. All I want to do is use the vgg's layers up to fully connected layers and change the first layers weights.