1
votes

I am trying to build a binary temporal image classifier by combining ResNet18 and an LSTM. However, I have never really used RNNs before and have been struggling on getting the correct output shape.

I am using a batch size of 128 and a sequence size of 32. The images are 80x80 grayscale images.

The current model is:

class CNNLSTM(nn.Module):
    def __init__(self):
        super(CNNLSTM, self).__init__()
        self.resnet = models.resnet18(pretrained=False)
        self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3)
        self.resnet.fc = nn.Sequential(nn.Linear(in_features=512, out_features=256, bias=True))
        
        self.lstm = nn.LSTM(input_size=256, hidden_size=256, num_layers=3)
        self.fc1 = nn.Linear(256, 128)
        self.fc2 = nn.Linear(128, 1)
       
    def forward(self, x_3d):
        #x3d:  torch.Size([128, 32, 1, 80, 80])
        hidden = None
        toret  = []
        for t in range(x_3d.size(1)):
            x = self.resnet(x_3d[:, t, :, :, :])
            
            out, hidden = self.lstm(x.unsqueeze(0), hidden)         
            x = self.fc1(out[-1, :, :])
            x = F.relu(x)
            x = self.fc2(x)
            print("x shape: ", x.shape)
            
            toret.append(x)
        return torch.stack(toret)

Which returns a tensor of shape torch.Size([32, 128, 1]) which, according to what I understand, means that every nth row represents the nth time step of each element in the sequence.

How can I get output of shape 128x1x32 instead?

And is there a better way to do this?

1

1 Answers

0
votes

You could permute the dimensions:

a = torch.rand(32, 128, 1)
a = a.permute(1, 2, 0) # these are the indices of the original dimensions
print(a.shape)
>> torch.Size([128, 1, 32])

But you could also set batch_first=True in the LSTM module:

self.lstm = nn.LSTM(input_size=256, hidden_size=256, num_layers=3, batch_first=True)

This will expect that the input to the LSTM has the shape batch-size x seq-len x features and will output a tensor in the same way.