I'm trying to implement quantization aware training (QAT) to an existing TensorFlow+NVIDIA DALI pipeline for a dense prediction computer vision task. My training loss first decreases as usual, and then instantaneously jumps to NaN. With the tf.debugging
API in my loss function, I figured out I am getting NaNs in my 'target' tensor from the data pipeline. Curiously, I do not get any NaNs from the (identical) data pipeline when QAT is disabled.
def custom_loss(network_output, target_tensor, ...):
tf.debugging.check_numerics(target_tensor, "target tensor numeric error", name=None)
return loss
My suspicion was that I'm simply getting overflows in some occasions where the target tensor has values outside a range that TensorFlow expects. This suspicion is strengthened by the fact that if I normalize the target tensors I do not get NaNs.
But then again, it seems that the model does not actually do any quantization during QAT. From the functional example in the quantization aware training comprehensive guide:
# For deployment purposes, the tool adds `QuantizeLayer` after `InputLayer` so that the
# quantized model can take in float inputs instead of only uint8.
"Note that the resulting model is quantization aware but not quantized (e.g. the weights are float32 instead of int8)."
I verified that my loss function gets fed a float32 target tensor with tf.print(target_tensor.dtype)
Any thoughts?