Suppose I have a nested tensor A:
import torch.nn as nn
X = np.array([[1, 3, 2], [2, 3, 5], [1, 2, 3]])
X = torch.DoubleTensor(X)
rows = X.shape[0]
cols = X.shape[1]
A = torch.matmul(X.view(rows, cols, 1),
X.view(rows, 1, cols))
A
Output:
tensor([[[ 1., 3., 2.],
[ 3., 9., 6.],
[ 2., 6., 4.]],
[[ 4., 6., 10.],
[ 6., 9., 15.],
[10., 15., 25.]],
[[ 1., 2., 3.],
[ 2., 4., 6.],
[ 3., 6., 9.]]], dtype=torch.float64)
And I have another tensor B:
B = torch.DoubleTensor([[11., 21, 31], [31, 51, 31], [41, 51, 21]])
B
Output:
tensor([[11., 21., 31.],
[31., 51., 31.],
[41., 51., 21.]])
How do I use torch.einsum()
to find the trace value between the dot product of each of the nested tensor in A and tensor B. For eg. the trace value of the dot product between the 1st nested tensor in A:
[[ 1., 3., 2.],
[ 3., 9., 6.],
[ 2., 6., 4.]]
and B:
tensor([[11., 21., 31.],
[31., 51., 31.],
[41., 51., 21.]])
and similarly with the other 2 nested tensors in A.
My results tensor will be a tensor with just 3 trace values. Is there a way to do this without looping over each of the nested tensor in A (with say a for loop)?
Ps:
I know the code to find the trace value between the dot product of 2 tensors is:
torch.einsum('ij,ji->', X, Y).item()
If you know how to do this with numpy.einsum()
, please let me know too. I might just need to tweak numpy.einsum()
a little to make it work for PyTorch tensors.