I'm trying to implement quantization aware training (QAT) to an existing TensorFlow+NVIDIA DALI pipeline for a dense prediction computer vision task. My training loss first decreases as usual, and then instantaneously jumps to NaN. With the tf.debugging
API in my loss function, I figured out I am getting NaNs in my 'target' tensor from the data pipeline. Curiously, I do not get any NaNs from the (identical) data pipeline when QAT is disabled.
def custom_loss(network_output, target_tensor, ...):
tf.debugging.check_numerics(target_tensor, "target tensor numeric error", name=None)
...
return loss
My suspicion was that I'm simply getting overflows in some occasions where the target tensor has values outside a range that TensorFlow expects. This suspicion is strengthened by the fact that if I normalize the target tensors I do not get NaNs.
But then again, it seems that the model does not actually do any quantization during QAT. From the functional example in the quantization aware training comprehensive guide:
# For deployment purposes, the tool adds `QuantizeLayer` after `InputLayer` so that the
# quantized model can take in float inputs instead of only uint8.
And
"Note that the resulting model is quantization aware but not quantized (e.g. the weights are float32 instead of int8)."
I verified that my loss function gets fed a float32 target tensor with tf.print(target_tensor.dtype)
.
Any thoughts?