I have a custom gradient calculation function which doubles the incoming gradients.
import tensorflow as tf
@tf.RegisterGradient("CustomSquare")
def _custom_square_grad(op, grad):
return grad*2.0
c = tf.constant(3.)
s1 = tf.square(c)
grad1 = tf.gradients(s1, c)[0]
g = tf.get_default_graph()
with g.gradient_override_map({"Square": "CustomSquare"}):
s2 = tf.square(c)
grad2 = tf.gradients(s2, c)[0]
with tf.Session() as sess:
print(sess.run([c, s1, grad1]))
print(sess.run([c, s2, grad2]))
The results I get are surprising:
[3.0, 9.0, 6.0]
[3.0, 9.0, 2.0]
I was expecting the second result to be [3.0, 9.0, 12.0]
. What am I missing?
Thanks.