I have three tensors, A, B and C
in tensorflow, A
and B
are both of shape (m, n, r)
, C
is a binary tensor of shape (m, n, 1)
.
I want to select elements from either A or B based on the value of C
. The obvious tool is tf.select
, however that does not have broadcasting semantics, so I need to first explicitly broadcast C
to the same shape as A and B.
This would be my first attempt at how to do this, but it doesn't like me mixing a tensor (tf.shape(A)[2]
) into the shape list.
import tensorflow as tf
A = tf.random_normal([20, 100, 10])
B = tf.random_normal([20, 100, 10])
C = tf.random_normal([20, 100, 1])
C = tf.greater_equal(C, tf.zeros_like(C))
C = tf.tile(C, [1,1,tf.shape(A)[2]])
D = tf.select(C, A, B)
What's the correct approach here?
Expander = tf.ones_like(B)
, thenC = Expander*C
– wxs