Skip to contents

There may be non-trivial reasons to weight the observations passed into nadir::super_learner(). One such example might be when the outcome is a population prevalence or incidence rate and the observations should be weighted to reflect the fact that the outcomes may represent varying amounts of underlying exposure-time (or person-time frequently in epidemiology).

We demonstrate here how to use weighted observations in super_learner() and check that the loss in a weighted super_learner() is favorably reduced on higher weight observations compared to in an unweighted super_learner().

library(nadir)
set.seed(1234)

# generate synthetic data 
n_counties <- 500

df <- tibble::tibble(
  county_id = 1:n_counties,
  county_person_years = round(rlnorm(n = n_counties, meanlog = log(1e5), sd = 3)) + 10,
  # if it's a particularly rural county, we assume the average age is slightly higher
  county_mean_age = sample(x = 20:60, n_counties, replace = TRUE),
  county_avg_sbp = sample(90:150, n_counties, replace = TRUE) + 2*(county_mean_age - mean(county_mean_age))/sd(county_mean_age),
  # counties have between hundreds to hundreds of thousands of people;
  # we assume we have observed 1 year of time for all people in the county
  county_incidence_rate = 
   dplyr::case_when(
      county_person_years >= 1e5 & county_person_years < 2e5 ~ 
           50 + 1.5 * county_mean_age + 1.3 * county_mean_age^{4/3} + 
             sqrt(county_avg_sbp) + rnorm(n = n_counties, mean = 0, sd = 10),
      county_person_years < 1e5 ~ 
        25 + 1.9 * county_mean_age + 1.6 * county_mean_age^{3/2} + 
             1.5 * sqrt(county_avg_sbp),
      county_person_years >= 2e5 ~ 
        35 + 1.7 * county_mean_age + 1.6 * county_mean_age^{1/2} + 
             1.4 * sqrt(county_avg_sbp),
      
    ) + rnorm(n = n_counties, mean = 100, sd = 50)
)
head(df)
#> # A tibble: 6 × 5
#>   county_id county_person_years county_mean_age county_avg_sbp
#>       <int>               <dbl>           <int>          <dbl>
#> 1         1                2685              30           88.2
#> 2         2              229867              41          100. 
#> 3         3             2587630              21          122. 
#> 4         4                  98              50          150. 
#> 5         5              362336              53          100. 
#> 6         6              456396              43          105. 
#> # ℹ 1 more variable: county_incidence_rate <dbl>

Let’s take a look at what the simulated data look like. You can see that depending on whether we focus on the large population counties or smaller population counties, we should pick up on different trends.

library(ggplot2)

ggplot(df, aes(x = county_mean_age, y = county_incidence_rate, size = county_person_years)) + 
 geom_point(alpha = 0.25) + 
  scale_size_continuous(
    transform = scales::pseudo_log_trans(sigma = 1e5),
    breaks = c(100, 1000, 1e5, 1e6),
    labels = c('100', '1000', '100k', '1m')
  )


ggplot(df, aes(x = county_mean_age, y = county_incidence_rate, size = county_person_years)) + 
 geom_point(alpha = 0.25) + 
 facet_grid(cut(county_person_years, c(0, 1e5, 2e5, Inf))~ .) + 
  scale_size_continuous(transform = scales::pseudo_log_trans(sigma = 1e5),
                        breaks = c(100, 1000, 1e5, 1e6),
                        labels = c('100', '1000', '100k', '1m'))

Now let’s fit two super learners, one of which has its weights set so that each observation is person-time weighted.

# non-weighted super_learner
sl_model_no_weights <- super_learner(
  data = df, 
  formula = county_incidence_rate ~ county_mean_age + county_avg_sbp,
  learners = list(lnr_lm, lnr_earth, lnr_rf, lnr_xgboost, lnr_glmnet),
  verbose = TRUE
)

sl_model_with_weights <- super_learner(
  data = df, 
  formula = county_incidence_rate ~ county_mean_age + county_avg_sbp,
  learners = list(lnr_lm, lnr_earth, lnr_rf, lnr_xgboost, lnr_glmnet),
  weights = df$county_person_years, # ifelse(df$county_person_years > 2e5, 1, 0), 
  verbose = TRUE
)

squared_errors_with_no_weights <- (sl_model_no_weights$sl_predictor(df) - df$county_incidence_rate)^2
squared_errors_with_weights <- (sl_model_with_weights$sl_predictor(df) - df$county_incidence_rate)^2

# let's look at the larger counties
high_weight_observations <- which(df$county_person_years > 2e5)
length(high_weight_observations) # 
#> [1] 205

mean(squared_errors_with_no_weights[high_weight_observations])
#> [1] 63721.07
mean(squared_errors_with_weights[high_weight_observations])
#> [1] 2408.222

# we'll generate some newdata to predict on; just at the mean sbp and 
# across the observed range of data.
newdata <- df[1,]
newdata$county_avg_sbp <- mean(df$county_avg_sbp)

no_weights_predictions <- sapply(20:60, function(x) {
  newdata$county_mean_age <- x
  sl_model_no_weights$sl_predictor(newdata)
})

with_weights_predictions <- sapply(20:60, function(x) {
  newdata$county_mean_age <- x
  sl_model_with_weights$sl_predictor(newdata)
})

ggplot(df, aes(x = county_mean_age, y = county_incidence_rate, size = county_person_years)) + 
 geom_point(alpha = 0.25) + 
  scale_size_continuous(
    transform = scales::pseudo_log_trans(sigma = 1e5),
    breaks = c(100, 1000, 1e5, 1e6),
    labels = c('100', '1000', '100k', '1m')
  ) + 
geom_line(
  data = data.frame(
    county_mean_age = 20:60,
    prediction = no_weights_predictions),
  mapping = aes(y = prediction, color = 'non-weighted'),
  size = NULL, linewidth = 1) + 
geom_line(
  data = data.frame(
    county_mean_age = 20:60,
    prediction = with_weights_predictions),
  mapping = aes(y = prediction, color = 'weighted'),
  size = NULL, linewidth = 1) + 
scale_color_manual(values = c('weighted' = '#c54a41', 'non-weighted' = '#2193b3')) + 
  theme_bw() + 
  labs(color = 'super_learner() specification',
       x = 'County Avg. Age',
       y = 'County Incidence Rate per 100k Person-Years') + 
  ggtitle("Observation weights change the influence of observations on super_learner() predictions")
#> Warning: Using `size` aesthetic for lines was deprecated in ggplot2 3.4.0.
#>  Please use `linewidth` instead.
#> This warning is displayed once every 8 hours.
#> Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
#> generated.