0
votes

I have a (N, 9, 9) shape tensorflow tensor T, and permutations Px, Py which might look like this: [3 4 5 6 7 8 2 1 0], [6 8 2 0 3 7 4 1 5].

I want to apply the permutation Px to the 1st axis of T, and Py to the 2nd axis. That is, I want to compute a tensor S defined by

S_i,j,k = T_i,Px(j),Py(k)

To use tf.gather_nd to construct S I need to construct an indices tensor such that

indices[i,j,k,0] = i
indices[i,j,k,1] = Px(j)
indices[i,j,k,2] = Py(k)

What's the cleanest way to construct indices (in Python)?

1

1 Answers

1
votes

If I undestand your problem statement correctly, I believe this is what you need.

indices[:,:,:,0] = np.arange(indices.shape[0])
indices[:,:,:,1] = indices[:,Px(np.arange(indices.shape[1]),:,1]
indices[:,:,:,2] = indices[:,:,Py(np.arange(indices.shape[2]),2]

Hard to tell without a minimal reproducible.