1
votes

I am currently attempting to train a PyTorch CNN to classify demented and non-demented individuals based on MRI scans. However, during training, the loss of the model remains constant and the accuracy, while attempting to differentiate 3 classes, remains at .333. I have tried many of the suggestions offered by respondents to similar questions, yet none of them have worked for my specific task. These pieces of advice included changing the amount of convolutional units in the model, trying different loss functions, training the model on the raw dataset and then scaling up to a larger, augmented set of images, and altering parameters such as learning rate and batch size. I have attached my code and examples of input imagery below.

Image Examples

Healthy Brain

Mild Cognitive Impairment Brain

Alzheimer's Brain

Preprocessing Code

torch.cuda.set_device(0)
g = True
if g == True:
    for f in final_MRI_data:
        path = os.path.join(final_MRI_dir, f)
        matrix = nib.load(path)
        matrix.get_fdata()
        matrix = matrix.get_fdata()
        matrix.shape
        slice_ = matrix[90, :, :]
        img = Image.fromarray(slice_)
        img = img.crop((left, top, right, bottom))
        img = ImageOps.grayscale(img)
        data_matrices.append(img)

postda_data = []
for image in data_matrices:
    for i in range(30):
        transformed_img = transforms(image)
        transformed_img = np.asarray(transformed_img)
        postda_data.append(transformed_img)

final_MRI_labels = list(itertools.chain.from_iterable(itertools.repeat(x, 30) for x in 
final_MRI_labels))

X = torch.Tensor(np.asarray([i for i in postda_data])).view(-1, 145, 200)
print(X.size())

y = torch.Tensor([i for i in final_MRI_labels]) #Target labels for cross entropy loss function

z = []
for val in final_MRI_labels:
    z.append(np.eye(3)[val])
z = torch.Tensor(np.asarray(z)) #Target one-hot encoded matrices for model testing function

Network Class

class Hl_Model(nn.Module):

    torch.cuda.set_device(0)

    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(1, 32, 3, stride=2)
        self.conv2 = nn.Conv2d(32, 64, 3, stride=2)
        self.conv3 = nn.Conv2d(64, 128, 3, stride=2)
        self.conv4 = nn.Conv2d(128, 256, 3, stride=2)

        x = torch.randn(145,200).view(-1,1,145,200)
        self._to_linear = None
        self.convs(x)
    
        self.fc1 = nn.Linear(self._to_linear, 128, bias=True)
        self.fc2 = nn.Linear(128, 3)

    def convs(self, x):

        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.max_pool2d(F.relu(self.conv4(x)), (2, 2), stride=2)

        if self._to_linear is None:
            self._to_linear = x[0].shape[0]*x[0].shape[1]*x[0].shape[2]
        return x


    def forward(self, x):
        x = self.convs(x)
        x = x.view(-1, self._to_linear)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.softmax(x, dim=1)

Training Function

def train(net, train_fold_x, train_fold_y):

    optimizer = optim.Adam(net.parameters(), lr=0.05)
    BATCH_SIZE = 5
    EPOCHS = 50
    for epoch in range(EPOCHS):
        for i in tqdm(range(0, len(train_fold_x), BATCH_SIZE)):

            batch_x = train_fold_x[i:i+BATCH_SIZE].view(-1, 1, 145, 200)
            batch_y = train_fold_y[i:i+BATCH_SIZE]
        
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
        
            optimizer.zero_grad()
            outputs = net(batch_x)

            batch_y = batch_y.long()
            loss = loss_func(outputs, batch_y)

            loss.backward()
            optimizer.step()
        
        print(f"Epoch: {epoch} Loss: {loss}")

Testing Function

def test(net, test_fold_x, test_fold_y):

    test_fold_x.to(device)
    test_fold_y.to(device)

    correct = 0
    total = 0
    with torch.no_grad():
        for i in tqdm(range(len(test_fold_x))):
            real_class = torch.argmax(test_fold_y[i]).to(device)
            net_out = net(test_fold_x[i].view(-1, 1, 145, 200).to(device))
            pred_class = torch.argmax(net_out)

            if pred_class == real_class:
                correct += 1
            total +=1

Cross-Validation Loop

for i in range(6):
    result = next(skf.split(X, y))
    X_train = X[result[0]]
    X_test = X[result[1]]
    y_train = y[result[0]]
    y_test = z[result[1]]
    train(hl_model, X_train, y_train)
    test(hl_model, X_test, y_test)

Output During Training:

  0%|          | 0/188 [00:00<?, ?it/s]
  1%|          | 1/188 [00:01<05:35,  1.79s/it]
  5%|4         | 9/188 [00:01<03:45,  1.26s/it]
  9%|8         | 16/188 [00:02<02:32,  1.13it/s]
 12%|#2        | 23/188 [00:02<01:43,  1.60it/s]
 16%|#6        | 31/188 [00:02<01:09,  2.27it/s]
 21%|##        | 39/188 [00:02<00:46,  3.20it/s]
 25%|##5       | 47/188 [00:02<00:31,  4.49it/s]
 30%|##9       | 56/188 [00:02<00:21,  6.26it/s]
 35%|###4      | 65/188 [00:02<00:14,  8.67it/s]
 39%|###9      | 74/188 [00:02<00:09, 11.85it/s]
 44%|####4     | 83/188 [00:02<00:06, 15.92it/s]
 49%|####8     | 92/188 [00:02<00:04, 21.01it/s]
 54%|#####3    | 101/188 [00:03<00:03, 27.03it/s]
 59%|#####8    | 110/188 [00:03<00:02, 33.91it/s]
 63%|######3   | 119/188 [00:03<00:01, 41.21it/s]
 68%|######8   | 128/188 [00:03<00:01, 48.36it/s]
 73%|#######2  | 137/188 [00:03<00:00, 55.36it/s]
 78%|#######7  | 146/188 [00:03<00:00, 61.09it/s]
 82%|########2 | 155/188 [00:03<00:00, 65.87it/s]
 87%|########7 | 164/188 [00:03<00:00, 69.85it/s]
 92%|#########2| 173/188 [00:03<00:00, 72.93it/s]
 97%|#########6| 182/188 [00:04<00:00, 74.88it/s]
100%|##########| 188/188 [00:04<00:00, 45.32it/s]
Epoch: 0 Loss: 1.5514447689056396

  0%|          | 0/188 [00:00<?, ?it/s]
  5%|4         | 9/188 [00:00<00:02, 85.13it/s]
 10%|9         | 18/188 [00:00<00:02, 84.42it/s]
 14%|#4        | 27/188 [00:00<00:01, 83.22it/s]
 19%|#9        | 36/188 [00:00<00:01, 82.64it/s]
 24%|##3       | 45/188 [00:00<00:01, 82.23it/s]
 29%|##8       | 54/188 [00:00<00:01, 82.17it/s]
 34%|###3      | 63/188 [00:00<00:01, 82.13it/s]
 38%|###8      | 72/188 [00:00<00:01, 81.66it/s]
 43%|####2     | 80/188 [00:00<00:01, 79.76it/s]
 47%|####6     | 88/188 [00:01<00:01, 79.66it/s]
 52%|#####1    | 97/188 [00:01<00:01, 80.58it/s]
 56%|#####6    | 106/188 [00:01<00:01, 80.36it/s]
 61%|######1   | 115/188 [00:01<00:00, 80.64it/s]
 66%|######5   | 124/188 [00:01<00:00, 80.84it/s]
 71%|#######   | 133/188 [00:01<00:00, 80.54it/s]
 76%|#######5  | 142/188 [00:01<00:00, 80.98it/s]
 80%|########  | 151/188 [00:01<00:00, 80.86it/s]
 85%|########5 | 160/188 [00:01<00:00, 80.77it/s]
 90%|########9 | 169/188 [00:02<00:00, 78.81it/s]
 94%|#########4| 177/188 [00:02<00:00, 78.53it/s]
 98%|#########8| 185/188 [00:02<00:00, 77.88it/s]
100%|##########| 188/188 [00:02<00:00, 80.35it/s]
Epoch: 1 Loss: 1.5514447689056396

  0%|          | 0/188 [00:00<?, ?it/s]
  5%|4         | 9/188 [00:00<00:02, 83.56it/s]
 10%|9         | 18/188 [00:00<00:02, 82.41it/s]
 14%|#3        | 26/188 [00:00<00:01, 81.49it/s]
 19%|#8        | 35/188 [00:00<00:01, 81.65it/s]
 23%|##3       | 44/188 [00:00<00:01, 81.55it/s]
 28%|##7       | 52/188 [00:00<00:01, 80.41it/s]
 32%|###1      | 60/188 [00:00<00:01, 79.40it/s]
 37%|###6      | 69/188 [00:00<00:01, 80.17it/s]
 41%|####1     | 78/188 [00:00<00:01, 80.29it/s]
 46%|####6     | 87/188 [00:01<00:01, 80.81it/s]
 51%|#####1    | 96/188 [00:01<00:01, 80.95it/s]
 55%|#####5    | 104/188 [00:01<00:01, 80.24it/s]
 60%|######    | 113/188 [00:01<00:00, 80.56it/s]
 65%|######4   | 122/188 [00:01<00:00, 80.56it/s]
 70%|######9   | 131/188 [00:01<00:00, 80.78it/s]
 74%|#######4  | 140/188 [00:01<00:00, 79.65it/s]
 79%|#######9  | 149/188 [00:01<00:00, 80.14it/s]
 84%|########4 | 158/188 [00:01<00:00, 80.70it/s]
 89%|########8 | 167/188 [00:02<00:00, 80.88it/s]
 94%|#########3| 176/188 [00:02<00:00, 81.22it/s]
 98%|#########8| 185/188 [00:02<00:00, 81.03it/s]
100%|##########| 188/188 [00:02<00:00, 80.66it/s]
Epoch: 2 Loss: 1.5514447689056396

This output repeats all the way to "Epoch: 49 Loss: 1.5514447689056396"

Thanks in advance for any advice.

1
Could you provide the output during your training? In addition, I would like to see the definition of the loss function you were using in this case.TQCH
@TQCH Loss Function: loss_func = nn.CrossEntropyLoss(); Output: Epoch: 1 Loss: 1.5514447689056396, Epoch: 2 Loss: 1.5514447689056396....Epoch: 49 Loss: 1.5514447689056396Jack Diskin

1 Answers

2
votes

It seems that the problem is due to the softmax activation at the last step of model forward and your loss function, loss_func = nn.CrossEntropyLoss() which actually takes raw logits instead. Please check the official documentation:

class torch.nn.CrossEntropyLoss(weight: Optional[torch.Tensor] = None, size_average=None, ignore_index: int = -100, reduce=None, reduction: str = 'mean')

and

This criterion combines nn.LogSoftmax() and nn.NLLLoss() in one single class. The input is expected to contain raw, unnormalized scores for each class.