The assumptions present in using cross validation to estimate the risk on unseen data as is done in the super learner algorithm include that the training and validation splits are independent samples from the underlying data distribution.
If the data come from clusters or otherwise dependent samples, we need to make sure that the cross-validation do not split up clusters by putting some observations into training splits and others into a validation split from the same cluster.
Read about this in Practical considerations for specifying a super learner by Rachel Phillips et al (2023, Int J Epidemiology) https://academic.oup.com/ije/article/52/4/1276/7076266.
When the data are not i.i.d., clustered observations must be assigned as a group to the same validation and training sets. This ensures the validation data are completely independent of the training data, and the loss function is evaluated at the cluster level. V-fold CV with clustered data is specified with existing SL software by supplying a cluster identifier to the ‘id’ argument.
The above quote is describing the syntax of the SuperLearner package, so we try to achieve something similar. We just need to specify a schema using origami as follows.
## origami v1.0.7: Generalized Framework for Cross-Validation
# generate synthetic clustered data
n_clusters <- 25
n_participants <- 100
cluster_ids <- sample.int(25, 100, replace = TRUE)
age <- sample.int(100, n_participants, replace = TRUE)
drug <- sample(x = c(0, 1), n_participants, replace = TRUE)
cluster_mean_outcomes <- rnorm(mean = 25, sd = 5, n = n_clusters)
participant_outcomes <-
cluster_mean_outcomes[cluster_ids] +
drug * 15 +
age
df <- data.frame(
cluster_id = cluster_ids,
age = age,
drug = drug,
outcome = participant_outcomes)
set.seed(1234)
sl_model <- super_learner(
data = df,
formulas = outcome ~ age + drug,
learners = list(lnr_rf, lnr_lm, lnr_earth, lnr_glmnet),
cluster_ids = df$cluster_id,
verbose = TRUE
)
Behind the scenes, if cluster_ids
or
strata_ids
are passed to super_learner()
, the
cv_schema
argument is being set to
cv_schema <- function(data, n_folds) {
cv_origami_schema(data, n_folds = n_folds,
fold_fun = origami::folds_vfun,
cluster_ids = cluster_ids,
strata_ids = strata_ids
)
}
Let’s check see explicitly how cv_origami_schema
is
working:
df_splits <- cv_origami_schema(
data = df,n_folds = 5, fold_fun = origami::folds_vfold, cluster_ids = df$cluster_id)
unisort <- \(x) sort(unique(x))
unisort(df_splits$training_data[[1]]$cluster_id)
## [1] 2 3 4 5 7 8 9 12 13 14 15 16 17 18 19 20 21 23 24 25
unisort(df_splits$validation_data[[1]]$cluster_id)
## [1] 1 6 10 11 22
unisort(df_splits$training_data[[2]]$cluster_id)
## [1] 1 2 4 5 6 7 8 9 10 11 14 15 16 17 20 21 22 23 24 25
unisort(df_splits$validation_data[[2]]$cluster_id)
## [1] 3 12 13 18 19
So we can see that cv_origami_schema
assigns all of each
cluster to either the training or validation data in each split.
In fact, we can get equivalent results if we run the following code,
which shows explicitly how to pass a user-chosen origami
folds_*
function:
set.seed(1234)
sl_model2 <- super_learner(
data = df,
formulas = outcome ~ age + drug,
learners = list(lnr_rf, lnr_lm, lnr_earth, lnr_glmnet),
cv_schema = function(data, n_folds) {
cv_origami_schema(data = data, n_folds = n_folds, fold_fun = origami::folds_vfold,
cluster_ids = data$cluster_id)
},
verbose = TRUE
)
sl_model$learner_weights
## rf lm earth glmnet
## 0.0000000 0.6891014 0.3108986 0.0000000
sl_model2$learner_weights
## rf lm earth glmnet
## 0.0000000 0.6891014 0.3108986 0.0000000
Strata
If, on the other hand, you want to ensure that strata are always
represented in the training sets for every cross-validation split of the
data, we can use the strata_ids
argument to
super_learner()
.
To see how it works, we can look at
df$strata_id <- rep(1:20, each = 5)
df_splits <- cv_origami_schema(
data = df,n_folds = 5, fold_fun = origami::folds_vfold, strata_ids = df$strata_id)
unisort <- \(x) sort(unique(x))
unisort(df_splits$training_data[[1]]$strata_id)
## [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
unisort(df_splits$validation_data[[1]]$strata_id)
## [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20