3
votes

I have a 2D tensor A with shape [batch_size, D] , and a 1D tensor B with shape [batch_size]. Each element of B is a column index of A, for each row of A, eg. B[i] in [0,D).

What is the best way in tensorflow to get the values A[B]

For example:

A = tf.constant([[0,1,2],
                 [3,4,5]])
B = tf.constant([2,1])

with desired output:

some_slice_func(A, B) -> [2,4]

There is another constraint. In practice, batch_size is actually None.

Thanks in advance!

5

5 Answers

3
votes

I was able to get it working using a linear index:

def vector_slice(A, B):
    """ Returns values of rows i of A at column B[i]

    where A is a 2D Tensor with shape [None, D] 
    and B is a 1D Tensor with shape [None] 
    with type int32 elements in [0,D)

    Example:
      A =[[1,2], B = [0,1], vector_slice(A,B) -> [1,4]
          [3,4]]
    """
    linear_index = (tf.shape(A)[1]
                   * tf.range(0,tf.shape(A)[0]))
    linear_A = tf.reshape(A, [-1])
    return tf.gather(linear_A, B + linear_index)

This feels slightly hacky though.

If anyone knows a better (as in clearer or faster) please also leave an answer! (I won't accept my own for a while)

1
votes

Code for what @Eugene Brevdo said:

def vector_slice(A, B):
    """ Returns values of rows i of A at column B[i]

    where A is a 2D Tensor with shape [None, D]
    and B is a 1D Tensor with shape [None]
    with type int32 elements in [0,D)

    Example:
      A =[[1,2], B = [0,1], vector_slice(A,B) -> [1,4]
          [3,4]]
    """
    B = tf.expand_dims(B, 1)
    range = tf.expand_dims(tf.range(tf.shape(B)[0]), 1)
    ind = tf.concat([range, B], 1)
    return tf.gather_nd(A, ind)
0
votes

the least hacky way is probably to build a proper 2d index by concatenating range(batch_size) and B, to get a batch_size x 2 matrix. then pass this to tf.gather_nd.

0
votes

The simplest approach is to do:

def tensor_slice(target_tensor, index_tensor):
    indices = tf.stack([tf.range(tf.shape(index_tensor)[0]), index_tensor], 1)
    return tf.gather_nd(target_tensor, indices)
0
votes

Consider to use tf.one_hot, tf.math.multiply and tf.reduce_sum to solve it.

e.g.

def vector_slice (inputs, inds, axis = None):
    axis = axis if axis is not None else tf.rank(inds) - 1

    inds = tf.one_hot(inds, inputs.shape[axis])
    for i in tf.range(tf.rank(inputs) - tf.rank(inds)):
        inds = tf.expand_dims(inds, axis = -1)
    
    inds = tf.cast(inds, dtype = inputs.dtype)
    
    x = tf.multiply(inputs, inds)

    return tf.reduce_sum(x, axis = axis)