3
votes

Is it possible to have ROC curve for training set and test set separately for each fold in 5 fold cross validation in Caret?

library(caret)
train_control <- trainControl(method="cv", number=5,savePredictions =  TRUE,classProbs = TRUE)
output <- train(Species~., data=iris, trControl=train_control, method="rf")

I can do the following but I do not know if it returns ROC for training set of Fold1 or for test set:

library(pROC) 
selectedIndices <- rfmodel$pred$Resample == "Fold1"
plot.roc(rfmodel$pred$obs[selectedIndices],rfmodel$pred$setosa[selectedIndices])
1

1 Answers

3
votes

It is true that the documentation is not at all clear regarding the contents of rfmodel$pred - I would bet that the predictions included are for the fold used as a test set, but I cannot point to any evidence in the docs; nevertheless, and regardless of this, you are still missing some points in the way you are trying to get the ROC.

First, let's isolate rfmodel$pred in a separate dataframe for easier handling:

dd <- rfmodel$pred

nrow(dd)
# 450

Why 450 rows? It is because you have tried 3 different parameter sets (in your case just 3 different values for mtry):

rfmodel$results
# output:
  mtry Accuracy Kappa AccuracySD    KappaSD
1    2     0.96  0.94 0.04346135 0.06519202
2    3     0.96  0.94 0.04346135 0.06519202
3    4     0.96  0.94 0.04346135 0.06519202

and 150 rows X 3 settings = 450.

Let's have a closer look at the contents of rfmodel$pred:

head(dd)

# result:
    pred    obs setosa versicolor virginica rowIndex mtry Resample
1 setosa setosa  1.000      0.000         0        2    2    Fold1
2 setosa setosa  1.000      0.000         0        3    2    Fold1
3 setosa setosa  1.000      0.000         0        6    2    Fold1
4 setosa setosa  0.998      0.002         0       24    2    Fold1
5 setosa setosa  1.000      0.000         0       33    2    Fold1
6 setosa setosa  1.000      0.000         0       38    2    Fold1
  • Column obs contains the true values
  • The three columns setosa, versicolor, and virginica contain the respective probabilities calculated for each class, and they sum up to 1 for each row
  • Column pred contains the final prediction, i.e. the class with the maximum probability from the three columns mentioned above

If this were the whole story, your way of plotting the ROC would be OK, i.e.:

selectedIndices <- rfmodel$pred$Resample == "Fold1"
plot.roc(rfmodel$pred$obs[selectedIndices],rfmodel$pred$setosa[selectedIndices])

But this is not the whole story (the mere existence of 450 rows instead of just 150 should have given a hint already): notice the existence of a column named mtry; indeed, rfmodel$pred includes the results for all runs of cross-validation (i.e. for all the parameter settings):

tail(dd)
# result:
         pred       obs setosa versicolor virginica rowIndex mtry Resample
445 virginica virginica      0      0.004     0.996      112    4    Fold5
446 virginica virginica      0      0.000     1.000      113    4    Fold5
447 virginica virginica      0      0.020     0.980      115    4    Fold5
448 virginica virginica      0      0.000     1.000      118    4    Fold5
449 virginica virginica      0      0.394     0.606      135    4    Fold5
450 virginica virginica      0      0.000     1.000      140    4    Fold5

This is the ultimate reason why your selectedIndices calculation is not correct; it should also include a specific choice of mtry, otherwise the ROC does not make any sense, since it "aggregates" more than one model:

selectedIndices <- rfmodel$pred$Resample == "Fold1" & rfmodel$pred$mtry == 2

--

As I said in the beginning, I bet that the predictions in rfmodel$pred are for the folder as a test set; indeed, if we compute manually the accuracies, they coincide with the ones reported in rfmodel$results shown above (0.96 for all 3 settings), which we know are for the folder used as test (arguably, the respective training accuracies are 1.0):

for (i in 2:4) {  # mtry values in {2, 3, 4}

acc = (length(which(dd$pred == dd$obs & dd$mtry==i & dd$Resample=='Fold1'))/30 +
    length(which(dd$pred == dd$obs & dd$mtry==i & dd$Resample=='Fold2'))/30 +
    length(which(dd$pred == dd$obs & dd$mtry==i & dd$Resample=='Fold3'))/30 +
    length(which(dd$pred == dd$obs & dd$mtry==i & dd$Resample=='Fold4'))/30 +
    length(which(dd$pred == dd$obs & dd$mtry==i & dd$Resample=='Fold5'))/30
)/5

print(acc) 
}

# result:
[1] 0.96
[1] 0.96
[1] 0.96