I am currently attempting to train a PyTorch CNN to classify demented and non-demented individuals based on MRI scans. However, during training, the loss of the model remains constant and the accuracy, while attempting to differentiate 3 classes, remains at .333. I have tried many of the suggestions offered by respondents to similar questions, yet none of them have worked for my specific task. These pieces of advice included changing the amount of convolutional units in the model, trying different loss functions, training the model on the raw dataset and then scaling up to a larger, augmented set of images, and altering parameters such as learning rate and batch size. I have attached my code and examples of input imagery below.
Image Examples
Mild Cognitive Impairment Brain
Preprocessing Code
torch.cuda.set_device(0)
g = True
if g == True:
for f in final_MRI_data:
path = os.path.join(final_MRI_dir, f)
matrix = nib.load(path)
matrix.get_fdata()
matrix = matrix.get_fdata()
matrix.shape
slice_ = matrix[90, :, :]
img = Image.fromarray(slice_)
img = img.crop((left, top, right, bottom))
img = ImageOps.grayscale(img)
data_matrices.append(img)
postda_data = []
for image in data_matrices:
for i in range(30):
transformed_img = transforms(image)
transformed_img = np.asarray(transformed_img)
postda_data.append(transformed_img)
final_MRI_labels = list(itertools.chain.from_iterable(itertools.repeat(x, 30) for x in
final_MRI_labels))
X = torch.Tensor(np.asarray([i for i in postda_data])).view(-1, 145, 200)
print(X.size())
y = torch.Tensor([i for i in final_MRI_labels]) #Target labels for cross entropy loss function
z = []
for val in final_MRI_labels:
z.append(np.eye(3)[val])
z = torch.Tensor(np.asarray(z)) #Target one-hot encoded matrices for model testing function
Network Class
class Hl_Model(nn.Module):
torch.cuda.set_device(0)
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, stride=2)
self.conv2 = nn.Conv2d(32, 64, 3, stride=2)
self.conv3 = nn.Conv2d(64, 128, 3, stride=2)
self.conv4 = nn.Conv2d(128, 256, 3, stride=2)
x = torch.randn(145,200).view(-1,1,145,200)
self._to_linear = None
self.convs(x)
self.fc1 = nn.Linear(self._to_linear, 128, bias=True)
self.fc2 = nn.Linear(128, 3)
def convs(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = F.max_pool2d(F.relu(self.conv4(x)), (2, 2), stride=2)
if self._to_linear is None:
self._to_linear = x[0].shape[0]*x[0].shape[1]*x[0].shape[2]
return x
def forward(self, x):
x = self.convs(x)
x = x.view(-1, self._to_linear)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.softmax(x, dim=1)
Training Function
def train(net, train_fold_x, train_fold_y):
optimizer = optim.Adam(net.parameters(), lr=0.05)
BATCH_SIZE = 5
EPOCHS = 50
for epoch in range(EPOCHS):
for i in tqdm(range(0, len(train_fold_x), BATCH_SIZE)):
batch_x = train_fold_x[i:i+BATCH_SIZE].view(-1, 1, 145, 200)
batch_y = train_fold_y[i:i+BATCH_SIZE]
batch_x, batch_y = batch_x.to(device), batch_y.to(device)
optimizer.zero_grad()
outputs = net(batch_x)
batch_y = batch_y.long()
loss = loss_func(outputs, batch_y)
loss.backward()
optimizer.step()
print(f"Epoch: {epoch} Loss: {loss}")
Testing Function
def test(net, test_fold_x, test_fold_y):
test_fold_x.to(device)
test_fold_y.to(device)
correct = 0
total = 0
with torch.no_grad():
for i in tqdm(range(len(test_fold_x))):
real_class = torch.argmax(test_fold_y[i]).to(device)
net_out = net(test_fold_x[i].view(-1, 1, 145, 200).to(device))
pred_class = torch.argmax(net_out)
if pred_class == real_class:
correct += 1
total +=1
Cross-Validation Loop
for i in range(6):
result = next(skf.split(X, y))
X_train = X[result[0]]
X_test = X[result[1]]
y_train = y[result[0]]
y_test = z[result[1]]
train(hl_model, X_train, y_train)
test(hl_model, X_test, y_test)
Output During Training:
0%| | 0/188 [00:00<?, ?it/s]
1%| | 1/188 [00:01<05:35, 1.79s/it]
5%|4 | 9/188 [00:01<03:45, 1.26s/it]
9%|8 | 16/188 [00:02<02:32, 1.13it/s]
12%|#2 | 23/188 [00:02<01:43, 1.60it/s]
16%|#6 | 31/188 [00:02<01:09, 2.27it/s]
21%|## | 39/188 [00:02<00:46, 3.20it/s]
25%|##5 | 47/188 [00:02<00:31, 4.49it/s]
30%|##9 | 56/188 [00:02<00:21, 6.26it/s]
35%|###4 | 65/188 [00:02<00:14, 8.67it/s]
39%|###9 | 74/188 [00:02<00:09, 11.85it/s]
44%|####4 | 83/188 [00:02<00:06, 15.92it/s]
49%|####8 | 92/188 [00:02<00:04, 21.01it/s]
54%|#####3 | 101/188 [00:03<00:03, 27.03it/s]
59%|#####8 | 110/188 [00:03<00:02, 33.91it/s]
63%|######3 | 119/188 [00:03<00:01, 41.21it/s]
68%|######8 | 128/188 [00:03<00:01, 48.36it/s]
73%|#######2 | 137/188 [00:03<00:00, 55.36it/s]
78%|#######7 | 146/188 [00:03<00:00, 61.09it/s]
82%|########2 | 155/188 [00:03<00:00, 65.87it/s]
87%|########7 | 164/188 [00:03<00:00, 69.85it/s]
92%|#########2| 173/188 [00:03<00:00, 72.93it/s]
97%|#########6| 182/188 [00:04<00:00, 74.88it/s]
100%|##########| 188/188 [00:04<00:00, 45.32it/s]
Epoch: 0 Loss: 1.5514447689056396
0%| | 0/188 [00:00<?, ?it/s]
5%|4 | 9/188 [00:00<00:02, 85.13it/s]
10%|9 | 18/188 [00:00<00:02, 84.42it/s]
14%|#4 | 27/188 [00:00<00:01, 83.22it/s]
19%|#9 | 36/188 [00:00<00:01, 82.64it/s]
24%|##3 | 45/188 [00:00<00:01, 82.23it/s]
29%|##8 | 54/188 [00:00<00:01, 82.17it/s]
34%|###3 | 63/188 [00:00<00:01, 82.13it/s]
38%|###8 | 72/188 [00:00<00:01, 81.66it/s]
43%|####2 | 80/188 [00:00<00:01, 79.76it/s]
47%|####6 | 88/188 [00:01<00:01, 79.66it/s]
52%|#####1 | 97/188 [00:01<00:01, 80.58it/s]
56%|#####6 | 106/188 [00:01<00:01, 80.36it/s]
61%|######1 | 115/188 [00:01<00:00, 80.64it/s]
66%|######5 | 124/188 [00:01<00:00, 80.84it/s]
71%|####### | 133/188 [00:01<00:00, 80.54it/s]
76%|#######5 | 142/188 [00:01<00:00, 80.98it/s]
80%|######## | 151/188 [00:01<00:00, 80.86it/s]
85%|########5 | 160/188 [00:01<00:00, 80.77it/s]
90%|########9 | 169/188 [00:02<00:00, 78.81it/s]
94%|#########4| 177/188 [00:02<00:00, 78.53it/s]
98%|#########8| 185/188 [00:02<00:00, 77.88it/s]
100%|##########| 188/188 [00:02<00:00, 80.35it/s]
Epoch: 1 Loss: 1.5514447689056396
0%| | 0/188 [00:00<?, ?it/s]
5%|4 | 9/188 [00:00<00:02, 83.56it/s]
10%|9 | 18/188 [00:00<00:02, 82.41it/s]
14%|#3 | 26/188 [00:00<00:01, 81.49it/s]
19%|#8 | 35/188 [00:00<00:01, 81.65it/s]
23%|##3 | 44/188 [00:00<00:01, 81.55it/s]
28%|##7 | 52/188 [00:00<00:01, 80.41it/s]
32%|###1 | 60/188 [00:00<00:01, 79.40it/s]
37%|###6 | 69/188 [00:00<00:01, 80.17it/s]
41%|####1 | 78/188 [00:00<00:01, 80.29it/s]
46%|####6 | 87/188 [00:01<00:01, 80.81it/s]
51%|#####1 | 96/188 [00:01<00:01, 80.95it/s]
55%|#####5 | 104/188 [00:01<00:01, 80.24it/s]
60%|###### | 113/188 [00:01<00:00, 80.56it/s]
65%|######4 | 122/188 [00:01<00:00, 80.56it/s]
70%|######9 | 131/188 [00:01<00:00, 80.78it/s]
74%|#######4 | 140/188 [00:01<00:00, 79.65it/s]
79%|#######9 | 149/188 [00:01<00:00, 80.14it/s]
84%|########4 | 158/188 [00:01<00:00, 80.70it/s]
89%|########8 | 167/188 [00:02<00:00, 80.88it/s]
94%|#########3| 176/188 [00:02<00:00, 81.22it/s]
98%|#########8| 185/188 [00:02<00:00, 81.03it/s]
100%|##########| 188/188 [00:02<00:00, 80.66it/s]
Epoch: 2 Loss: 1.5514447689056396
This output repeats all the way to "Epoch: 49 Loss: 1.5514447689056396"
Thanks in advance for any advice.