1
votes

I am new to PyTorch and am still wrapping my head around how to form a proper gather statement. I have a 4D input tensor of size (1,200,61,1632), where 1632 is the time dimension. I want to index it with a tensor idx which is size (4,1632) where each row of idx is a value I want to extract from the input tensor. So the rows of idx look like:

[0,20,30,0]
[0,150,9,1]
[0,180,100,2]
...

So that the output has size 1632. In other words I want to do this:

output = []
for i in range(1632):
  output.append(input[idx[0,i], idx[1,i], idx[2,i], idx[3,i]])

Is this an appropriate use case for torch.gather? Looking at the documentation for gather, it says the input and index tensors must have the same shape.

1

1 Answers

1
votes

Since PyTorch doesn't offer an implementation of ravel_multi_index, the ugly way of doing it is this one:

output = input[idx[0, :], idx[1, :], idx[2, :], idx[3, :]]

In NumPy, you could do this way:

output = np.take(input, np.ravel_multi_index(idx, input.shape))