0
votes

I am trying to learn the weights of a 3x3 conv2d layer accepting 3 channels and outputting 3 channels. For this discussion consider bias=0 in each case. However, the weights of the conv layer are learned indirectly. I have a 2 layered Multi layer perception having 9 nodes in first layer and 9 in the second. The weights of the 2d conv layer are then precisely the weights learned using this MLP i.e. nn.Linear(9,9). I understand in this case I will have to use nn.functional.conv2d(input,weight). But how exactly to extract the weights from MLP and use it for convolution is not clear and can think of the following.

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        m=nn.Linear(9,9)
    def forward(self, x):
        # some operations involving MLP `m`
        return nn.Functional.conv2d(x,m.weight)

Can some one provide a short, dummy code in PyTorch to achieve this training configuration allowing backpropagation?

1

1 Answers

2
votes

A convolution from 3 input channels to 3 output channels with kernel_size=3 has 81 weights (and not 9). You can reduce this number to 27 if you use groups=3.

you can do the following:

class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.hyper = nn.Linear(9, 9)  # output the required number of parameters

  def forward(self, x):
    # do stuff with self.hyper(x)  
    y = nn.Functional.conv2d(x, self.hyper.weight.reshape((3, 3, 3, 3)))  # add padding and other parameters
    return y