I did not quite well understand how the broadcasting mechanism works in Tensorflow. Assume that we have the following code:
W1_shape = [5, 5, 1, 32]
b1_shape = [32]
x = tf.placeholder(tf.float32)
initial_W1 = tf.truncated_normal(shape=W1_shape, stddev=0.1)
W1 = tf.Variable(initial_W1)
initial_b1 = tf.constant(0.1, shape=b1_shape)
b1 = tf.Variable(initial_b1)
conv1 = tf.nn.conv2d(x, W1, strides=[1, 1, 1, 1], padding='SAME')
conv1_sum = conv1 + b1
y = tf.placeholder(tf.float32)
z = conv1 + y
sess = tf.Session()
# Run init ops
init = tf.global_variables_initializer()
sess.run(init)
while True:
samples, labels, indices = dataset.get_next_batch(batch_size=1000)
samples = samples.reshape((1000, MnistDataSet.MNIST_SIZE, MnistDataSet.MNIST_SIZE, 1))
y_data = np.ones(shape=(1000, 32))
conv1_res, conv1_sum_res, b1_res, z_res=\
sess.run([conv1, conv1_sum, b1, z], feed_dict={x: samples, y: y_data})
if dataset.isNewEpoch:
break
So, I load the MNIST dataset, which consists of 28x28 sized images. The convolution operator uses 32 filters of 5x5 size. I use a batch size of 1000, so data tensor x
has the shape (1000,28,28,1). The tf.nn.conv2d
operation outputs a tensor of the shape (1000,28,28,32). y
is a placeholder, a variable which I add to check Tensorflow's broadcasting mechanism by adding it to (1000,28,28,32) shaped conv1
tensor. In the line y_data = np.ones(shape=(1000, 32))
, I experiment with various tensor shapes for y
. The shapes (28,28), (1000,28) and (1000,32) won't add to conv1
, with the errors of the type:
InvalidArgumentError (see above for traceback): Incompatible shapes: [1000,28,28,32] vs. [28,28]
The shapes (28,32) and (28,28,32) work and broadcast correctly. But according to the broadcasting semantics explained in https://www.tensorflow.org/performance/xla/broadcasting , the first three shapes have to work as well, since they are of the correct order by matching dimensions with the 4D conv1
tensor. For example, (28,28) matches (1000,28,28,32) in the dimensions 1 and 2, (1000,32) matches in the dimensions 0 and 3, just as mentioned in the link. Am I missing or misunderstanding something here? What is the correct broadcasting behavior of Tensorflow in such cases?