1
votes

Sometimes when I fit a model with caret, I am really just interested to see how it performs using the resampling method I have chosen (e.g. cross-validation).

When I am not interested in the "final model" built on the full training data, I would like to avoid fitting it. It is really just about saving precious minutes multiple times during development.

Is there any way to skip fitting the final model when using caret? I have not seen any relevant arguments in caret::trainControl or caret::train.

1

1 Answers

2
votes

There indeed doesn't seem to be an argument that directly achieves that. There are a couple of candidate solutions, though.

  1. selectionFunction as an argument of trainControl selects the final model based on candidate models' performance (with there being only one candidate when there is no parameter tuning) in terms of accuracy, RMSE, etc. Setting selectionFunction as something like function(x, ...) NA or function(x, ...) NULL fails. However, something like function(x, ...) -1 does partially work: there is no warning or error returned, and a final model is attempted to fit. The final outcome seems to be model-dependent.

  2. Another argument of trainControl of interest is indexFinal:

    an optional vector of integers indicating which samples are used to fit the final model after resampling. If NULL, then entire data set is used.

    Setting it to NA appears to fail with most models, except kNN. Setting it to something like 1:10 fit a final model, if there are few enough parameters, using only those ten observations. Hence, setting it to something like 1:100 should work in many cases and take little time.

  3. You may of course alter the train function itself. In the following I only add an argument fitFinal, which is TRUE by default, and check if it's TRUE when going to fit the final model. If fitFinal == FALSE, then

    finalModel <- list(fit = NULL, preProc = NULL)
    finalTime <- 0
    

    Everything else seems to run smoothly. As to overwrite the actual train.default function, you should afterwards run

    environment(myTrain) <- environment(caret:::train.default)
    assignInNamespace("train.default", myTrain, ns = "caret")
    

    So, we have

    myTrain <- function (x, y, method = "rf", preProcess = NULL, ..., weights = NULL, fitFinal = TRUE,
                         metric = ifelse(is.factor(y), "Accuracy", "RMSE"), maximize = ifelse(metric %in%
                                                                                                c("RMSE", "logLoss", "MAE"), FALSE, TRUE), trControl = trainControl(),
                         tuneGrid = NULL, tuneLength = ifelse(trControl$method ==
                                                                "none", 1, 3))
    {
      startTime <- proc.time()
      rs_seed <- sample.int(.Machine$integer.max, 1L)
      if (is.null(colnames(x)))
        stop("Please use column names for `x`", call. = FALSE)
      if (is.character(y))
        y <- as.factor(y)
      if (!is.numeric(y) & !is.factor(y)) {
        stop("Please make sure `y` is a factor or numeric value.",
             call. = FALSE)
      }
      if (is.list(method)) {
        minNames <- c("library", "type", "parameters", "grid",
                      "fit", "predict", "prob")
        nameCheck <- minNames %in% names(method)
        if (!all(nameCheck))
          stop(paste("some required components are missing:",
                     paste(minNames[!nameCheck], collapse = ", ")),
               call. = FALSE)
        models <- method
        method <- "custom"
      }
      else {
        models <- getModelInfo(method, regex = FALSE)[[1]]
        if (length(models) == 0)
          stop(paste("Model", method, "is not in caret's built-in library"),
               call. = FALSE)
      }
      checkInstall(models$library)
      for (i in seq(along = models$library)) do.call("requireNamespaceQuietStop",
                                                     list(package = models$library[i]))
      if (any(names(models) == "check") && is.function(models$check)) {
        software_check <- models$check(models$library)
      }
      paramNames <- as.character(models$parameters$parameter)
      funcCall <- match.call(expand.dots = TRUE)
      modelType <- get_model_type(y)
      if (!(modelType %in% models$type))
        stop(paste("wrong model type for", tolower(modelType)),
             call. = FALSE)
      if (grepl("^svm", method) & grepl("String$", method)) {
        if (is.vector(x) && is.character(x)) {
          stop("'x' should be a character matrix with a single column for string kernel methods",
               call. = FALSE)
        }
        if (is.matrix(x) && is.numeric(x)) {
          stop("'x' should be a character matrix with a single column for string kernel methods",
               call. = FALSE)
        }
        if (is.data.frame(x)) {
          stop("'x' should be a character matrix with a single column for string kernel methods",
               call. = FALSE)
        }
      }
      if (modelType == "Regression" & length(unique(y)) == 2)
        warning(paste("You are trying to do regression and your outcome only has",
                      "two possible values Are you trying to do classification?",
                      "If so, use a 2 level factor as your outcome column."))
      if (modelType != "Classification" & !is.null(trControl$sampling))
        stop("sampling methods are only implemented for classification problems",
             call. = FALSE)
      if (!is.null(trControl$sampling)) {
        trControl$sampling <- parse_sampling(trControl$sampling)
      }
      if (any(class(x) == "data.table"))
        x <- as.data.frame(x)
      check_dims(x = x, y = y)
      n <- if (class(y)[1] == "Surv")
        nrow(y)
      else length(y)
      parallel_check("RWeka", models)
      parallel_check("keras", models)
      if (!is.null(preProcess) && !(all(names(preProcess) %in%
                                        ppMethods)))
        stop(paste("pre-processing methods are limited to:",
                   paste(ppMethods, collapse = ", ")), call. = FALSE)
      if (modelType == "Classification") {
        classLevels <- levels(y)
        attributes(classLevels) <- list(ordered = is.ordered(y))
        xtab <- table(y)
        if (any(xtab == 0)) {
          xtab_msg <- paste("'", names(xtab)[xtab == 0], "'",
                            collapse = ", ", sep = "")
          stop(paste("One or more factor levels in the outcome has no data:",
                     xtab_msg), call. = FALSE)
        }
        if (trControl$classProbs && any(classLevels != make.names(classLevels))) {
          stop(paste("At least one of the class levels is not a valid R variable name;",
                     "This will cause errors when class probabilities are generated because",
                     "the variables names will be converted to ",
                     paste(make.names(classLevels), collapse = ", "),
                     ". Please use factor levels that can be used as valid R variable names",
                     " (see ?make.names for help)."), call. = FALSE)
        }
        if (metric %in% c("RMSE", "Rsquared"))
          stop(paste("Metric", metric, "not applicable for classification models"),
               call. = FALSE)
        if (!trControl$classProbs && metric == "ROC")
          stop(paste("Class probabilities are needed to score models using the",
                     "area under the ROC curve. Set `classProbs = TRUE`",
                     "in the trainControl() function."), call. = FALSE)
        if (trControl$classProbs) {
          if (!is.function(models$prob)) {
            warning("Class probabilities were requested for a model that does not implement them")
            trControl$classProbs <- FALSE
          }
        }
      }
      else {
        if (metric %in% c("Accuracy", "Kappa"))
          stop(paste("Metric", metric, "not applicable for regression models"),
               call. = FALSE)
        classLevels <- NA
        if (trControl$classProbs) {
          warning("cannnot compute class probabilities for regression")
          trControl$classProbs <- FALSE
        }
      }
      if (trControl$method == "oob" & is.null(models$oob))
        stop("Out of bag estimates are not implemented for this model",
             call. = FALSE)
      trControl <- withr::with_seed(rs_seed, make_resamples(trControl,
                                                            outcome = y))
      if (is.logical(trControl$savePredictions)) {
        trControl$savePredictions <- if (trControl$savePredictions)
          "all"
        else "none"
      }
      else {
        if (!(trControl$savePredictions %in% c("all", "final",
                                               "none")))
          stop("`savePredictions` should be either logical or \"all\", \"final\" or \"none\"",
               call. = FALSE)
      }
      if (!is.null(preProcess)) {
        ppOpt <- list(options = preProcess)
        if (length(trControl$preProcOptions) > 0)
          ppOpt <- c(ppOpt, trControl$preProcOptions)
      }
      else ppOpt <- NULL
      if (is.null(tuneGrid)) {
        if (!is.null(ppOpt) && length(models$parameters$parameter) >
            1 && as.character(models$parameters$parameter) !=
            "parameter") {
          pp <- list(method = ppOpt$options)
          if ("ica" %in% pp$method)
            pp$n.comp <- ppOpt$ICAcomp
          if ("pca" %in% pp$method)
            pp$thresh <- ppOpt$thresh
          if ("knnImpute" %in% pp$method)
            pp$k <- ppOpt$k
          pp$x <- x
          ppObj <- do.call("preProcess", pp)
          tuneGrid <- models$grid(x = predict(ppObj, x), y = y,
                                  len = tuneLength, search = trControl$search)
          rm(ppObj, pp)
        }
        else {
          tuneGrid <- models$grid(x = x, y = y, len = tuneLength,
                                  search = trControl$search)
          if (trControl$search != "grid" && tuneLength < nrow(tuneGrid))
            tuneGrid <- tuneGrid[1:tuneLength, , drop = FALSE]
        }
      }
      if (grepl("adaptive", trControl$method) & nrow(tuneGrid) ==
          1) {
        stop(paste("For adaptive resampling, there needs to be more than one",
                   "tuning parameter for evaluation"), call. = FALSE)
      }
      dotNames <- hasDots(tuneGrid, models)
      if (dotNames)
        colnames(tuneGrid) <- gsub("^\\.", "", colnames(tuneGrid))
      tuneNames <- as.character(models$parameters$parameter)
      goodNames <- all.equal(sort(tuneNames), sort(names(tuneGrid)))
      if (!is.logical(goodNames) || !goodNames) {
        stop(paste("The tuning parameter grid should have columns",
                   paste(tuneNames, collapse = ", ", sep = "")), call. = FALSE)
      }
      if (trControl$method == "none" && nrow(tuneGrid) != 1)
        stop("Only one model should be specified in tuneGrid with no resampling",
             call. = FALSE)
      trControl$yLimits <- if (is.numeric(y))
        get_range(y)
      else NULL
      if (trControl$method != "none") {
        if (is.function(models$loop) && nrow(tuneGrid) > 1) {
          trainInfo <- models$loop(tuneGrid)
          if (!all(c("loop", "submodels") %in% names(trainInfo)))
            stop("The 'loop' function should produce a list with elements 'loop' and 'submodels'",
                 call. = FALSE)
          lengths <- unlist(lapply(trainInfo$submodels, nrow))
          if (all(lengths == 0))
            trainInfo$submodels <- NULL
        }
        else trainInfo <- list(loop = tuneGrid)
        num_rs <- if (trControl$method != "oob")
          length(trControl$index)
        else 1L
        if (trControl$method %in% c("boot632", "optimism_boot",
                                    "boot_all"))
          num_rs <- num_rs + 1L
        if (is.null(trControl$seeds) || all(is.na(trControl$seeds))) {
          seeds <- sample.int(n = 1000000L, size = num_rs *
                                nrow(trainInfo$loop) + 1L)
          seeds <- lapply(seq(from = 1L, to = length(seeds),
                              by = nrow(trainInfo$loop)), function(x) {
                                seeds[x:(x + nrow(trainInfo$loop) - 1L)]
                              })
          seeds[[num_rs + 1L]] <- seeds[[num_rs + 1L]][1L]
          trControl$seeds <- seeds
        }
        else {
          if (!(length(trControl$seeds) == 1 && is.na(trControl$seeds))) {
            numSeeds <- unlist(lapply(trControl$seeds, length))
            badSeed <- (length(trControl$seeds) < num_rs +
                          1L) || (any(numSeeds[-length(numSeeds)] < nrow(trainInfo$loop))) ||
              (numSeeds[length(numSeeds)] < 1L)
            if (badSeed)
              stop(paste("Bad seeds: the seed object should be a list of length",
                         num_rs + 1, "with", num_rs, "integer vectors of size",
                         nrow(trainInfo$loop), "and the last list element having at least a",
                         "single integer"), call. = FALSE)
            if (any(is.na(unlist(trControl$seeds))))
              stop("At least one seed is missing (NA)", call. = FALSE)
          }
        }
        if (trControl$method == "oob") {
          perfNames <- metric
        }
        else {
          testSummary <- evalSummaryFunction(y, wts = weights,
                                             ctrl = trControl, lev = classLevels, metric = metric,
                                             method = method)
          perfNames <- names(testSummary)
        }
        if (!(metric %in% perfNames)) {
          oldMetric <- metric
          metric <- perfNames[1]
          warning(paste("The metric \"", oldMetric, "\" was not in ",
                        "the result set. ", metric, " will be used instead.",
                        sep = ""))
        }
        if (trControl$method == "oob") {
          tmp <- oobTrainWorkflow(x = x, y = y, wts = weights,
                                  info = trainInfo, method = models, ppOpts = preProcess,
                                  ctrl = trControl, lev = classLevels, ...)
          performance <- tmp
          perfNames <- colnames(performance)
          perfNames <- perfNames[!(perfNames %in% as.character(models$parameters$parameter))]
          if (!(metric %in% perfNames)) {
            oldMetric <- metric
            metric <- perfNames[1]
            warning(paste("The metric \"", oldMetric, "\" was not in ",
                          "the result set. ", metric, " will be used instead.",
                          sep = ""))
          }
        }
        else {
          if (trControl$method == "LOOCV") {
            tmp <- looTrainWorkflow(x = x, y = y, wts = weights,
                                    info = trainInfo, method = models, ppOpts = preProcess,
                                    ctrl = trControl, lev = classLevels, ...)
            performance <- tmp$performance
          }
          else {
            if (!grepl("adapt", trControl$method)) {
              tmp <- nominalTrainWorkflow(x = x, y = y, wts = weights,
                                          info = trainInfo, method = models, ppOpts = preProcess,
                                          ctrl = trControl, lev = classLevels, ...)
              performance <- tmp$performance
              resampleResults <- tmp$resample
            }
            else {
              tmp <- adaptiveWorkflow(x = x, y = y, wts = weights,
                                      info = trainInfo, method = models, ppOpts = preProcess,
                                      ctrl = trControl, lev = classLevels, metric = metric,
                                      maximize = maximize, ...)
              performance <- tmp$performance
              resampleResults <- tmp$resample
            }
          }
        }
        trControl$indexExtra <- NULL
        if (!(trControl$method %in% c("LOOCV", "oob"))) {
          if (modelType == "Classification" && length(grep("^\\cell",
                                                           colnames(resampleResults))) > 0) {
            resampledCM <- resampleResults[, !(names(resampleResults) %in%
                                                 perfNames)]
            resampleResults <- resampleResults[, -grep("^\\cell",
                                                       colnames(resampleResults))]
          }
          else resampledCM <- NULL
        }
        else resampledCM <- NULL
        if (trControl$verboseIter) {
          cat("Aggregating results\n")
          flush.console()
        }
        perfCols <- names(performance)
        perfCols <- perfCols[!(perfCols %in% paramNames)]
        if (all(is.na(performance[, metric]))) {
          cat(paste("Something is wrong; all the", metric,
                    "metric values are missing:\n"))
          print(summary(performance[, perfCols[!grepl("SD$",
                                                      perfCols)], drop = FALSE]))
          stop("Stopping", call. = FALSE)
        }
        if (!is.null(models$sort))
          performance <- models$sort(performance)
        if (any(is.na(performance[, metric])))
          warning("missing values found in aggregated results")
        if (trControl$verboseIter && nrow(performance) > 1) {
          cat("Selecting tuning parameters\n")
          flush.console()
        }
        selectClass <- class(trControl$selectionFunction)[1]
        if (grepl("adapt", trControl$method)) {
          perf_check <- subset(performance, .B == max(performance$.B))
        }
        else perf_check <- performance
        if (selectClass == "function") {
          bestIter <- trControl$selectionFunction(x = perf_check,
                                                  metric = metric, maximize = maximize)
        }
        else {
          if (trControl$selectionFunction == "oneSE") {
            bestIter <- oneSE(perf_check, metric, length(trControl$index),
                              maximize)
          }
          else {
            bestIter <- do.call(trControl$selectionFunction,
                                list(x = perf_check, metric = metric, maximize = maximize))
          }
        }
        if (is.na(bestIter) || length(bestIter) != 1)
          stop("final tuning parameters could not be determined",
               call. = FALSE)
        if (grepl("adapt", trControl$method)) {
          best_perf <- perf_check[bestIter, as.character(models$parameters$parameter),
                                  drop = FALSE]
          performance$order <- 1:nrow(performance)
          bestIter <- merge(performance, best_perf)$order
          performance$order <- NULL
        }
        bestTune <- performance[bestIter, paramNames, drop = FALSE]
      }
      else {
        bestTune <- tuneGrid
        performance <- evalSummaryFunction(y, wts = weights,
                                           ctrl = trControl, lev = classLevels, metric = metric,
                                           method = method)
        perfNames <- names(performance)
        performance <- as.data.frame(t(performance))
        performance <- cbind(performance, tuneGrid)
        performance <- performance[-1, , drop = FALSE]
        tmp <- resampledCM <- NULL
      }
      if (!(trControl$method %in% c("LOOCV", "oob", "none"))) {
        byResample <- switch(trControl$returnResamp, none = NULL,
                             all = {
                               out <- resampleResults
                               colnames(out) <- gsub("^\\.", "", colnames(out))
                               out
                             }, final = {
                               out <- merge(bestTune, resampleResults)
                               out <- out[, !(names(out) %in% names(tuneGrid)),
                                          drop = FALSE]
                               out
                             })
      }
      else {
        byResample <- NULL
      }
      orderList <- list()
      for (i in seq(along = paramNames)) orderList[[i]] <- performance[,
                                                                       paramNames[i]]
      performance <- performance[do.call("order", orderList), ]
      if (trControl$verboseIter) {
        bestText <- paste(paste(names(bestTune), "=", format(bestTune,
                                                             digits = 3)), collapse = ", ")
        if (nrow(performance) == 1)
          bestText <- "final model"
        cat("Fitting", bestText, "on full training set\n")
        flush.console()
      }
      indexFinal <- if (is.null(trControl$indexFinal))
        seq(along = y)
      else trControl$indexFinal
      if (!(length(trControl$seeds) == 1 && is.na(trControl$seeds)))
        set.seed(trControl$seeds[[length(trControl$seeds)]][1])
      if (fitFinal) {
        finalTime <- system.time(finalModel <- createModel(x = subset_x(x,
                                                                        indexFinal), y = y[indexFinal], wts = weights[indexFinal],
                                                           method = models, tuneValue = bestTune, obsLevels = classLevels,
                                                           pp = ppOpt, last = TRUE, classProbs = trControl$classProbs,
                                                           sampling = trControl$sampling, ...))
      } else {
        finalModel <- list(fit = NULL, preProc = NULL)
        finalTime <- 0
      }
      if (trControl$trim && !is.null(models$trim)) {
        if (trControl$verboseIter)
          old_size <- object.size(finalModel$fit)
        finalModel$fit <- models$trim(finalModel$fit)
        if (trControl$verboseIter) {
          new_size <- object.size(finalModel$fit)
          reduction <- format(old_size - new_size, units = "Mb")
          if (reduction == "0 Mb")
            reduction <- "< 0 Mb"
          p_reduction <- (unclass(old_size) - unclass(new_size))/unclass(old_size) *
            100
          p_reduction <- if (p_reduction < 1)
            "< 1%"
          else paste0(round(p_reduction, 0), "%")
          cat("Final model footprint reduced by", reduction,
              "or", p_reduction, "\n")
        }
      }
      pp <- finalModel$preProc
      finalModel <- finalModel$fit
      if (method == "pls")
        finalModel$bestIter <- bestTune
      if (method == "glmnet")
        finalModel$lambdaOpt <- bestTune$lambda
      if (trControl$returnData) {
        outData <- if (!is.data.frame(x))
          try(as.data.frame(x), silent = TRUE)
        else x
        if (inherits(outData, "try-error")) {
          warning("The training data could not be converted to a data frame for saving")
          outData <- NULL
        }
        else {
          outData$.outcome <- y
          if (!is.null(weights))
            outData$.weights <- weights
        }
      }
      else outData <- NULL
      if (trControl$savePredictions == "final")
        tmp$predictions <- merge(bestTune, tmp$predictions)
      endTime <- proc.time()
      times <- list(everything = endTime - startTime, final = finalTime)
      out <- structure(list(method = method, modelInfo = models,
                            modelType = modelType, results = performance, pred = tmp$predictions,
                            bestTune = bestTune, call = funcCall, dots = list(...),
                            metric = metric, control = trControl, finalModel = finalModel,
                            preProcess = pp, trainingData = outData, resample = byResample,
                            resampledCM = resampledCM, perfNames = perfNames, maximize = maximize,
                            yLimits = trControl$yLimits, times = times, levels = classLevels),
                       class = "train")
      trControl$yLimits <- NULL
      if (trControl$timingSamps > 0) {
        pData <- x[sample(1:nrow(x), trControl$timingSamps, replace = TRUE),
                   , drop = FALSE]
        out$times$prediction <- system.time(predict(out, pData))
      }
      else out$times$prediction <- rep(NA, 3)
      out
    }
    

    That gives

    data(iris)
    TrainData <- iris[,1:4]
    TrainClasses <- iris[,5]
    
    knnFit1 <- train(TrainData, TrainClasses,
                     method = "knn",
                     preProcess = c("center", "scale"),
                     tuneLength = 10,
                     trControl = trainControl(method = "cv"), fitFinal = FALSE)
    knnFit1$finalModel
    # NULL