2
votes

Suppose I have a nested tensor A:

import torch.nn as nn
X = np.array([[1, 3, 2], [2, 3, 5], [1, 2, 3]])
X = torch.DoubleTensor(X)

rows = X.shape[0]
cols = X.shape[1]

A = torch.matmul(X.view(rows, cols, 1),
                 X.view(rows, 1, cols))

A

Output:

tensor([[[ 1.,  3.,  2.],
         [ 3.,  9.,  6.],
         [ 2.,  6.,  4.]],

        [[ 4.,  6., 10.],
         [ 6.,  9., 15.],
         [10., 15., 25.]],

        [[ 1.,  2.,  3.],
         [ 2.,  4.,  6.],
         [ 3.,  6.,  9.]]], dtype=torch.float64)

And I have another tensor B:

B = torch.DoubleTensor([[11., 21, 31], [31, 51, 31], [41, 51, 21]])
B

Output:

tensor([[11., 21., 31.],
        [31., 51., 31.],
        [41., 51., 21.]])

How do I use torch.einsum() to find the trace value between the dot product of each of the nested tensor in A and tensor B. For eg. the trace value of the dot product between the 1st nested tensor in A:

[[ 1.,  3.,  2.],
 [ 3.,  9.,  6.],
 [ 2.,  6.,  4.]]

and B:

tensor([[11., 21., 31.],
        [31., 51., 31.],
        [41., 51., 21.]])

and similarly with the other 2 nested tensors in A.

My results tensor will be a tensor with just 3 trace values. Is there a way to do this without looping over each of the nested tensor in A (with say a for loop)?

Ps:

I know the code to find the trace value between the dot product of 2 tensors is:

torch.einsum('ij,ji->', X, Y).item()

If you know how to do this with numpy.einsum(), please let me know too. I might just need to tweak numpy.einsum() a little to make it work for PyTorch tensors.

1

1 Answers

1
votes

It's quite simple, you need to add the 'batch dimension' of A:

torch.einsum('bij,ji->b', A, B)

The output is

tensor([1346., 3290., 1216.])