I'm using this tutorial to learn how to train a model on the MNIST dataset here: https://www.tensorflow.org/tutorials/quickstart/beginner
Currently, the model only trains on the accuracy, but I want to figure out the F1-score of the model (starting with precision and recall first).
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10)
])
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(optimizer='adam',
loss=loss_fn,
metrics=['accuracy'])
model.fit(x_train, y_train, epochs=5)
Epoch 1/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.2895 - accuracy: 0.9151
Epoch 2/5
1875/1875 [==============================] - 3s 2ms/step - loss: 0.1393 - accuracy: 0.9586
...
Apparently the model also uses log-odd scores which are converted into probabilities by a softmax as well.
This is my problem, though. I tried changing the metrics in model.compile
to metrics=[tf.keras.metrics.Precision()]
, but I got the error ValueError: Shapes (32, 10) and (32, 1) are incompatible
.
I also tried calculating the precisiona and recall through scikit-learn, but my predictions aren't lined up with the true labels.
y_pred = model.predict(x_test)
print(y_pred)
precision_score(y_test, y_pred)
Output:
[[ -4.7507367 -7.4252934 -2.8428416 ... 8.855136 -5.937388
-2.1762638 ]
[ -5.0433793 5.554433 12.963128 ... -18.583 -1.6025407
-18.721622 ]
[ -7.623428 6.3951 -1.8510209 ... 0.37932196 -1.2399373
-6.59459 ]
...
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-44-a82c4d76f544> in <module>()
1 y_pred = model.predict(x_test)
2 print(y_pred)
----> 3 precision_score(y_test, y_pred)
ValueError: Classification metrics can't handle a mix of multiclass and continuous-multioutput targets
I'm thinking I might need to transform y_pred, but I'm not sure how. Or if there is a way to add precision and recall to the metrics that would be even better. How can I get the precision and recall of this model?