I'm reimplementing the paper Learning Image Matching by Simply Watching Video using tensorflow and I'm facing some serious performance issues when grabbing the gradients from the network. To quickly recap what they do in the paper, they have the trained network, they do 1 forward prop to get the interpolated image and then they do w*h/stride^2 backprops to obtain the gradients of the output wrt the input for each pixel. Due to the high number of backpropagations, this has to be done fairly efficiently in order to get the gradients in a reasonable amount of time (in the paper, 8 minutes, 150 ms for each backprop times 128*384/16 pixels (stride 4 on both rows and columns)). Since in tensorflow the multiple backprops cannot be batched due to gradient aggregation (see for example this discussion), I need to do something like:
for i in range(0, h, stride):
for j in range(0, w, stride):
grad_output[0,i,j,:] = 1 #select current pixel
grad.append(tf.gradients(predictions, images, grad_output))
grad_output[grad_output != 0] = 0
to get the symbolic gradients for each pixel, where predictions is the output tensor of the network and images is the input, declared as a in gpu constant:
with tf.device('/gpu:0'):
images = tf.constant(inp, dtype=tf.float32)
where inp is the actual numpy array containing the data.
Every call to tf.gradients
alone takes around 0.35 ms, which is already too much compared to what the authors report in the paper. But the largest amount of time is spent in evaluating the symbolic gradient, something like:
for i in range(0, len(grad)):
res = sess.run(grad[i])
This takes around 1.5 seconds, really slow. Now, subsequent calls to sess.run(grad[i])
(with the same index i
) are really fast, around 100 ms, while running the for loop changing i
at every iteration results in around 1.5 seconds per iteration. After seeing this behavior, my guess is that there is a big overhead in moving stuff to the GPU, is it possible? If this is the case, how can I avoid it? I already moved the images
tensor to a GPU constant instead of using a placeholder and relying of the feed_dict
in the sess.run
, but that didn't have any visible impact on the performance. Any ideas to speed up the evaluation of the symbolic gradients? I feel I'm missing something simple here since 1 backprop taking 1.5 seconds is really far from any realistic scenario (training the network was able to process around 100 samples per seconds for example, so it's not an architecture problem I guess..)
Thanks!