I previously asked: PyTorch tensors: new tensor based on old tensor and indices
I have the same problem now but need to use a 2d index tensor.
I have a tensor col of size [batch_size, k] with values between 0 and k-1:
idx = tensor([[0,1,2,0],
[0,3,2,2],
...])
and the following matrix:
x = tensor([[[0, 9],
[1, 8],
[2, 3],
[4, 9]],
[[0, 0],
[1, 2],
[3, 4],
[5, 6]]])
I want to create a new tensor which contains the rows specified in index, in that order. So I want:
tensor([[[0, 9],
[1, 8],
[2, 3],
[0, 9]],
[[0, 0],
[5, 6],
[3, 4],
[3, 4]]])
Currently I'm doing it like this:
for i, batch in enumerate(t):
t[i] = batch[col[i]]
How can I do it more efficiently?