I have a (N, 9, 9) shape tensorflow tensor T, and permutations Px, Py which might look like this: [3 4 5 6 7 8 2 1 0]
, [6 8 2 0 3 7 4 1 5]
.
I want to apply the permutation Px to the 1st axis of T, and Py to the 2nd axis. That is, I want to compute a tensor S defined by
S_i,j,k = T_i,Px(j),Py(k)
To use tf.gather_nd
to construct S I need to construct an indices
tensor such that
indices[i,j,k,0] = i
indices[i,j,k,1] = Px(j)
indices[i,j,k,2] = Py(k)
What's the cleanest way to construct indices
(in Python)?