2
votes

Hey so I am trying my hand at image classification/transfer learning using the monkey species dataset and the resnet50 with a modified final fc layer to predict just the 10 classes. Eveything is working until I use model.train() and model.eval() then after the first epoch it starts to return nans and the accuracy drops off as you'll see below. I'm curious why is this only when switching to train/eval....?

First I import the model and attach the classifier and freeze the parameters

%%capture
resnet = models.resnet50(pretrained=True)

for param in resnet.parameters():
  param.required_grad = False

in_features = resnet.fc.in_features


# Build custom classifier
classifier = nn.Sequential(OrderedDict([('fc1', nn.Linear(in_features, 512)),
                                        ('relu', nn.ReLU()),
                                        ('drop', nn.Dropout(0.05)),
                                        ('fc2', nn.Linear(512, 10)),
                                        ]))

# ('output', nn.LogSoftmax(dim=1))
resnet.classifier = classifier

resnet.to(device)

Then setting my loss func, optimizer, and shceduler

# Step : Define criterion and optimizer
criterion = nn.CrossEntropyLoss()
# pass the optimizer to the appended classifier layer
optimizer = torch.optim.SGD(resnet.parameters(), lr=0.01)
# Scheduler
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10], gamma=0.05)  

Then setting the training and validation loops

epochs = 20


tr_losses = []
avg_epoch_tr_loss = []
tr_accuracy = []


val_losses = []
avg_epoch_val_loss = []
val_accuracy = []
val_loss_min = np.Inf


resnet.train()
for epoch in range(epochs):
  for i, batch in enumerate(train_loader):
    # Pull the data and labels from the batch
    data, label = batch
    # If available push data and label to GPU
    if train_on_gpu:
      data, label = data.to(device), label.to(device)
    # Compute the logit
    logit = resnet(data)
    # Compte loss
    loss = criterion(logit, label)
    # Clearing the gradient
    resnet.zero_grad()
    # Backpropagate the gradients (accumulte the partial derivatives of loss)
    loss.backward()
    # Apply the updates to the optimizer step in the opposite direction to the gradient
    optimizer.step()
    # Store the losses of each batch
    # loss.item() seperates the loss from comp graph
    tr_losses.append(loss.item())
    # Detach and store the average accuracy of each batch
    tr_accuracy.append(label.eq(logit.argmax(dim=1)).float().mean())
    # Print the rolling batch training loss every 20 batches
    if i % 40 == 0 and not i == 1:
      print(f'Batch No: {i} \tAverage Training Batch Loss: {torch.tensor(tr_losses).mean():.2f}')
  # Print the average loss for each epoch
  print(f'\nEpoch No: {epoch + 1},Training Loss: {torch.tensor(tr_losses).mean():.2f}')
  # Print the average accuracy for each epoch
  print(f'Epoch No: {epoch + 1}, Training Accuracy: {torch.tensor(tr_accuracy).mean():.2f}\n')
  # Store the avg epoch loss for plotting
  avg_epoch_tr_loss.append(torch.tensor(tr_losses).mean())


  resnet.eval()
  for i, batch in enumerate(val_loader):
    # Pull the data and labels from the batch
    data, label = batch
    # If available push data and label to GPU
    if train_on_gpu:
      data, label = data.to(device), label.to(device)
    # Compute the logits without computing the gradients
    with torch.no_grad():
      logit = resnet(data)
    # Compte loss
    loss = criterion(logit, label)
    # Store test loss
    val_losses.append(loss.item())
    # Store the accuracy for each batch
    val_accuracy.append(label.eq(logit.argmax(dim=1)).float().mean())
    if i % 20 == 0 and not i == 1:
      print(f'Batch No: {i+1} \tAverage Val Batch Loss: {torch.tensor(val_losses).mean():.2f}')
  # Print the average loss for each epoch
  print(f'\nEpoch No: {epoch + 1}, Epoch Val Loss: {torch.tensor(val_losses).mean():.2f}')
  # Print the average accuracy for each epoch    
  print(f'Epoch No: {epoch + 1}, Epoch Val Accuracy: {torch.tensor(val_accuracy).mean():.2f}\n')
  # Store the avg epoch loss for plotting
  avg_epoch_val_loss.append(torch.tensor(val_losses).mean())

  # Checpoininting the model using val loss threshold
  if torch.tensor(val_losses).float().mean() <= val_loss_min:
    print("Epoch Val Loss Decreased... Saving model")
    # save current model
    torch.save(resnet.state_dict(), '/content/drive/MyDrive/1. Full Projects/Intel Image Classification/model_state.pt')
    val_loss_min = torch.tensor(val_losses).mean()
  # Step the scheduler for the next epoch
  scheduler.step()
  # Print the updated learning rate
  print('Learning Rate Set To: {:.5f}'.format(optimizer.state_dict()['param_groups'][0]['lr']),'\n')

The model starts to train but then slowly becomes nan values

Batch No: 0     Average Training Batch Loss: 9.51
Batch No: 40    Average Training Batch Loss: 1.71
Batch No: 80    Average Training Batch Loss: 1.15
Batch No: 120   Average Training Batch Loss: 0.94

Epoch No: 1,Training Loss: 0.83
Epoch No: 1, Training Accuracy: 0.78

Batch No: 1     Average Val Batch Loss: 0.39
Batch No: 21    Average Val Batch Loss: 0.56
Batch No: 41    Average Val Batch Loss: 0.54
Batch No: 61    Average Val Batch Loss: 0.54

Epoch No: 1, Epoch Val Loss: 0.55
Epoch No: 1, Epoch Val Accuracy: 0.81

Epoch Val Loss Decreased... Saving model
Learning Rate Set To: 0.01000 

Batch No: 0     Average Training Batch Loss: 0.83
Batch No: 40    Average Training Batch Loss: nan
Batch No: 80    Average Training Batch Loss: nan
1
is it possible one of your training samples has nan in it? - Shai
you might find useful information in this answer - Shai
@Shai no nans in the data its just folders of images, Ill check the link too, thanks. - Digital Moniker
it might be that transformations applied to data creates nans. does the nan always appear at the same time/iteration/epoch? what happens if you significantly reduce the learning rate? - Shai
@Shai I see that your in Rehovot, I'm Irish but I live in Tel Aviv ;) - Digital Moniker

1 Answers

1
votes

I see that resnet.zero_grad() is after logit = resnet(data), which causes the gradient to explode in your case.

Please do it as below:

# Clearing the gradient
optimizer.zero_grad()
logit = resnet(data)

# Compute loss
loss = criterion(logit, label)