0
votes

This question if for tidymodels user, and if you are lazy, just skip the entire text and jump right to the bold question below

Im looking for the most efficient way to extract my parsnip model object from fitted resamples (tune::fit_resample()).

When i want to train a model with cross-validation, i can either go with tune::tune_grid() oder fit_resamples().

Lets say i know the best parameters for my algorithm, so i dont need any paramter tunig, which means i decide to go with fit_resamples(). If i had decided to go with tune_grid() i usually set up a workflow since i evaluate different models after tune_grid ran: I go for tune::show_best() and tune::select_best() to explore and extract the best parameters for my model. Then i go for tune::finalize_workflow(), workflows::pull_wokrflow_fit() to extract my model object. Further when i want to see predictions i go for tune::last_fit() and tune::collect_predictions()

All these steps seem redundant when i go with fit_resamples(), since i basically only have one model with stable parameters. So all these steps above are not neccesarry, nevertheless i have to go trough them. Do I?

After fit_resamples() is performed, i get a tibble with information about .splits, .metrics, .notes, etc. So my question really comes down to:

  • What is the fastest way from the output tibble of fit_resamples() to my final parsnip model object?
1

1 Answers

1
votes

The important thing to realize about fit_resamples() is that its purpose is to measure performance. The models that you train in fit_resamples() are not kept or used later.

Let's imagine that you know the parameters you want to use for an SVM model.

library(tidymodels)
#> ── Attaching packages ─────────────────────────── tidymodels 0.1.1 ──
#> ✓ broom     0.7.0      ✓ recipes   0.1.13
#> ✓ dials     0.0.8      ✓ rsample   0.0.7 
#> ✓ dplyr     1.0.0      ✓ tibble    3.0.3 
#> ✓ ggplot2   3.3.2      ✓ tidyr     1.1.0 
#> ✓ infer     0.5.3      ✓ tune      0.1.1 
#> ✓ modeldata 0.0.2      ✓ workflows 0.1.2 
#> ✓ parsnip   0.1.2      ✓ yardstick 0.0.7 
#> ✓ purrr     0.3.4
#> ── Conflicts ────────────────────────────── tidymodels_conflicts() ──
#> x purrr::discard() masks scales::discard()
#> x dplyr::filter()  masks stats::filter()
#> x dplyr::lag()     masks stats::lag()
#> x recipes::step()  masks stats::step()

## pretend this is your training data
data("hpc_data")

svm_spec <- svm_poly(degree = 1, cost = 1/4) %>%
  set_engine("kernlab") %>%
  set_mode("regression")

svm_wf <- workflow() %>%
  add_model(svm_spec) %>%
  add_formula(compounds ~ .)

hpc_folds <- vfold_cv(hpc_data)

svm_rs <- svm_wf %>%
  fit_resamples(
    resamples = hpc_folds
  )

svm_rs
#> # Resampling results
#> # 10-fold cross-validation 
#> # A tibble: 10 x 4
#>    splits             id     .metrics         .notes          
#>    <list>             <chr>  <list>           <list>          
#>  1 <split [3.9K/434]> Fold01 <tibble [2 × 3]> <tibble [0 × 1]>
#>  2 <split [3.9K/433]> Fold02 <tibble [2 × 3]> <tibble [0 × 1]>
#>  3 <split [3.9K/433]> Fold03 <tibble [2 × 3]> <tibble [0 × 1]>
#>  4 <split [3.9K/433]> Fold04 <tibble [2 × 3]> <tibble [0 × 1]>
#>  5 <split [3.9K/433]> Fold05 <tibble [2 × 3]> <tibble [0 × 1]>
#>  6 <split [3.9K/433]> Fold06 <tibble [2 × 3]> <tibble [0 × 1]>
#>  7 <split [3.9K/433]> Fold07 <tibble [2 × 3]> <tibble [0 × 1]>
#>  8 <split [3.9K/433]> Fold08 <tibble [2 × 3]> <tibble [0 × 1]>
#>  9 <split [3.9K/433]> Fold09 <tibble [2 × 3]> <tibble [0 × 1]>
#> 10 <split [3.9K/433]> Fold10 <tibble [2 × 3]> <tibble [0 × 1]>

There are no fitted models in this output. Models were fitted to each of these resamples, but you don't want to use them for anything; they are thrown away because their only purpose is for computing the .metrics to estimate performance.

If you want a model to use to predict on new data, you need to go back to your whole training set and fit your model once again, with the entire training set.

svm_fit <- svm_wf %>%
  fit(hpc_data)

svm_fit
#> ══ Workflow [trained] ═══════════════════════════════════════════════
#> Preprocessor: Formula
#> Model: svm_poly()
#> 
#> ── Preprocessor ─────────────────────────────────────────────────────
#> compounds ~ .
#> 
#> ── Model ────────────────────────────────────────────────────────────
#> Support Vector Machine object of class "ksvm" 
#> 
#> SV type: eps-svr  (regression) 
#>  parameter : epsilon = 0.1  cost C = 0.25 
#> 
#> Polynomial kernel function. 
#>  Hyperparameters : degree =  1  scale =  1  offset =  1 
#> 
#> Number of Support Vectors : 2827 
#> 
#> Objective Function Value : -284.7255 
#> Training error : 0.835421

Created on 2020-07-17 by the reprex package (v0.3.0)

This final object is one that you can use with pull_workflow_fit() for variable importance or similar.