2
votes

I'm new to tensors and having a headache over this problem:

I have an index tensor of size k with values between 0 and k-1:

tensor([0,1,2,0])

and the following matrix:

tensor([[[0, 9],
     [1, 8],
     [2, 3],
     [4, 9]]])

I want to create a new tensor which contains the rows specified in index, in that order. So I want:

tensor([[[0, 9],
     [1, 8],
     [2, 3],
     [0, 9]]])

Outside tensors I'd do this operation more or less like this:

new_matrix = [matrix[i] for i in index]

How do I do something similar in PyTorch on tensors?

1

1 Answers

1
votes

You use fancy indexing:

from torch import tensor

index = tensor([0,1,2,0])
t = tensor([[[0, 9],
     [1, 8],
     [2, 3],
     [0, 9]]])

result = t[:, index, :]

to get

tensor([[[0, 9],
         [1, 8],
         [2, 3],
         [0, 9]]])

Note that t.shape == (1, 4, 2) and you want to index on the second axis; so we apply it in the second argument and keep the rest the same via :s i.e. [:, index, :].