Skip to contents

A good future step would be to provide functions that automate the below plotting code given the output of compare_learners() and cv_super_learner().

data("Boston", package = "MASS")

# construct our super learner with verbose = TRUE
sl_model <- super_learner(
  data = Boston,
  formulas = medv ~ .,
  learners = list(
    lm = lnr_lm,
    rf = lnr_rf,
    ranger = lnr_ranger,
    xgboost = lnr_xgboost,
    glmnet = lnr_glmnet,
    earth = lnr_earth
    ),
  verbose = TRUE)
  
compare_learners(sl_model)
#> The default in nadir::compare_learners is to use CV-MSE for comparing learners.
#> Other metrics can be set using the loss_metric argument to compare_learners.
#> # A tibble: 1 × 6
#>      lm    rf ranger xgboost glmnet earth
#>   <dbl> <dbl>  <dbl>   <dbl>  <dbl> <dbl>
#> 1  25.3  10.8   11.8    8.83   26.7  14.8

pacman::p_load('dplyr', 'ggplot2', 'tidyr', 'magrittr')

truth <- sl_model$holdout_predictions$mpg
#> Warning: Unknown or uninitialised column: `mpg`.

holdout_var <- sl_model$holdout_predictions |>
  dplyr::group_by(.sl_fold) |> 
  dplyr::summarize(across(everything(), ~ mean((. - medv)^2))) |> 
  dplyr::summarize(across(everything(), var)) |> 
  select(-medv, -.sl_fold) |> 
  t() |> 
  as.data.frame() |> 
  tibble::rownames_to_column('learner') |> 
  dplyr::rename(var = V1) |>
  dplyr::mutate(sd = sqrt(var))


jitters <- sl_model$holdout_predictions |> 
  dplyr::mutate(dplyr::across(-.sl_fold, ~ (. - medv)^2)) |> 
  dplyr::select(-medv) %>%
  tidyr::pivot_longer(cols = 2:ncol(.), names_to = 'learner', values_to = 'squared_error') |>
  dplyr::group_by(learner, .sl_fold) |> 
  dplyr::summarize(mse = mean(squared_error)) |> 
  ungroup() |> 
  rename(fold = .sl_fold)
#> `summarise()` has grouped output by 'learner'. You can override using the
#> `.groups` argument.

learner_comparison_df <- sl_model |> 
  compare_learners() |> 
  t() |> 
  as.data.frame() |>
  tibble::rownames_to_column(var = 'learner') |> 
  dplyr::mutate(learner = factor(learner)) |>
  dplyr::rename(mse = V1) |>
  dplyr::left_join(holdout_var) |> 
  dplyr::mutate(
    upper_ci = mse + sd,
    lower_ci = mse - sd) |> 
  dplyr::mutate(learner = forcats::fct_reorder(learner, mse))
#> The default in nadir::compare_learners is to use CV-MSE for comparing learners.
#> Other metrics can be set using the loss_metric argument to compare_learners.
#> Joining with `by = join_by(learner)`

jitters$learner <- factor(jitters$learner, levels = levels(learner_comparison_df$learner))

learner_comparison_df |> 
  ggplot2::ggplot(ggplot2::aes(y = learner, x = mse, fill = learner)) + 
  ggplot2::geom_col(alpha = 0.5) + 
  ggplot2::geom_jitter(data = jitters, mapping = ggplot2::aes(x = mse), height = .15, shape = 'o') + 
  ggplot2::geom_pointrange(mapping = ggplot2::aes(xmax = upper_ci, xmin = lower_ci),
                           alpha = 0.5) + 
  ggplot2::theme_bw() + 
  ggplot2::ggtitle("Comparison of Candidate Learners") + 
  ggplot2::labs(caption = "Error bars show ±1 standard deviation across the CV estimated MSE for each learner\n
Each open circle represents the hold-out MSE of one fold of the data") + 
  ggplot2::theme(plot.caption.position = 'plot')



sl_closure <- function(data) {
  nadir::super_learner(
  data = data,
  formulas = medv ~ .,
  learners = list(
    lm = lnr_lm,
    rf = lnr_rf,
    ranger = lnr_ranger,
    xgboost = lnr_xgboost,
    glmnet = lnr_glmnet,
    earth = lnr_earth
    )
  )
}

cv_results <- cv_super_learner(data = Boston, sl_closure, 
                 y_variable = 'medv',
                 n_folds = 5)
#> The default is to report CV-MSE if no other loss_metric is specified.

cv_jitters <- cv_results$cv_trained_learners |> 
  dplyr::select(split, predictions, medv) |> 
  tidyr::unnest(cols = c('predictions', 'medv')) |> 
  dplyr::group_by(split) |> 
  dplyr::summarize(mse = mean((medv - predictions)^2)) |>
  dplyr::bind_cols(learner = 'super_learner')


cv_var <- cv_results$cv_trained_learners |> 
  dplyr::select(split, predictions, medv) |> 
  tidyr::unnest(cols = c(predictions, medv)) |> 
  dplyr::mutate(squared_error = (medv - predictions)^2) |> 
  dplyr::group_by(split) |> 
  dplyr::summarize(mse = mean(squared_error)) |> 
  dplyr::summarize(
    var = var(mse),
    mse = mean(mse),
    sd = sqrt(var),
    upper_ci = mse + sd,
    lower_ci = mse - sd) |> 
  dplyr::bind_cols(learner = 'super_learner')

new_jitters <- bind_rows(jitters, cv_jitters)

learner_comparison_df |> 
  bind_rows(cv_var) |> 
  dplyr::mutate(learner = forcats::fct_reorder(learner, mse)) |> 
  ggplot2::ggplot(ggplot2::aes(y = learner, x = mse, fill = learner)) + 
  ggplot2::geom_col(alpha = 0.5) + 
  ggplot2::geom_jitter(data = new_jitters, mapping = ggplot2::aes(x = mse), height = .15, shape = 'o') + 
  ggplot2::geom_pointrange(mapping = ggplot2::aes(xmax = upper_ci, xmin = lower_ci),
                           alpha = 0.5) + 
  ggplot2::theme_bw() + 
  ggplot2::scale_fill_brewer(palette = 'Set2') + 
  ggplot2::ggtitle("Comparison of Candidate Learners against Super Learner") + 
  ggplot2::labs(caption = "Error bars show ±1 standard deviation across the CV estimated MSE for each learner\n
Each open circle represents the hold-out MSE of one fold of the data") + 
  ggplot2::theme(plot.caption.position = 'plot')