I am new to PyTorch and am still wrapping my head around how to form a proper gather statement. I have a 4D input tensor of size (1,200,61,1632), where 1632 is the time dimension. I want to index it with a tensor idx which is size (4,1632) where each row of idx is a value I want to extract from the input tensor. So the rows of idx look like:
[0,20,30,0]
[0,150,9,1]
[0,180,100,2]
...
So that the output has size 1632. In other words I want to do this:
output = []
for i in range(1632):
output.append(input[idx[0,i], idx[1,i], idx[2,i], idx[3,i]])
Is this an appropriate use case for torch.gather? Looking at the documentation for gather, it says the input and index tensors must have the same shape.