1
votes

An example

Suppose I have a tensor values with shape (2,2,2)

values = [[[0, 1],[2, 3]],[[4, 5],[6, 7]]]

And a tensor indicies with shape (2,2) which describes what values to be selected in the innermost dimension

indicies = [[1,0],[0,0]]

Then the result will be a (2,2) matrix with these values

result = [[1,2],[4,6]]

What is this operation called in tensorflow and how to do it?

General

Note that the above shape (2,2,2) is only an example, it can be any dimension. Some conditions for this operation:

  • ndim(values) -1 = ndim(indicies)
  • values.shape[:-1] == indicies.shape == result.shape
  • indicies.max() < values.shape[-1] -1
1

1 Answers

0
votes

I think you can emulate this with tf.gather_nd. You will just have to convert "your" indices to a representation that is suitable for tf.gather_nd. The following example here is tied to your specific example, i.e. input tensors of shape (2, 2, 2) but I think this gives you an idea how you could write the conversion for input tensors with arbitrary shape, although I am not sure how easy it would be to implement this (haven't thought about it too long). Also, I'm not claiming that this is the easiest possible solution.

import tensorflow as tf
import numpy as np

values = np.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]])
values_tf = tf.constant(values)
indices = np.array([[1, 0], [0, 0]])

converted_idx = []
for k in range(values.shape[0]):
    outer = []
    for l in range(values.shape[1]):
        inds = [k, l, indices[k][l]]
        outer.append(inds)
        print(inds)
    converted_idx.append(outer)

with tf.Session() as sess:
    result = tf.gather_nd(values_tf, converted_idx)
    print(sess.run(result))

This prints

[[1 2]
 [4 6]]

Edit: To handle arbitrary shapes here is a recursive solution that should work (only tested on your example):

def convert_idx(last_dim_vals, ori_indices, access_to_ori, depth):
    if depth == len(last_dim_vals.shape) - 1:
        inds = access_to_ori + [ori_indices[tuple(access_to_ori)]]
        return inds

    outer = []
    for k in range(ori_indices.shape[depth]):
        inds = convert_idx(last_dim_vals, ori_indices, access_to_ori + [k], depth + 1)
        outer.append(inds)
    return outer

You can use this together with the original code I posted like so:

...
converted_idx = convert_idx(values, indices, [], 0)
with tf.Session() as sess:
    result = tf.gather_nd(values_tf, converted_idx)
    print(sess.run(result))