0
votes

I previously asked: PyTorch tensors: new tensor based on old tensor and indices

I have the same problem now but need to use a 2d index tensor.

I have a tensor col of size [batch_size, k] with values between 0 and k-1:

idx = tensor([[0,1,2,0],
        [0,3,2,2],
        ...])

and the following matrix:

x = tensor([[[0, 9],
 [1, 8],
 [2, 3],
 [4, 9]],
 [[0, 0],
 [1, 2],
 [3, 4],
 [5, 6]]])

I want to create a new tensor which contains the rows specified in index, in that order. So I want:

tensor([[[0, 9],
 [1, 8],
 [2, 3],
 [0, 9]],
 [[0, 0],
 [5, 6],
 [3, 4],
 [3, 4]]])

Currently I'm doing it like this:

for i, batch in enumerate(t):
    t[i] = batch[col[i]]

How can I do it more efficiently?

2

2 Answers

2
votes

you should use torch gather to achieve this. It would actually also work for the otehr question you linked, but this is left as an exercise to the reader :p

Let us call idx your first tensor and source the second one. Their respective dimensions are (B,N) and (B, K, p) (with p=2 in your example), and all values of idx are between 0 and K-1.

So to use torch gather, we first need to express your operation as a nested for loop. In your case, what you actually want to achieve is :

for b in range(B):
    for i in range(N):
        for j in range(p):
            # This kind of nested for loops is what torch.gether actually does
            target[b,i,j] = source[b, idx[b,i,j], j]

But that does not work because idx is a 2D tensor, not a 3D one. Well, no big deal, let's make it a 3D tensor. We want it to have shape (B, N, p) and be actually constant along the last dimension. Then we can replace the for loop with a call to gather:

reshaped_idx = idx.unsqueeze(-1).repeat(1,1,2)
target = source.gather(1, reshaped_idx)
# or : target = torch.gather(source, 1, reshaped_idx)
1
votes

You can gather the indices, after a slight manipulation to get it into a compatable shape:

>>> new_idx = idx.unsqueeze(-1).expand_as(x)
>>> x.gather(1, new_idx)

tensor([[[0, 9],
         [1, 8],
         [2, 3],
         [0, 9]],

        [[0, 0],
         [5, 6],
         [3, 4],
         [3, 4]]])