The following code is to train an MLP with images of size 64*64, while using the loss ||output - input||^2.
For some reason, my weights per epoch are not being updated as shown at the end.
class MLP(nn.Module):
def __init__(self, size_list):
super(MLP, self).__init__()
layers = []
self.size_list = size_list
for i in range(len(size_list) - 2):
layers.append(nn.Linear(size_list[i],size_list[i+1]))
layers.append(nn.ReLU())
layers.append(nn.Linear(size_list[-2], size_list[-1]))
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
model_1 = MLP([4096, 64, 4096])
And for training each epoch:
def train_epoch(model, train_loader, criterion, optimizer):
model.train()
model.to(device)
running_loss = 0.0
start_time = time.time()
# train batch
for batch_idx, (data) in enumerate(train_loader):
optimizer.zero_grad()
data = data.to(device)
outputs = model(data)
loss = criterion(outputs, data)
running_loss += loss.item()
loss.backward()
optimizer.step()
end_time = time.time()
weight_ll = model.net[0].weight
running_loss /= len(train_loader)
print('Training Loss: ', running_loss, 'Time: ',end_time - start_time, 's')
return running_loss, outputs, weight_ll
for training the data:
n_epochs = 20
Train_loss = []
weights=[]
criterion = nn.MSELoss()
optimizer = optim.SGD(model_1.parameters(), lr = 0.1)
for i in range(n_epochs):
train_loss, output, weights_ll = train_epoch(model_1, trainloader, criterion, optimizer)
Train_loss.append(train_loss)
weights.append(weights_ll)
print('='*20)
Now, when I print the weights of the first fully connected layer per epoch they aren't being updated.
print(weights[0][0])
print(weights[19][0])
The output for the above is (showing the weight in epoch 0 and in epoch 19):
tensor([ 0.0086, 0.0069, -0.0048, ..., -0.0082, -0.0115, -0.0133],
grad_fn=<SelectBackward>)
tensor([ 0.0086, 0.0069, -0.0048, ..., -0.0082, -0.0115, -0.0133],
grad_fn=<SelectBackward>)
What may be going wrong? Looking at my loss, it's decreasing at a steady rate but there is no change in the weights.
loss.backword()
? – zihaozhihao