7
votes

I'm using cross-validation to evaluate the performance of a classifier with scikit-learn and I want to plot the Precision-Recall curve. I found an example on scikit-learn`s website to plot the PR curve but it doesn't use cross validation for the evaluation.

How can I plot the Precision-Recall curve in scikit learn when using cross-validation?

I did the following but i'm not sure if it's the correct way to do it (psudo code):

for each k-fold:

   precision, recall, _ =  precision_recall_curve(y_test, probs)
   mean_precision += precision
   mean_recall += recall

mean_precision /= num_folds
mean_recall /= num_folds

plt.plot(recall, precision)

What do you think?

Edit:

it doesn't work because the size of precision and recall arrays are different after each fold.

anyone?

2

2 Answers

8
votes

Instead of recording the precision and recall values after each fold, store the predictions on the test samples after each fold. Next, collect all the test (i.e. out-of-bag) predictions and compute precision and recall.

 ## let test_samples[k] = test samples for the kth fold (list of list)
 ## let train_samples[k] = test samples for the kth fold (list of list)

 for k in range(0, k):
      model = train(parameters, train_samples[k])
      predictions_fold[k] = predict(model, test_samples[k])

 # collect predictions
 predictions_combined = [p for preds in predictions_fold for p in preds]

 ## let predictions = rearranged predictions s.t. they are in the original order

 ## use predictions and labels to compute lists of TP, FP, FN
 ## use TP, FP, FN to compute precisions and recalls for one run of k-fold cross-validation

Under a single, complete run of k-fold cross-validation, the predictor makes one and only one prediction for each sample. Given n samples, you should have n test predictions.

(Note: These predictions are different from training predictions, because the predictor makes the prediction for each sample without having been previously seen it.)

Unless you are using leave-one-out cross-validation, k-fold cross validation generally requires a random partitioning of the data. Ideally, you would do repeated (and stratified) k-fold cross validation. Combining precision-recall curves from different rounds, however, is not straight forward, since you cannot use simple linear interpolation between precision-recall points, unlike ROC (See Davis and Goadrich 2006).

I personally calculated AUC-PR using the Davis-Goadrich method for interpolation in PR space (followed by numerical integration) and compared the classifiers using the AUC-PR estimates from repeated stratified 10-fold cross validation.

For a nice plot, I showed a representative PR curve from one of the cross-validation rounds.

There are, of course, many other ways of assessing classifier performance, depending on the nature of your dataset.

For instance, if the proportion of (binary) labels in your dataset is not skewed (i.e. it is roughly 50-50), you could use the simpler ROC analysis with cross-validation:

Collect predictions from each fold and construct ROC curves (as before), collect all the TPR-FPR points (i.e. take the union of all TPR-FPR tuples), then plot the combined set of points with possible smoothing. Optionally, compute AUC-ROC using simple linear interpolation and the composite trapezoid method for numerical integration.

2
votes

This is currently the best way to plot a Precision Recall curve for an sklearn classifier using cross-validation. Best part is, it plots the PR Curves for ALL classes, so you get multiple neat-looking curves as well

from scikitplot.classifiers import plot_precision_recall_curve
import matplotlib.pyplot as plt

clf = LogisticRegression()
plot_precision_recall_curve(clf, X, y)
plt.show()

The function automatically takes care of cross-validating the given dataset, concatenating all out of fold predictions, and calculating the PR Curves for each class + averaged PR Curve. It's a one-line function that takes care of it all for you.

Precision Recall Curves

Disclaimer: Note that this uses the scikit-plot library, which I built.