I am trying to use a 2-D tensor to index a 3-D tensor in Tensorflow. For example, I have x
of shape [2, 3, 4]
:
[[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]]
and I want to index it with another tensor y
of shape [2, 3]
, where each element of y
index the last dimension of x
. For example, if we have y
like:
[[0, 2, 3],
[1, 0, 2]]
The output should of shape [2, 3]
:
[[0, 6, 11],
[13, 16, 22]]