0
votes

Suppose I have the following tensors:

N = 2
k = 3
d = 2

L = torch.arange(N * k * d * d).view(N, k, d, d)
L
tensor([[[[ 0,  1],
          [ 2,  3]],

         [[ 4,  5],
          [ 6,  7]],

         [[ 8,  9],
          [10, 11]]],


        [[[12, 13],
          [14, 15]],

         [[16, 17],
          [18, 19]],

         [[20, 21],
          [22, 23]]]])


index = torch.Tensor([0,1,0,0]).view(N,-1)
index
tensor([[0., 1.],
        [0., 0.]])

I now would like to use the index tensor to pick out the corresponding matrices on the second dimension, i.e. I would like to get something like:

tensor([[[[ 0,  1],
          [ 2,  3]],

         [[ 4,  5],
          [ 6,  7]]],


        [[[12, 13],
          [14, 15]],

         [[[[12, 13],
          [14, 15]]])

Any idea how I could achieve this? Thank you so much!

1

1 Answers

1
votes

Tensors can be indexed with multiple tensors specified across different dimensions (tuples of tensors), where the i-th element of each tensor are combined to create a tuple of indices, i.e. data[indices_dim0, indices_dim1] results in indexing data[indices_dim0[0], indices_dim1[0]], data[indices_dim0[1], indices_dim1[1]] and so on. They must have the same length len(indices_dim0) == len(indices_dim1).

Let's use the flat version of index (before you applied the view). Each element needs to be matched to the appropriate batch index, which would be [0, 0, 1, 1]. Also index needs to have type torch.long, because floats cannot be used as indices. torch.tensor should be preferred for creating tensors with existing data, since torch.Tensor is an alias for the default tensor type (torch.FloatTensor), whereas torch.tensor automatically uses the data type that represents the given values, but also supports the dtype argument to set the type manually, and is generally more versatile.

# Type torch.long is inferred
index = torch.tensor([0, 1, 0, 0])

# Same, but explicitly setting the type
index = torch.tensor([0, 1, 0, 0], dtype=torch.long)

batch_index = torch.tensor([0, 0, 1, 1])

L[batch_index, index]
# => tensor([[[ 0,  1],
#             [ 2,  3]],
#
#            [[ 4,  5],
#             [ 6,  7]],
#
#            [[12, 13],
#             [14, 15]],
#
#            [[12, 13],
#             [14, 15]]])

The indices are not limited to 1D tensors, but they all need to have the same size and each element is used as one index, for example with 2D tensors the indexing happens as data[indices_dim0[i][j], indices_dim1[i][j]] With 2D tensors it happens to be much simpler to create the batch indices without having to do it manually.

index = torch.tensor([0, 1, 0, 0]).view(N, -1)
# => tensor([[0, 1],
#            [0, 0]])

# Every batch gets its index and is repeated across dim=1
batch_index = torch.arange(N).view(N, 1).expand_as(index)
# => tensor([[0, 0],
#            [1, 1]])

L[batch_index, index]