We are facing a very strange issue. We tested the exact same model into two different “execution” settings. In the first case, given a certain amount of epochs, we train using mini-batches for one epoch, and thereafter we test on the validation set following the same criteria. Then, we go for the next epoch. Clearly, before each training epoch, we use model.train(), and before validation we turn on model.eval().
Then we take the exact same model (same init, same dataset, same epochs, etc.) and we just train it without validation after each epoch.
Just looking at performance on training set, we observed that, even if we fixed all seeds, the two training procedures evolve differently and produce quite different metrics results (losses, accuracy, and so on). Specifically, the training-only procedure is less performing.
We also observe the following things:
- It is not a reproducibility issue, because multiple executions of the same procedure produce exactly the same results (and this is intended);
- Removing the dropout, it appears that the problem vanishes;
- Batchnorm1d layer, that still has different behaviours between training and evaluation, seems to work properly;
- The issue still happens if we move from training onto TPUs to CPUs. We are working and tried Pythorch 1.6, Pythorch nightly, XLA 1.6.
We quite lost one full day in trying to tackle this issue (and no, we cannot avoid using dropout). Does anyone have any idea about how to solve this fact?
Thank you very much!
p.s. Here the code employed for the training (on CPU).
def sigmoid(x):
return 1 / (1 + torch.exp(-x))
def _run(model, EPOCHS, training_data_in, validation_data_in=None):
def train_fn(train_dataloader, model, optimizer, criterion):
running_loss = 0.
running_accuracy = 0.
running_tp = 0.
running_tn = 0.
running_fp = 0.
running_fn = 0.
model.train()
for batch_idx, (ecg, spo2, labels) in enumerate(train_dataloader, 1):
optimizer.zero_grad()
outputs = model(ecg)
loss = criterion(outputs, labels)
loss.backward() # calculate the gradients
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optimizer.step() # update the network weights
running_loss += loss.item()
predicted = torch.round(sigmoid(outputs.data)) # here determining the sigmoid, not included in the model
running_accuracy += (predicted == labels).sum().item() / labels.size(0)
fp = ((predicted - labels) == 1.).sum().item()
fn = ((predicted - labels) == -1.).sum().item()
tp = ((predicted + labels) == 2.).sum().item()
tn = ((predicted + labels) == 0.).sum().item()
running_tp += tp
running_fp += fp
running_tn += tn
running_fn += fn
retval = {'loss':running_loss / batch_idx,
'accuracy':running_accuracy / batch_idx,
'tp':running_tp,
'tn':running_tn,
'fp':running_fp,
'fn':running_fn
}
return retval
def valid_fn(valid_dataloader, model, criterion):
running_loss = 0.
running_accuracy = 0.
running_tp = 0.
running_tn = 0.
running_fp = 0.
running_fn = 0.
model.eval()
for batch_idx, (ecg, spo2, labels) in enumerate(valid_dataloader, 1):
outputs = model(ecg)
loss = criterion(outputs, labels)
running_loss += loss.item()
predicted = torch.round(sigmoid(outputs.data)) # here determining the sigmoid, not included in the model
running_accuracy += (predicted == labels).sum().item() / labels.size(0)
fp = ((predicted - labels) == 1.).sum().item()
fn = ((predicted - labels) == -1.).sum().item()
tp = ((predicted + labels) == 2.).sum().item()
tn = ((predicted + labels) == 0.).sum().item()
running_tp += tp
running_fp += fp
running_tn += tn
running_fn += fn
retval = {'loss':running_loss / batch_idx,
'accuracy':running_accuracy / batch_idx,
'tp':running_tp,
'tn':running_tn,
'fp':running_fp,
'fn':running_fn
}
return retval
# Defining data loaders
train_dataloader = torch.utils.data.DataLoader(training_data_in, batch_size=BATCH_SIZE, shuffle=True, num_workers=1)
if validation_data_in != None:
validation_dataloader = torch.utils.data.DataLoader(validation_data_in, batch_size=BATCH_SIZE, shuffle=False, num_workers=1)
# Defining the loss function
criterion = nn.BCEWithLogitsLoss()
# Defining the optimizer
import torch.optim as optim
optimizer = optim.AdamW(model.parameters(), lr=3e-4, amsgrad=False, eps=1e-07)
# Training code
metrics_history = {"loss":[], "accuracy":[], "precision":[], "recall":[], "f1":[], "specificity":[], "accuracy_bis":[], "tp":[], "tn":[], "fp":[], "fn":[],
"val_loss":[], "val_accuracy":[], "val_precision":[], "val_recall":[], "val_f1":[], "val_specificity":[], "val_accuracy_bis":[], "val_tp":[], "val_tn":[], "val_fp":[], "val_fn":[],}
train_begin = time.time()
for epoch in range(EPOCHS):
start = time.time()
print("EPOCH:", epoch+1)
train_metrics = train_fn(train_dataloader=train_dataloader,
model=model,
optimizer=optimizer,
criterion=criterion)
metrics_history["loss"].append(train_metrics["loss"])
metrics_history["accuracy"].append(train_metrics["accuracy"])
metrics_history["tp"].append(train_metrics["tp"])
metrics_history["tn"].append(train_metrics["tn"])
metrics_history["fp"].append(train_metrics["fp"])
metrics_history["fn"].append(train_metrics["fn"])
precision = train_metrics["tp"] / (train_metrics["tp"] + train_metrics["fp"]) if train_metrics["tp"] > 0 else 0
recall = train_metrics["tp"] / (train_metrics["tp"] + train_metrics["fn"]) if train_metrics["tp"] > 0 else 0
specificity = train_metrics["tn"] / (train_metrics["tn"] + train_metrics["fp"]) if train_metrics["tn"] > 0 else 0
f1 = 2*precision*recall / (precision + recall) if precision*recall > 0 else 0
metrics_history["precision"].append(precision)
metrics_history["recall"].append(recall)
metrics_history["f1"].append(f1)
metrics_history["specificity"].append(specificity)
if validation_data_in != None:
# Calculate the metrics on the validation data, in the same way as done for training
with torch.no_grad(): # don't keep track of the info necessary to calculate the gradients
val_metrics = valid_fn(valid_dataloader=validation_dataloader,
model=model,
criterion=criterion)
metrics_history["val_loss"].append(val_metrics["loss"])
metrics_history["val_accuracy"].append(val_metrics["accuracy"])
metrics_history["val_tp"].append(val_metrics["tp"])
metrics_history["val_tn"].append(val_metrics["tn"])
metrics_history["val_fp"].append(val_metrics["fp"])
metrics_history["val_fn"].append(val_metrics["fn"])
val_precision = val_metrics["tp"] / (val_metrics["tp"] + val_metrics["fp"]) if val_metrics["tp"] > 0 else 0
val_recall = val_metrics["tp"] / (val_metrics["tp"] + val_metrics["fn"]) if val_metrics["tp"] > 0 else 0
val_specificity = val_metrics["tn"] / (val_metrics["tn"] + val_metrics["fp"]) if val_metrics["tn"] > 0 else 0
val_f1 = 2*val_precision*val_recall / (val_precision + val_recall) if val_precision*val_recall > 0 else 0
metrics_history["val_precision"].append(val_precision)
metrics_history["val_recall"].append(val_recall)
metrics_history["val_f1"].append(val_f1)
metrics_history["val_specificity"].append(val_specificity)
print(" > Training/validation loss:", round(train_metrics['loss'], 4), round(val_metrics['loss'], 4))
print(" > Training/validation accuracy:", round(train_metrics['accuracy'], 4), round(val_metrics['accuracy'], 4))
print(" > Training/validation precision:", round(precision, 4), round(val_precision, 4))
print(" > Training/validation recall:", round(recall, 4), round(val_recall, 4))
print(" > Training/validation f1:", round(f1, 4), round(val_f1, 4))
print(" > Training/validation specificity:", round(specificity, 4), round(val_specificity, 4))
else:
print(" > Training loss:", round(train_metrics['loss'], 4))
print(" > Training accuracy:", round(train_metrics['accuracy'], 4))
print(" > Training precision:", round(precision, 4))
print(" > Training recall:", round(recall, 4))
print(" > Training f1:", round(f1, 4))
print(" > Training specificity:", round(specificity, 4))
print("Completed in:", round(time.time() - start, 1), "seconds \n")
print("Training completed in:", round((time.time()- train_begin)/60, 1), "minutes")
# Save the model weights
torch.save(model.state_dict(), './nnet_model.pt')
# Save the metrics history
torch.save(metrics_history, 'training_history')
And here is the function that initializes the model and set the seeds, called before each execution of the code of "_run":
def reinit_model():
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
net = Net() # the model
return net