16
votes

I have three tensors, A, B and C in tensorflow, A and B are both of shape (m, n, r), C is a binary tensor of shape (m, n, 1).

I want to select elements from either A or B based on the value of C. The obvious tool is tf.select, however that does not have broadcasting semantics, so I need to first explicitly broadcast C to the same shape as A and B.

This would be my first attempt at how to do this, but it doesn't like me mixing a tensor (tf.shape(A)[2]) into the shape list.

import tensorflow as tf
A = tf.random_normal([20, 100, 10])
B = tf.random_normal([20, 100, 10])
C = tf.random_normal([20, 100, 1])
C = tf.greater_equal(C, tf.zeros_like(C))

C = tf.tile(C, [1,1,tf.shape(A)[2]])
D = tf.select(C, A, B)

What's the correct approach here?

4
One hack that works: I can use the broadcasting semantics of multiply and multiply by a ones tensor thus: Expander = tf.ones_like(B), then C = Expander*Cwxs

4 Answers

15
votes

EDIT: In all versions of TensorFlow since 0.12rc0, the code in the question works directly. TensorFlow will automatically stack tensors and Python numbers into a tensor argument. The solution below using tf.pack() is only needed in versions prior to 0.12rc0. Note that tf.pack() was renamed to tf.stack() in TensorFlow 1.0.


Your solution is very close to working. You should replace the line:

C = tf.tile(C, [1,1,tf.shape(C)[2]])

...with the following:

C = tf.tile(C, tf.pack([1, 1, tf.shape(A)[2]]))

(The reason for the issue is that TensorFlow won't implicitly convert a list of tensors and Python literals into a tensor. tf.pack() takes a list of tensors, so it will convert each of the elements in its input (1, 1, and tf.shape(C)[2]) to a tensor. Since each element is a scalar, the result will be a vector.)

4
votes

Here's a dirty hack:

import tensorflow as tf

def broadcast(tensor, shape):
    return tensor + tf.zeros(shape, dtype=tensor.dtype)

A = tf.random_normal([20, 100, 10])
B = tf.random_normal([20, 100, 10])
C = tf.random_normal([20, 100, 1])

C = broadcast(C, A.shape)
D = tf.select(C, A, B)
2
votes

In the newest tensorflow version(2.0), you can use tf.broadcast_to as below:

import tensorflow as tf

A = tf.random_normal([20, 100, 10])
B = tf.random_normal([20, 100, 10])
C = tf.random_normal([20, 100, 1])
C = tf.greater_equal(C, tf.zeros_like(C))
C = tf.broadcast_to(C, A.shape)

D = tf.where(C,A,B)
0
votes
import tensorflow as tf

def broadcast(tensor, shape):
     """Broadcasts ``x`` to have shape ``shape``.
                                                                   |
     Uses ``tf.Assert`` statements to ensure that the broadcast is
     valid.

     First calculates the number of missing dimensions in 
     ``tf.shape(x)`` and left-pads the shape of ``x`` with that many 
     ones. Then identifies the dimensions of ``x`` that require
     tiling and tiles those dimensions appropriately.

     Args:
         x (tf.Tensor): The tensor to broadcast.
         shape (Union[tf.TensorShape, tf.Tensor, Sequence[int]]): 
             The shape to broadcast to.

     Returns:
         tf.Tensor: ``x``, reshaped and tiled to have shape ``shape``.

     """
     with tf.name_scope('broadcast') as scope:
         shape_x = tf.shape(x)
         rank_x = tf.shape(shape0)[0]
         shape_t = tf.convert_to_tensor(shape, preferred_dtype=tf.int32)
         rank_t = tf.shape(shape1)[0]

         with tf.control_dependencies([tf.Assert(
             rank_t >= rank_x,
             ['len(shape) must be >= tf.rank(x)', shape_x, shape_t],
             summarize=255
         )]):
             missing_dims = tf.ones(tf.stack([rank_t - rank_x], 0), tf.int32)

         shape_x_ = tf.concat([missing_dims, shape_x], 0)
         should_tile = tf.equal(shape_x_, 1)

         with tf.control_dependencies([tf.Assert(
             tf.reduce_all(tf.logical_or(tf.equal(shape_x_, shape_t), should_tile),
             ['cannot broadcast shapes', shape_x, shape_t],
             summarize=255
         )]):
             multiples = tf.where(should_tile, shape_t, tf.ones_like(shape_t))
             out = tf.tile(tf.reshape(x, shape_x_), multiples, name=scope)

         try:
             out.set_shape(shape)
         except:
             pass

         return out

A = tf.random_normal([20, 100, 10])
B = tf.random_normal([20, 100, 10])
C = tf.random_normal([20, 100, 1])

C = broadcast(C, A.shape)
D = tf.select(C, A, B)