# load packages
library(tidyverse)
library(tidymodels)
library(knitr)
library(schrute)
Cross validation
STA 210 - Spring 2022
Welcome
Topics
- Cross validation for model evaluation
- Cross validation for model comparison
Computational setup
Data & goal
<- read_csv(here::here("slides", "data/office_episodes.csv"))
office_episodes office_episodes
# A tibble: 186 × 14
season episode episode_name imdb_rating total_votes air_date lines_jim
<dbl> <dbl> <chr> <dbl> <dbl> <date> <dbl>
1 1 1 Pilot 7.6 3706 2005-03-24 0.157
2 1 2 Diversity Day 8.3 3566 2005-03-29 0.123
3 1 3 Health Care 7.9 2983 2005-04-05 0.172
4 1 4 The Alliance 8.1 2886 2005-04-12 0.202
5 1 5 Basketball 8.4 3179 2005-04-19 0.0913
6 1 6 Hot Girl 7.8 2852 2005-04-26 0.159
7 2 1 The Dundies 8.7 3213 2005-09-20 0.125
8 2 2 Sexual Harassment 8.2 2736 2005-09-27 0.0565
9 2 3 Office Olympics 8.4 2742 2005-10-04 0.196
10 2 4 The Fire 8.4 2713 2005-10-11 0.160
# … with 176 more rows, and 7 more variables: lines_pam <dbl>,
# lines_michael <dbl>, lines_dwight <dbl>, halloween <chr>, valentine <chr>,
# christmas <chr>, michael <chr>
Modeling prep
Split data into training and testing
set.seed(123)
<- initial_split(office_episodes)
office_split <- training(office_split)
office_train <- testing(office_split) office_test
Specify model
<- linear_reg() %>%
office_spec set_engine("lm")
office_spec
Linear Regression Model Specification (regression)
Computational engine: lm
Model 1
From yesterday’s lab
- Create a recipe that uses the new variables we generated
- Denotes
episode_name
as an ID variable and doesn’t useair_date
as a predictor - Create dummy variables for all nominal predictors
- Remove all zero variance predictors
Create recipe
<- recipe(imdb_rating ~ ., data = office_train) %>%
office_rec1 update_role(episode_name, new_role = "id") %>%
step_rm(air_date) %>%
step_dummy(all_nominal_predictors()) %>%
step_zv(all_predictors())
office_rec1
Recipe
Inputs:
role #variables
id 1
outcome 1
predictor 12
Operations:
Delete terms air_date
Dummy variables from all_nominal_predictors()
Zero variance filter on all_predictors()
Create workflow
<- workflow() %>%
office_wflow1 add_model(office_spec) %>%
add_recipe(office_rec1)
office_wflow1
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: linear_reg()
── Preprocessor ────────────────────────────────────────────────────────────────
3 Recipe Steps
• step_rm()
• step_dummy()
• step_zv()
── Model ───────────────────────────────────────────────────────────────────────
Linear Regression Model Specification (regression)
Computational engine: lm
Fit model to training data
. . .
Actually, not so fast!
Cross validation
Spending our data
- We have already established that the idea of data spending where the test set was recommended for obtaining an unbiased estimate of performance.
- However, we usually need to understand the effectiveness of the model before using the test set.
- Typically we can’t decide on which final model to take to the test set without making model assessments.
- Remedy: Resampling to make model assessments on training data in a way that can generalize to new data.
Resampling for model assessment
Resampling is only conducted on the training set. The test set is not involved. For each iteration of resampling, the data are partitioned into two subsamples:
- The model is fit with the analysis set.
- The model is evaluated with the assessment set.
Resampling for model assessment
Source: Kuhn and Silge. Tidy modeling with R.
Analysis and assessment sets
- Analysis set is analogous to training set.
- Assessment set is analogous to test set.
- The terms analysis and assessment avoids confusion with initial split of the data.
- These data sets are mutually exclusive.
Cross validation
More specifically, v-fold cross validation – commonly used resampling technique:
- Randomly split your training data into v partitions
- Use 1 partition for assessment, and the remaining v-1 partitions for analysis
- Repeat v times, updating which partition is used for assessment each time
. . .
Let’s give an example where v = 3
…
Cross validation, step 1
Randomly split your training data into 3 partitions:
Split data
set.seed(345)
<- vfold_cv(office_train, v = 3)
folds folds
# 3-fold cross-validation
# A tibble: 3 × 2
splits id
<list> <chr>
1 <split [92/47]> Fold1
2 <split [93/46]> Fold2
3 <split [93/46]> Fold3
Cross validation, steps 2 and 3
- Use 1 partition for assessment, and the remaining v-1 partitions for analysis
- Repeat v times, updating which partition is used for assessment each time
Fit resamples
set.seed(456)
<- office_wflow1 %>%
office_fit_rs1 fit_resamples(folds)
office_fit_rs1
# Resampling results
# 3-fold cross-validation
# A tibble: 3 × 4
splits id .metrics .notes
<list> <chr> <list> <list>
1 <split [92/47]> Fold1 <tibble [2 × 4]> <tibble [0 × 1]>
2 <split [93/46]> Fold2 <tibble [2 × 4]> <tibble [0 × 1]>
3 <split [93/46]> Fold3 <tibble [2 × 4]> <tibble [0 × 1]>
Cross validation, now what?
- We’ve fit a bunch of models
- Now it’s time to use them to collect metrics (e.g., R-squared, RMSE) on each model and use them to evaluate model fit and how it varies across folds
Collect CV metrics
collect_metrics(office_fit_rs1)
# A tibble: 2 × 6
.metric .estimator mean n std_err .config
<chr> <chr> <dbl> <int> <dbl> <chr>
1 rmse standard 0.351 3 0.0111 Preprocessor1_Model1
2 rsq standard 0.546 3 0.0378 Preprocessor1_Model1
Deeper look into CV metrics
<- collect_metrics(office_fit_rs1, summarize = FALSE)
cv_metrics1
cv_metrics1
# A tibble: 6 × 5
id .metric .estimator .estimate .config
<chr> <chr> <chr> <dbl> <chr>
1 Fold1 rmse standard 0.356 Preprocessor1_Model1
2 Fold1 rsq standard 0.520 Preprocessor1_Model1
3 Fold2 rmse standard 0.367 Preprocessor1_Model1
4 Fold2 rsq standard 0.498 Preprocessor1_Model1
5 Fold3 rmse standard 0.330 Preprocessor1_Model1
6 Fold3 rsq standard 0.621 Preprocessor1_Model1
Better tabulation of CV metrics
%>%
cv_metrics1 mutate(.estimate = round(.estimate, 3)) %>%
pivot_wider(id_cols = id, names_from = .metric, values_from = .estimate) %>%
kable(col.names = c("Fold", "RMSE", "R-squared"))
Fold | RMSE | R-squared |
---|---|---|
Fold1 | 0.356 | 0.520 |
Fold2 | 0.367 | 0.498 |
Fold3 | 0.330 | 0.621 |
How does RMSE compare to y?
Cross validation RMSE stats:
%>%
cv_metrics1 filter(.metric == "rmse") %>%
summarise(
min = min(.estimate),
max = max(.estimate),
mean = mean(.estimate),
sd = sd(.estimate)
)
# A tibble: 1 × 4
min max mean sd
<dbl> <dbl> <dbl> <dbl>
1 0.330 0.367 0.351 0.0192
Training data IMDB score stats:
%>%
office_episodes summarise(
min = min(imdb_rating),
max = max(imdb_rating),
mean = mean(imdb_rating),
sd = sd(imdb_rating)
)
# A tibble: 1 × 4
min max mean sd
<dbl> <dbl> <dbl> <dbl>
1 6.7 9.7 8.25 0.535
Cross validation jargon
- Referred to as v-fold or k-fold cross validation
- Also commonly abbreviated as CV
Cross validation, for reals
To illustrate how CV works, we used
v = 3
:- Analysis sets are 2/3 of the training set
- Each assessment set is a distinct 1/3
- The final resampling estimate of performance averages each of the 3 replicates
This was useful for illustrative purposes, but
v = 3
is a poor choice in practiceValues of
v
are most often 5 or 10; we generally prefer 10-fold cross-validation as a default
Application exercise
Recap
- Cross validation for model evaluation
- Cross validation for model comparison