I have a 2D tensor A
with shape [batch_size, D]
, and a 1D tensor B
with shape [batch_size]
. Each element of B
is a column index of A
, for each row of A
, eg. B[i] in [0,D)
.
What is the best way in tensorflow to get the values A[B]
For example:
A = tf.constant([[0,1,2],
[3,4,5]])
B = tf.constant([2,1])
with desired output:
some_slice_func(A, B) -> [2,4]
There is another constraint. In practice, batch_size
is actually None
.
Thanks in advance!