5
votes

I need to get values in a ragged tensor by indexing along the ragged dimension. Some indexing works ([:, :x], [:, -x:] or [:, x:y]), but not direct indexing ([:, x]):

R = tf.RaggedTensor.from_tensor([[1, 2, 3], [4, 5, 6]])
print(R[:, :2]) # RaggedTensor([[1, 2], [4, 5]])
print(R[:, 1:2]) # RaggedTensor([[2], [5]])
print(R[:, 1])  # ValueError: Cannot index into an inner ragged dimension.

The documentation explains why this fails:

RaggedTensors supports multidimensional indexing and slicing, with one restriction: indexing into a ragged dimension is not allowed. This case is problematic because the indicated value may exist in some rows but not others. In such cases, it's not obvious whether we should (1) raise an IndexError; (2) use a default value; or (3) skip that value and return a tensor with fewer rows than we started with. Following the guiding principles of Python ("In the face of ambiguity, refuse the temptation to guess" ), we currently disallow this operation.

This makes sense, but how do I actually implement options 1, 2 and 3? Must I convert the ragged array into a Python array of Tensors, and manually iterate over them? Is there a more efficient solution? One that would work 100% in a TensorFlow graph, without going through the Python interpreter?

1

1 Answers

1
votes

If you have a 2D RaggedTensor, then you can get behavior (3) with:

def get_column_slice_v3(rt, column):
  assert column >= 0  # Negative column index not supported
  slice = rt[:, column:column+1]
  return slice.flat_values

And you can get behavior (1) by adding an assertion that rt.nrows() == tf.size(slice.flat_values):

def get_column_slice_v1(rt, column):
  assert column >= 0  # Negative column index not supported
  slice = rt[:, column:column+1]
  with tf.assert_equal(rt.nrows(), tf.size(slice.flat_values):
    return tf.identity(slice.flat_values)

To get behavior (2), I think the easiest way is probably to concatenate a vector of default values and then slice again:

def get_colum_slice_v2(rt, column, default=None):
  assert column >= 0  # Negative column index not supported
  slice = rt[:, column:column+1]
  if default is None:
    defaults = tf.zeros([slice.nrows(), 1], slice.dtype)
  ele:
    defaults = tf.fill([slice.nrows(), 1], default)
  slice_plus_default = tf.concat([rt, defaults], axis=1)
  slice2 = slice_plus_defaults[:1]
  return slice2.flat_values

It's possible to extend these to support higher-dimensional ragged tensors, but the logic gets a little more complicated. Also it should be possible to extend these to support negative column indices.