2
votes

I have a tensor of shape (16, 4096, 3). I have another tensor of indices of shape (16, 32768, 3). I am trying to collect the values along dim=1. This was initially done in pytorch using gather function as shown below-

# a.shape (16L, 4096L, 3L)
# idx.shape (16L, 32768L, 3L)
b = a.gather(1, idx)
# b.shape (16L, 32768L, 3L)

Please note that the size of output b is the same as that of idx. However, when I apply gather function of tensorflow, I get a completely different output. The output dimension was found mismatching as shown below-

b = tf.gather(a, idx, axis=1)
# b.shape (16, 16, 32768, 3, 3)

I also tried using tf.gather_nd but got in vain. See below-

b = tf.gather_nd(a, idx)
# b.shape (16, 32768)

Why am I getting different shapes of tensors? I want to get the tensor of the same shape as calculated by pytorch.

How to achieve the same result as given by pytorch?

1
any suggestions, please? - Ravi Joshi

1 Answers

0
votes

If I understand you correctly then tf.gather_nd is what you are looking for. If not, please be a little more clear.