2
votes

We are facing a very strange issue. We tested the exact same model into two different “execution” settings. In the first case, given a certain amount of epochs, we train using mini-batches for one epoch, and thereafter we test on the validation set following the same criteria. Then, we go for the next epoch. Clearly, before each training epoch, we use model.train(), and before validation we turn on model.eval().

Then we take the exact same model (same init, same dataset, same epochs, etc.) and we just train it without validation after each epoch.

Just looking at performance on training set, we observed that, even if we fixed all seeds, the two training procedures evolve differently and produce quite different metrics results (losses, accuracy, and so on). Specifically, the training-only procedure is less performing.

We also observe the following things:

  • It is not a reproducibility issue, because multiple executions of the same procedure produce exactly the same results (and this is intended);
  • Removing the dropout, it appears that the problem vanishes;
  • Batchnorm1d layer, that still has different behaviours between training and evaluation, seems to work properly;
  • The issue still happens if we move from training onto TPUs to CPUs. We are working and tried Pythorch 1.6, Pythorch nightly, XLA 1.6.

We quite lost one full day in trying to tackle this issue (and no, we cannot avoid using dropout). Does anyone have any idea about how to solve this fact?

Thank you very much!

p.s. Here the code employed for the training (on CPU).

def sigmoid(x):
    return 1 / (1 + torch.exp(-x))


def _run(model, EPOCHS, training_data_in, validation_data_in=None):
    
    def train_fn(train_dataloader, model, optimizer, criterion):

        running_loss = 0.
        running_accuracy = 0.
        running_tp = 0.
        running_tn = 0.
        running_fp = 0.
        running_fn = 0.
        
        model.train()

        for batch_idx, (ecg, spo2, labels) in enumerate(train_dataloader, 1):

            optimizer.zero_grad() 
                
            outputs = model(ecg)

            loss = criterion(outputs, labels)
                        
            loss.backward() # calculate the gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            optimizer.step() # update the network weights
                                                
            running_loss += loss.item()
            predicted = torch.round(sigmoid(outputs.data)) # here determining the sigmoid, not included in the model
            
            running_accuracy += (predicted == labels).sum().item() / labels.size(0)   
            
            fp = ((predicted - labels) == 1.).sum().item() 
            fn = ((predicted - labels) == -1.).sum().item()
            tp = ((predicted + labels) == 2.).sum().item()
            tn = ((predicted + labels) == 0.).sum().item()
            running_tp += tp
            running_fp += fp
            running_tn += tn
            running_fn += fn
            
        retval = {'loss':running_loss / batch_idx,
                'accuracy':running_accuracy / batch_idx,
                'tp':running_tp,
                'tn':running_tn,
                'fp':running_fp,
                'fn':running_fn
                }
            
        return retval
            

        
    def valid_fn(valid_dataloader, model, criterion):

        running_loss = 0.
        running_accuracy = 0.
        running_tp = 0.
        running_tn = 0.
        running_fp = 0.
        running_fn = 0.

        model.eval()
        
        for batch_idx, (ecg, spo2, labels) in enumerate(valid_dataloader, 1):

            outputs = model(ecg)

            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            predicted = torch.round(sigmoid(outputs.data)) # here determining the sigmoid, not included in the model

            running_accuracy += (predicted == labels).sum().item() / labels.size(0)  
            
            fp = ((predicted - labels) == 1.).sum().item()
            fn = ((predicted - labels) == -1.).sum().item()
            tp = ((predicted + labels) == 2.).sum().item()
            tn = ((predicted + labels) == 0.).sum().item()
            running_tp += tp
            running_fp += fp
            running_tn += tn
            running_fn += fn
            
        retval = {'loss':running_loss / batch_idx,
                'accuracy':running_accuracy / batch_idx,
                'tp':running_tp,
                'tn':running_tn,
                'fp':running_fp,
                'fn':running_fn
                }
            
        return retval
    
    
    
    # Defining data loaders

    train_dataloader = torch.utils.data.DataLoader(training_data_in, batch_size=BATCH_SIZE, shuffle=True, num_workers=1)
    
    if validation_data_in != None:
        validation_dataloader = torch.utils.data.DataLoader(validation_data_in, batch_size=BATCH_SIZE, shuffle=False, num_workers=1)


    # Defining the loss function
    criterion = nn.BCEWithLogitsLoss()
    
    
    # Defining the optimizer
    import torch.optim as optim
    optimizer = optim.AdamW(model.parameters(), lr=3e-4, amsgrad=False, eps=1e-07) 


    # Training code
    
    metrics_history = {"loss":[], "accuracy":[], "precision":[], "recall":[], "f1":[], "specificity":[], "accuracy_bis":[], "tp":[], "tn":[], "fp":[], "fn":[],
                    "val_loss":[], "val_accuracy":[], "val_precision":[], "val_recall":[], "val_f1":[], "val_specificity":[], "val_accuracy_bis":[], "val_tp":[], "val_tn":[], "val_fp":[], "val_fn":[],}
    
    train_begin = time.time()
    for epoch in range(EPOCHS):
        start = time.time()

        print("EPOCH:", epoch+1)

        train_metrics = train_fn(train_dataloader=train_dataloader, 
                                model=model,
                                optimizer=optimizer, 
                                criterion=criterion)
        
        metrics_history["loss"].append(train_metrics["loss"])
        metrics_history["accuracy"].append(train_metrics["accuracy"])
        metrics_history["tp"].append(train_metrics["tp"])
        metrics_history["tn"].append(train_metrics["tn"])
        metrics_history["fp"].append(train_metrics["fp"])
        metrics_history["fn"].append(train_metrics["fn"])
        
        precision = train_metrics["tp"] / (train_metrics["tp"] + train_metrics["fp"]) if train_metrics["tp"] > 0 else 0
        recall = train_metrics["tp"] / (train_metrics["tp"] + train_metrics["fn"]) if train_metrics["tp"] > 0 else 0
        specificity = train_metrics["tn"] / (train_metrics["tn"] + train_metrics["fp"]) if train_metrics["tn"] > 0 else 0
        f1 = 2*precision*recall / (precision + recall) if precision*recall > 0 else 0
        metrics_history["precision"].append(precision)
        metrics_history["recall"].append(recall)
        metrics_history["f1"].append(f1)
        metrics_history["specificity"].append(specificity)
        
        
        
        if validation_data_in != None:    
            # Calculate the metrics on the validation data, in the same way as done for training
            with torch.no_grad(): # don't keep track of the info necessary to calculate the gradients

                val_metrics = valid_fn(valid_dataloader=validation_dataloader, 
                                    model=model,
                                    criterion=criterion)

                metrics_history["val_loss"].append(val_metrics["loss"])
                metrics_history["val_accuracy"].append(val_metrics["accuracy"])
                metrics_history["val_tp"].append(val_metrics["tp"])
                metrics_history["val_tn"].append(val_metrics["tn"])
                metrics_history["val_fp"].append(val_metrics["fp"])
                metrics_history["val_fn"].append(val_metrics["fn"])

                val_precision = val_metrics["tp"] / (val_metrics["tp"] + val_metrics["fp"]) if val_metrics["tp"] > 0 else 0
                val_recall = val_metrics["tp"] / (val_metrics["tp"] + val_metrics["fn"]) if val_metrics["tp"] > 0 else 0
                val_specificity = val_metrics["tn"] / (val_metrics["tn"] + val_metrics["fp"]) if val_metrics["tn"] > 0 else 0
                val_f1 = 2*val_precision*val_recall / (val_precision + val_recall) if val_precision*val_recall > 0 else 0
                metrics_history["val_precision"].append(val_precision)
                metrics_history["val_recall"].append(val_recall)
                metrics_history["val_f1"].append(val_f1)
                metrics_history["val_specificity"].append(val_specificity)


            print("  > Training/validation loss:", round(train_metrics['loss'], 4), round(val_metrics['loss'], 4))
            print("  > Training/validation accuracy:", round(train_metrics['accuracy'], 4), round(val_metrics['accuracy'], 4))
            print("  > Training/validation precision:", round(precision, 4), round(val_precision, 4))
            print("  > Training/validation recall:", round(recall, 4), round(val_recall, 4))
            print("  > Training/validation f1:", round(f1, 4), round(val_f1, 4))
            print("  > Training/validation specificity:", round(specificity, 4), round(val_specificity, 4))
        else:
            print("  > Training loss:", round(train_metrics['loss'], 4))
            print("  > Training accuracy:", round(train_metrics['accuracy'], 4))
            print("  > Training precision:", round(precision, 4))
            print("  > Training recall:", round(recall, 4))
            print("  > Training f1:", round(f1, 4))
            print("  > Training specificity:", round(specificity, 4))


        print("Completed in:", round(time.time() - start, 1), "seconds \n")

    print("Training completed in:", round((time.time()- train_begin)/60, 1), "minutes")    

    
    
    # Save the model weights
    torch.save(model.state_dict(), './nnet_model.pt')
    
    
    # Save the metrics history
    torch.save(metrics_history, 'training_history')

And here is the function that initializes the model and set the seeds, called before each execution of the code of "_run":

def reinit_model():
    torch.manual_seed(42)
    np.random.seed(42)
    random.seed(42)
    net = Net() # the model
    return net
1
"even if we fixed all seeds". Please set code how you handle seed.prosti
Here you are (see Edit above). Although I don't think it's a problem related to the seeds, since the two different kinds of executions are reproducible.Andrea
To me, it does not seem an issue related to weights init, because they say executions are reproducible if we just look at the same kind of execution (training or training+eval).nsacco

1 Answers

1
votes

Ok, I found the issue. The problem is determined by the fact that, apparently, running the evaluation some random seeds are changed, and this affects the training phase.

The solution is thus as follows:

  • at the beginning of function "_run()", set all seeds states to the desired value, e.g., 42. Then, save those seeds to disk.
  • at the beginning of function "train_fn()", read the seeds states from disk, and set them
  • at the end of function "train_fn()", save the seeds states to disk

For instance, running on TPU with XLA, the following instructions have to be used:

  • at the beginning of function "_run()": xm.set_rng_state(42), xm.save(xm.get_rng_state(), 'xm_seed')
  • at the beginning of function "train_fn()": xm.set_rng_state(torch.load('xm_seed'), device=device) (you can also print here the seed for verification purposes with xm.master_print(xm.get_rng_state())
  • at the end of function "train_fn_()": xm.save(xm.get_rng_state(), 'xm_seed')