3
votes

I am currently working on converting some code from tensorflow to pytorch, I encountered problem with tf.gather func, there is no direct function to convert it in pytorch.

What I am trying to do is basically indexing, I have two tensors, feature tensor shapes of [minibatch, 60, 2] and indexing tensor [minibatch, 8], say like first tensor is tensor A, and the second one is B.

In Tensorflow, it is directly converted with tf.gather(A, B, batch_dims=1)

How do I achieve this in pytorch?

I have tried A[B] indexing. This seems not work

and A[0]B[0] works, but output of shape is [8, 2]

I need the shape of [minibatch, 8, 2]

It will probably work if I stack tensor like [stack, 8, 2] but I have no idea how to do it

tensorflow
out = tf.gather(logits, indices, batch_dims=1)
pytorch
out = A[B] -> something like this will be great

Output shape of [minibatch, 8, 2]

1

1 Answers

1
votes

I think you are looking for torch.gather

out = torch.gather(A, 1, B[..., None].expand(*B.shape, A.shape[-1]))