1
votes

here's some of the convolutional neural network sample code from Pytorch's examples directory on their github: https://github.com/pytorch/examples/blob/master/mnist/main.py

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

If I understand this, we need to flatten the output from the last convolutional layer before we can pass it through a linear layer (fc1). So, looking at this code, we see the input to the first fully connected layer is: 9216.

Where has this number (9216) come from?

1

1 Answers

1
votes

You also need to look at the forward method and the network input shape in order to compute the input shape of the linear/fully-connected layer. In the case of MNIST we have a single channel 28x28 input image. Using the following formulas from the docs you can compute the output shape of each convolution operation. The max-pooling operation follows the same input-output relationship as convolution layers.

Since the shape of the input before flattening is a 64 channel 12x12 feature map, then the total size of the feature is 64*12*12 = 9216.

Input/Output relation for conv2d and max_pool2d operations

Input/Output relation conv

def forward(self, x):
    """ For each line which changes the feature shape additional comment
        indicates <input_shape> -> <output_shape> """
    x = self.conv1(x)                # [1, 28, 28] -> [32, 26, 26]
    x = F.relu(x)
    x = self.conv2(x)                # [32, 26, 26] -> [64, 24, 24]
    x = F.relu(x)
    x = F.max_pool2d(x, 2)           # [64, 24, 24] -> [64, 12, 12]
    x = self.dropout1(x)
    x = torch.flatten(x, 1)          # [64, 12, 12] -> [9216]
    x = self.fc1(x)                  # [9216] -> [128]
    x = F.relu(x)
    x = self.dropout2(x)
    x = self.fc2(x)                  # [128] -> [10]
    output = F.log_softmax(x, dim=1)
    return output