I am trying to learn the weights of a 3x3 conv2d layer accepting 3 channels and outputting 3 channels. For this discussion consider bias=0 in each case. However, the weights of the conv layer are learned indirectly. I have a 2 layered Multi layer perception having 9 nodes in first layer and 9 in the second. The weights of the 2d conv layer are then precisely the weights learned using this MLP i.e. nn.Linear(9,9). I understand in this case I will have to use nn.functional.conv2d(input,weight). But how exactly to extract the weights from MLP and use it for convolution is not clear and can think of the following.
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
m=nn.Linear(9,9)
def forward(self, x):
# some operations involving MLP `m`
return nn.Functional.conv2d(x,m.weight)
Can some one provide a short, dummy code in PyTorch to achieve this training configuration allowing backpropagation?