2
votes

I train an image classifier using Keras up to around 98% test accuracy. Now I know that the overall accuracy is 98%, but i want to know the accuracy/error per distinct class/label.

Has Keras a builtin function for that or would I have to test this myself per class/label?

Update: Thanks @gionni. I didn't know the actual term was "Confusion Matrix". But that's what I am actually looking for. That being said, is there a function to generate one? I have to use Keras 1.2.2 by the way.

2
Do you want something like a confusion matrix for each training step? Or just the overall confusion matrix for test set? I think sklearn implements functions for confusion matricesgionni
Overall confusion matrix. I want to know what label is confused with what other label and how often.Era

2 Answers

3
votes

I had similar issue so I could share my code with you. The following function computes a single class accuracy:

def single_class_accuracy(interesting_class_id):
    def fn(y_true, y_pred):
        class_id_preds = K.argmax(y_pred, axis=-1)
        # Replace class_id_preds with class_id_true for recall here
        positive_mask = K.cast(K.equal(class_id_preds, interesting_class_id), 'int32')
        true_mask = K.cast(K.equal(y_true, interesting_class_id), 'int32')
        acc_mask = K.cast(K.equal(positive_mask, true_mask), 'float32')
        class_acc = K.mean(acc_mask)
        return class_acc

    return fn

Now - if you want to get an accuracy for 0 class you could add it to metrics while compiling a model:

model.compile(..., metrics=[..., single_class_accuracy(0)])

If you want to have all classes accuracy you could type:

model.compile(..., 
    metrics=[...] + [single_class_accuracy(i) for i in range(nb_of_classes)])
-1
votes

There may be better options, but you can use this:

import numpy as np

#gather each true label
distinct, counts = np.unique(trueLabels,axis=0,return_counts=True)

for dist,count in zip(distinct, counts):
    selector = (trueLabels == dist).all(axis=-1)

    selectedX = testData[selector]
    selectedY = trueLabels[selector]

    print('\n\nEvaluating for ' + str(count) + ' occurrences of class ' + str(dist))
    print(model.evaluate(selectedX,selectedY,verbose=0))