I think you can emulate this with tf.gather_nd
. You will just have to convert "your" indices to a representation that is suitable for tf.gather_nd
. The following example here is tied to your specific example, i.e. input tensors of shape (2, 2, 2)
but I think this gives you an idea how you could write the conversion for input tensors with arbitrary shape, although I am not sure how easy it would be to implement this (haven't thought about it too long). Also, I'm not claiming that this is the easiest possible solution.
import tensorflow as tf
import numpy as np
values = np.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]])
values_tf = tf.constant(values)
indices = np.array([[1, 0], [0, 0]])
converted_idx = []
for k in range(values.shape[0]):
outer = []
for l in range(values.shape[1]):
inds = [k, l, indices[k][l]]
with tf.Session() as sess:
result = tf.gather_nd(values_tf, converted_idx)
This prints
[[1 2]
[4 6]]
Edit: To handle arbitrary shapes here is a recursive solution that should work (only tested on your example):
def convert_idx(last_dim_vals, ori_indices, access_to_ori, depth):
if depth == len(last_dim_vals.shape) - 1:
inds = access_to_ori + [ori_indices[tuple(access_to_ori)]]
return inds
outer = []
for k in range(ori_indices.shape[depth]):
inds = convert_idx(last_dim_vals, ori_indices, access_to_ori + [k], depth + 1)
return outer
You can use this together with the original code I posted like so:
converted_idx = convert_idx(values, indices, [], 0)
with tf.Session() as sess:
result = tf.gather_nd(values_tf, converted_idx)