I'm currently analyzing how gradients develop over the course of training of a CNN using Tensorflow 2.x. What I want to do is compare each gradient in a batch to the gradient resulting for the whole batch. At the moment I use this simple code snippet for each training step:
[...]
loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
[...]
# One training step
# x_train is a batch of input data, y_train the corresponding labels
def train_step(model, optimizer, x_train, y_train):
# Process batch
with tf.GradientTape() as tape:
batch_predictions = model(x_train, training=True)
batch_loss = loss_object(y_train, batch_predictions)
batch_grads = tape.gradient(batch_loss, model.trainable_variables)
# Do something with gradient of whole batch
# ...
# Process each data point in the current batch
for index in range(len(x_train)):
with tf.GradientTape() as single_tape:
single_prediction = model(x_train[index:index+1], training=True)
single_loss = loss_object(y_train[index:index+1], single_prediction)
single_grad = single_tape.gradient(single_loss, model.trainable_variables)
# Do something with gradient of single data input
# ...
# Use batch gradient to update network weights
optimizer.apply_gradients(zip(batch_grads, model.trainable_variables))
train_loss(batch_loss)
train_accuracy(y_train, batch_predictions)
My main problem is that computation time explodes when calculating each of the gradients single-handedly although these calculations should have already been done by Tensorflow when calculating the batch's gradient. The reason is that GradientTape as well as compute_gradients always return a single gradient no matter whether single or several data points were given. So this computation has to be done for each data point.
I know that I could compute the batch's gradient to update the network by using all the single gradients calculated for each data point but this plays only a minor role in saving computation time.
Is there a more efficient way to compute single gradients?