I am currently working on converting some code from tensorflow to pytorch, I encountered problem with tf.gather func, there is no direct function to convert it in pytorch.
What I am trying to do is basically indexing, I have two tensors, feature tensor shapes of [minibatch, 60, 2] and indexing tensor [minibatch, 8], say like first tensor is tensor A, and the second one is B.
In Tensorflow, it is directly converted with tf.gather(A, B, batch_dims=1)
How do I achieve this in pytorch?
I have tried A[B] indexing. This seems not work
and A[0]B[0] works, but output of shape is [8, 2]
I need the shape of [minibatch, 8, 2]
It will probably work if I stack tensor like [stack, 8, 2] but I have no idea how to do it
tensorflow
out = tf.gather(logits, indices, batch_dims=1)
pytorch
out = A[B] -> something like this will be great
Output shape of [minibatch, 8, 2]