3
votes

I'm trying to figure out how to get confidence intervals from a caret::train linear model.

My first try was just to run predict with the usual lm confidence intervals arguments:

m <- caret::train(mpg ~ poly(hp,2), data=mtcars, method="lm")
predict(m, newdata=mtcars, interval="confidence", level=0.95)

But it looks like the object returned from caret::train doesn't have this implemented.

My second attempt was to extract the finalModel and predict on that:

m <- caret::train(mpg ~ poly(hp,2), data=mtcars, method="lm")
fm <- m$finalModel
predict(fm, newdata=mtcars, interval="confidence", level=0.95)

But I get the error

Error in eval(predvars, data, env) : object 'poly(hp, 2)1' not found

Digging deeper it seems that the final model has some weird representation for the formula and is searching for a 'poly(hp, 2)1' column in my newdata rather than evaluating the formula. The m$finalModel looks like this:

Call:
lm(formula = .outcome ~ ., data = dat)

Coefficients:
   (Intercept)  `poly(hp, 2)1`  `poly(hp, 2)2`  
         20.09          -26.05           13.15

I should add that I'm not just using lm because I'm using caret to fit the model through cross validation.

How can I get the confidence intervals from the linear model fit through caret::train?

1
What output do you get from formula(ms$finalModel)? The output presented is not equivalent to the formula in your first box.Oliver
@Oliver Sorry I had the formula wrong in the question, the output is .outcome ~ `poly(hp, 2)1` + `poly(hp, 2)2` <environment: 0x0000000091b1a6e8>Fish11
It seems you updated, at which point it is more in line. Which version of R and caret are you using? Executing predict(fm, newdata=mtcars, interval="confidence", level=0.95) correctly gives me the prediction intervals.Oliver
R version 3.6.0 (2019-04-26) and caret_6.0-84Fish11
It shouldn't be the problem, but try updating R to 3.6.1. Hopefuly another kind soul can find a better answer.Oliver

1 Answers

4
votes

Disclaimer:

This is a horrible answer, or maybe the caret package just has a horrible implementation of this specific issue. In either case it seems fitting for opening an issue or wish on their github if not already existing (either a wish for more diversified predict functions or fixing the naming used in object$finalModel)

The problem (which occured on second trial) stems from how the caret package internally handles the diverse fitting procedures, basically restricting the predict function for what seems to be cleaning and standardization purposes.

Problem:

The problem is two-fold.

  1. The predict.train does not allow for prediction/confidence intervals
  2. The finalModel contained in the output of train(...) contains a formula that is unusually formatted.

The two problems seems to be stem from the formatting of train and the usage in predict.train. Focusing first on the latter problem, this is apparent by looking at the output from

formula(m$finalModel)
#`.outcome ~ `poly(hp, 2)1` + `poly(hp, 2)2`)

Obviously some formatting is performed while running train, as the expecteed output would be mpg ~ poly(hp, 2), while the output has expanded the RHS (and added quotes/tags) and changed the LHS. As such it would be nice to either fixup the formula, or be able to use the formula.

Looking into how the caret package uses this in the predict.train function reveals the code piece below for newdata input

predict.formula
#output
--more code
if (!is.null(newdata)) {
    if (inherits(object, "train.formula")) {
        newdata <- as.data.frame(newdata)
        rn <- row.names(newdata)
        Terms <- delete.response(object$terms)
        m <- model.frame(Terms, newdata, na.action = na.action, 
            xlev = object$xlevels)
        if (!is.null(cl <- attr(Terms, "dataClasses"))) 
            .checkMFClasses(cl, m)
        keep <- match(row.names(m), rn)
        newdata <- model.matrix(Terms, m, contrasts = object$contrasts)
        xint <- match("(Intercept)", colnames(newdata), 
            nomatch = 0)
        if (xint > 0) 
            newdata <- newdata[, -xint, drop = FALSE]
    }
}
--more code
    out <- predictionFunction(method = object$modelInfo, 
                modelFit = object$finalModel, newdata = newdata, 
                preProc = object$preProcess)

For the less experienced R users, we basically see, that a model.matrix is constructed from scratch without using the output of formula(m$finalModel) (we can use this!), and later some function is called to predict based on the m$finalModel. Looking into predictionFunction from the same package reveals that this function simply calls m$modelInfo$predict(m$finalModel, newdata) (for our example)

Lastly looking at m$modelInfo$predict reveals the below code snippet

m$modelInfo$predict
#output
function(modelFit, newdata, submodels = NULL) {
                    if(!is.data.frame(newdata)) 
                        newdata <- as.data.frame(newdata)
                    predict(modelFit, newdata)
                  }

Note that modelFit = m$finalModel and newdata is made with the output above. Also Note that the call to predict does not allow one to specify interval = "confidence", which is the reason for the first problem.

Fixing the problem (sorta):

A myriad of ways exist for fixing this problem. One is use lm(...) instead of train(...). Another is to utilize the innards of the function to create a data object, that fits the weird model specification, so we can use predict(m$finalModel, newdata = newdata, interval = "confidence") in a way that works as expected.

I choose to do the latter.

caretNewdata <- caretTrainNewdata(m, mtcars)
preds <- predict(m$finalModel, caretNewdata, interval = "confidence")
head(preds, 3)
#output
                         fit      lwr      upr
Mazda RX4           22.03708 20.74297 23.33119
Mazda RX4 Wag       22.03708 20.74297 23.33119
Datsun 710          24.21108 22.77257 25.64960

The function is provided below. for the nerdy, i basically extracted the model.matrix building process from predict.train, predictionFunction and m$modelInfo$predict. I will not promise that this function works for the general case usage of every caret model, but it is a place to start.

caretTrainNewdata function:

caretTrainNewdata <- function(object, newdata, na.action = na.omit){
    if (!is.null(object$modelInfo$library)) 
        for (i in object$modelInfo$library) do.call("requireNamespaceQuietStop", 
                                                    list(package = i))
    if (!is.null(newdata)) {
        if (inherits(object, "train.formula")) {
            newdata <- as.data.frame(newdata)
            rn <- row.names(newdata)
            Terms <- delete.response(object$terms)
            m <- model.frame(Terms, newdata, na.action = na.action, 
                             xlev = object$xlevels)
            if (!is.null(cl <- attr(Terms, "dataClasses"))) 
                .checkMFClasses(cl, m)
            keep <- match(row.names(m), rn)
            newdata <- model.matrix(Terms, m, contrasts = object$contrasts)
            xint <- match("(Intercept)", colnames(newdata), 
                          nomatch = 0)
            if (xint > 0) 
                newdata <- newdata[, -xint, drop = FALSE]
        }
    }
    else if (object$control$method != "oob") {
        if (!is.null(object$trainingData)) {
            if (object$method == "pam") {
                newdata <- object$finalModel$xData
            }
            else {
                newdata <- object$trainingData
                newdata$.outcome <- NULL
                if ("train.formula" %in% class(object) && 
                    any(unlist(lapply(newdata, is.factor)))) {
                    newdata <- model.matrix(~., data = newdata)[, 
                                                                -1]
                    newdata <- as.data.frame(newdata)
                }
            }
        }
        else stop("please specify data via newdata")
    } else
        stop("please specify data data via newdata")
    if ("xNames" %in% names(object$finalModel) & is.null(object$preProcess$method$pca) & 
        is.null(object$preProcess$method$ica)) 
        newdata <- newdata[, colnames(newdata) %in% object$finalModel$xNames, 
                           drop = FALSE]
    if(!is.null(object$preProcess))
       newdata <- predict(preProc, newdata)
    if(!is.data.frame(newdata) && 
      !is.null(object$modelInfo$predict) && 
      any(grepl("as.data.frame", as.character(body(object$modelInfo$predict)))))
           newdata <- as.data.frame(newdata)
    newdata
}