Summarize the problem
I am working with high dimensional tensors in pytorch and I need to index one tensor with the argmax values from another tensor. So I need to index tensor y
of dim [3,4]
with the results from the argmax of tensor x
with dim [3,4]
. If tensors are:
import torch as T
# Tensor to get argmax from
# expected argmax: [2, 0, 1]
x = T.tensor([[1, 2, 8, 3],
[6, 3, 3, 5],
[2, 8, 1, 7]])
# Tensor to index with argmax from preivous
# expected tensor to retrieve [2, 4, 9]
y = T.tensor([[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11]])
# argmax
x_max, x_argmax = T.max(x, dim=1)
I would like an operation that given the argmax indexes of x
, or x_argmax
, retrieves the values in tensor y
in the same indexes x_argmax
Describe what you’ve tried
This is what I have tried:
# What I have tried
print(y[:, x_argmax])
print(y[..., x_argmax])
I have been reading a lot about numpy indexing, basic indexing, advanced indexing and combined indexing. I have been trying to use combined indexing (since I want a slice in first dimension of the tensor and the indexes values on the second one). But I have not been able to come up with a solution for this use case.