Pytorch tensors implement numpy style broadcast semantics which will work for this problem.
It's not clear from the question if you want to perform matrix multiplication or element-wise multiplication. In the length 2 case that you showed the two are equivalent, but this is certainly not true for higher dimensionality! Thankfully the code is almost the same so I'll just give both options.
A = torch.FloatTensor([[1, 2], [3, 4]])
B = torch.FloatTensor([[0, 0], [1, 1], [2, 2]])
# matrix multiplication
C_mm = (A.T[:, None, :, None] @ B[None, :, None, :]).flatten(0, 1)
# element-wise multiplication
C_ew = (A.T[:, None, :, None] * B[None, :, None, :]).flatten(0, 1)
Code description. A.T
transposes A
and the indexing with None
inserts unitary dimensions so A.T[:, None, :, None]
will be shape (2, 1, 2, 1)
and B[None, :, None, :]
is shape (1, 3, 1, 2)
. Since @
(matrix multiplication) operates on the last two dimensions of tensors, and broadcasts the other dimensions, then the result is matrix multiplication for each column of A
times each row of B
. In the element-wise case the broadcasting is performed on every dimension. The result is a (2, 3, 2, 2)
tensor. To turn it into a (6, 2, 2)
tensor we just flatten the first two dimensions using Tensor.flatten
.