0
votes

I am pretty new to tensorflow and I am struggling to get tensorboard to display some of my custom metrics. The model I am working with is a tf.estimator.Estimator, with an associated EstimatorSpec. The first new metric I am trying to log is from my loss function, which is composed of two components: a loss for an age prediction (tf.float32) and a loss for a class prediction (one-hot/multiclass), which I add together to determine a total loss (my model is predicting both a class and an age). The total loss is output just fine during training and shows up on tensorboard, but I would like to track the individual age and the class prediction loss components as well.

I think a solution that is supposed to work is to add a eval_metric_ops argument to the EstimatorSpec as described here (Custom eval_metric_ops in Estimator in Tensorflow). I have not been able to make this approach work, however. I defined a custom metric function that looks like this:

def age_loss_function(labels, ages_pred, ages_true):
  per_sample_age_loss = get_age_loss_per_sample(ages_pred, ages_true) ### works fine

  #### The error happens on this line:
  mean_abs_age_diff, age_loss_update_fn = tf.metrics.Mean(per_sample_age_loss)
  ######

  return mean_abs_age_diff, age_loss_update_fn

eval_metric_ops = {"age_loss": age_loss_function}  #### Want to use this in EstimatorSpec

The instructions seem to say that I need both the error metric and the update function which should both be returned from the tf.metrics command as in examples like the one I linked. But this command fails for me with the error message:

tensorflow.python.framework.errors_impl.OperatorNotAllowedInGraphError: using a `tf.Tensor` as a Python `bool` is not allowed in Graph execution. Use Eager execution or decorate this function with @tf.function.

I am probably just misusing the APIs. If someone can guide me on the proper usage I would really appreciate it. Thanks!

1

1 Answers

0
votes

It looks like the problem was from a version change. I had updated to tensorflow 2.0 while the instructions I was following were from 1.X. Using tf.compat.v1.metrics.mean() instead gets past this problem.