0
votes

Summarize the problem

I am working with high dimensional tensors in pytorch and I need to index one tensor with the argmax values from another tensor. So I need to index tensor y of dim [3,4] with the results from the argmax of tensor xwith dim [3,4]. If tensors are:

import torch as T
# Tensor to get argmax from
# expected argmax: [2, 0, 1]
x = T.tensor([[1, 2, 8, 3],
              [6, 3, 3, 5],
              [2, 8, 1, 7]])

# Tensor to index with argmax from preivous
# expected tensor to retrieve [2, 4, 9]
y = T.tensor([[0,  1,  2,  3],
              [4,  5,  6,  7],
              [8,  9, 10, 11]])
# argmax
x_max, x_argmax = T.max(x, dim=1)

I would like an operation that given the argmax indexes of x, or x_argmax, retrieves the values in tensor y in the same indexes x_argmax indexes.

Describe what you’ve tried

This is what I have tried:

# What I have tried
print(y[x_argmax])
print(y[:, x_argmax])
print(y[..., x_argmax])
print(y[x_argmax.unsqueeze(1)])

I have been reading a lot about numpy indexing, basic indexing, advanced indexing and combined indexing. I have been trying to use combined indexing (since I want a slice in first dimension of the tensor and the indexes values on the second one). But I have not been able to come up with a solution for this use case.

2

2 Answers

1
votes

You are looking for torch.gather:

idx = torch.argmax(x, dim=1, keepdim=true)  # get argmax directly, w/o max
out = torch.gather(y, 1, idx)

Resulting with

tensor([[2],
        [4],
        [9]])
0
votes

How about y[T.arange(3), x_argmax]?

That does the job for me...

Explanation: You take dimensional information away when you invoke T.max(x, dim=1), so this information needs to be restored explicitly.