Supplement to Applying Machine Learning in Predicting Medication Treatment Outcomes for Opioid Use Disorder

Author

Raymond R. Balise, Kyle Grealis, Guertson Jean-Baptiste, and the CTN-0094 Team

Published

September 2, 2025

What is this?

This is a companion to the paper “Understanding the Utility of Machine Learning for Predicting Medication Assisted Treatment Outcomes for Opioid Use Disorder” which is under review for publication. That paper contains the summary of modeling results when attempting to predict failure of treatment for people undergoing treatment for Opioid Use Disorder (OUD). Here you can learn about the machine learning (ML) methods we use and how to do them. We begin by describing the workflow and methods, then we explore the differences between ML and traditional statistics and, finally, we explain the code. Our goal is to provide a friendly introduction to the ML topics and to make it easy to start to learn the coding.

Introduction

Before we discuss the models themselves, let’s consider the project workflow which starts with the participant data and ends with a set of interpretable ML models.

You can download the complete analysis workflow for the project from GitHub by using this hyperlink, which will be live after the paper is accepted. You can find the necessary software to install on the project’s GitHub page. If you have not worked with GitHub before, you can study the code on this page as you are reading and learn about GitHub later. Everyone’s favorite GitHub resource for the R world is Happy Git and GitHub for the useR.

In the files saved to GitHub, there is an R file called do_everything.R that performs the entire modeling workflow. A copy of it appears below.

DON’T PANIC if you don’t speak R—it’s included as a reference for those familiar with the language. If you’re a novice, just note the following:

  • Lines that begin with #’ are technical explanations.
  • Lines that begin with # are human-readable comments.

We’ll discuss the details below, but the key point is this:

There is one code file that can fully reproduce the results in the paper.

Because the project is complex—and some parts take hours to run—the do_everything.R script tells R to execute other code files (with names ending in .R) that contain the modeling instructions for different machine learning algorithms.

By splitting the code into many smaller .R files, we were able to test each chunk independently without having to rerun the entire workflow.

#' do_everything.R
#'
#' Main script for running all modeling analyses
#' This script coordinates the entire modeling pipeline from data preparation
#' through model fitting and evaluation.

# ---- Helper Functions ----
#' Wrapper function for \code{source} to include processing time
#' @param script_path Machine learning modeling script
source_wrapper <- function(script_path) {
  message(glue::glue('Starting {script_path}...'))
  start_time <- Sys.time()
  source(script_path)
  process_time <- hms::as_hms(round(Sys.time() - start_time, 2))
  message(glue::glue('{script_path} ran in {process_time}\n\n'))
}

#' Function to process modeling scripts
#' @param model_script R script for specific machine learning model.
#' @param run_parallel Boolean value to use \code{doParallel} package
run_models <- function(model_script, run_parallel = FALSE) {
  message(glue::glue(
    "\nTime now: {hms::round_hms(hms::as_hms(Sys.time()), digits=0)}..."
  ))
  if (run_parallel) {
    doParallel::registerDoParallel(cores = parallel::detectCores() - 2)
  }
  source_wrapper(model_script)
  if (run_parallel) doParallel::stopImplicitCluster()
}

# ---- Initialize Settings ----
options(digits = 8)
set.seed(305)

# ---- Initialize Timing ----
beginning <- hms::round_hms(hms::as_hms(Sys.time()), digits=0)
message(sprintf('Start time: %s', beginning))

# ---- Data Preparation Phase ----
# These scripts handle library loading, data importing, and initial preprocessing
# To use latent class variables set the data use subset and subset_not_used
# in preprocess_recipe.R
setup_scripts <- c(
  "libraries.R",
  "code_with_notes.R",
  "load.R",
  "preprocess_recipe.R"
)
setup_scripts |>
  purrr::walk(source_wrapper)

# ---- Recipe Processing ----
# Prepare and bake the recipe for model training
prepped_recipe <- prep(a_recipe, training = a_train)
baked_recipe <- bake(prepped_recipe, new_data = a_train)
print(glue::glue(
  "Dimensions after modified recipe: {dim(baked_recipe)[1]} rows x {dim(baked_recipe)[2]} columns\n\n"
))

# ---- Model Fitting Phase ----
# Neural Network must be run without parallel processing on Mac systems
run_models("nnet.R", run_parallel = FALSE)

# Define and run models that can utilize parallel processing
models <- c(
  "logistic.R",
  "logistic_via_lasso.R",
  "lasso.R",
  "knn.R",
  "mars.R",
  "cart.R",
  "rf.R",
  "xgboost.R",
  "bart.R",
  "svm.R"
)
models |> purrr::walk(run_models, run_parallel = TRUE)

# ---- Results Processing ----
# Summarize results across all models
run_models("summarize.R", run_parallel = TRUE)

# Display total modeling time
message(glue::glue(
  'Total modeling time: {hms::round_hms(hms::as_hms(Sys.time() - beginning), 60)}'
))

# ---- Results Output ----
# Display primary performance metrics
best_ROC
best_sens_spec

# Set up results directory and save outputs
# Change iteration_name if using non-default recipe or want to date results
# iteration_name <- "default_recipe"
iteration_name <- "April2025"

# Create directory for results if it doesn't exist
if (!dir.exists(glue::glue("{here::here()}/data/{iteration_name}/"))) {
  dir.create(glue::glue("{here::here()}/data/{iteration_name}/"))
}

# Save performance metrics
save(
  best_ROC,
  file = glue::glue("data/{iteration_name}/best_ROC.RData")
)
save(
  best_sens_spec,
  file = glue::glue("data/{iteration_name}/best_sens_spec.RData")
)

# ---- Variable Importance Calculations ----
# WARNING: THIS SECTION IS COMPUTATIONALLY EXPENSIVE!
# Please close unnecessary programs and be advised that this will take 
# multiple hours to complete. Consider processing overnight or when 
# computer resources aren't immediately needed.
if (!require('DALEX')) install.packages('DALEX')
DALEX::install_dependencies()
source_wrapper("vip_calcs.R")

# ---- Final Timing ----
message(glue::glue(
  'Total processing time: {hms::round_hms(hms::as_hms(Sys.time() - beginning), 60)}'
))

The do_everything.R code takes several hours to run on a modern workstation (e.g., a 2023 M2 Ultra Mac Studio with 192 GB of RAM), so you probably want to run the workflow overnight.

The source of the participants and predictor variables are described in general terms in the paper. Below, before we introduce the modeling concepts and procedures, we provide more details to allow you to familiarize yourself with the participants and explore the data.

The individual variables are documented in the public.ctn0094data R package described here. If you run the steps to load the R software libraries (libraries.R) and load the data (load.R) described in the workflow below, you will have access to the analysis dataset.

The Workflow

When you view the do_everything.R file, you will notice its design to be a wrapper to create project-specific functions and execute several other R programs.

The preprocessing of the data begins by loading the required R packages using the libraries.R file:

# For installing keras to run NNet:
# reticulate::py_install('keras')
# install.packages('keras')

packages <- c(
  "conflicted",
  "CTNote",
  "doParallel",
  "ggthemes",                   # for theme_few()
  "janitor",                    # for tabyl()
  "patchwork",
  "public.ctn0094data",
  "public.ctn0094extra",
  "tidymodels",
  "tidyverse",

  # models
  "AppliedPredictiveModeling",  # for NNet
  "brulee",                     # for NNet
  "dbarts",                     # for BART
  "earth",                      # for MARS
  "glmnet",                     # for LASSO
  "keras",                      # for NNet
  "kernlab",                    # for SVM
  "kknn",                       # for KNN
  "rpart",                      # for CART
  "rpart.plot",                 # for CART
  "randomForest",
  "themis",                     # for upsampling
  "torch",                      # for NNet
  "vip",                        # for model specific variable importance
  "xgboost"                     # for xgboost
)

# Loop through list of packages and install each quietly
for (pkg in packages) {
  suppressPackageStartupMessages(library(pkg, character.only = TRUE))
}

tidymodels_prefer()

options(dplyr.summarise.inform = FALSE)
conflicted::conflict_prefer("vi", "vip", quiet = TRUE)

The modeling dataset, which is called analysis, is created using a lot of tidyverse code, relying heavily on the dplyr package. If you are not familiar with the tidyverse dialect of R, follow the advice found here: https://www.tidyverse.org/learn/. Many details of the data preparation are discussed below. The code itself is in the load.R file:

This creates the analysis dataset. We left some of the code for alternative models we tried and for conducting a sensitivity analysis.

# run libraries.R first

subjects <-
  public.ctn0094data::randomization |>
  filter(which == 1) |>
  inner_join(everybody, by = "who") |>
  rename(rand_date = when) |>
  select(project, who, treatment, rand_date) |>
  mutate(
    medication =
      case_when(
        treatment == "Inpatient BUP" ~ "Buprenorphine",
        treatment == "Inpatient NR-NTX" ~ "Naltrexone",
        treatment == "Methadone" ~ "Methadone",
        treatment == "Outpatient BUP" ~ "Buprenorphine",
        treatment == "Outpatient BUP + EMM" ~ "Buprenorphine",
        treatment == "Outpatient BUP + SMM" ~ "Buprenorphine"
      ),
    medication = factor(medication),
    in_out = case_when(
      treatment == "Inpatient BUP" ~ "Inpatient",
      treatment == "Inpatient NR-NTX" ~ "Inpatient",
      treatment == "Methadone" ~ "Outpatient",
      treatment == "Outpatient BUP" ~ "Outpatient",
      treatment == "Outpatient BUP + EMM" ~ "Outpatient",
      treatment == "Outpatient BUP + SMM" ~ "Outpatient"
    ),
    in_out = factor(in_out)
  )

# asi -----
# I should use the indicator of who was injecting drugs: `used_iv`.
analysis <-
  left_join(subjects, public.ctn0094data::asi, by = "who")
rm(subjects)

# demographics -----
# Use all demographic features
analysis <-
  left_join(analysis, public.ctn0094data::demographics, by = "who") |>
  mutate(
    education =
      str_replace_all(education, stringr::fixed("/"), " or ") |>
        as.factor()
  ) |>
  mutate(
    education = if_else(
      education == "Refused/Missing", NA_character_, as.character(education)
    ),
    education = factor(education)
  ) |>
  mutate(
    job = if_else(
      job == "Refused/Missing", NA_character_, as.character(job)
    ),
    job = factor(job)
  ) |>
  mutate(
    marital = if_else(
      marital == "Refused/Missing", NA_character_, as.character(marital)
    ),
    marital = factor(marital)
  )

# fagerstrom -----
# Use all smoking features
analysis <-
  left_join(analysis, public.ctn0094data::fagerstrom, by = "who") |>
  mutate(
    per_day =
      str_replace_all(per_day, stringr::fixed("-"), " TO "),
    per_day = if_else(per_day == "", "0", per_day),
    per_day = fct(
      per_day,
      levels = c("0", "10 OR LESS", "11 TO 20", "21 TO 30", "31 OR MORE")
    ),
    per_day = ordered(per_day)
  )

# pain -----
# select pain closest to randomization
pain <-
  public.ctn0094data::pain |>
  group_by(who) |>
  mutate(absolute = abs(when)) |>
  arrange(absolute) |>
  filter(row_number() == 1) |>
  filter(absolute <= 28) |>
  select(who, pain)

analysis <-
  left_join(analysis, pain, by = "who")
rm(pain)

# psychiatric -----
# Include psych conditions, brain damage, epilepsy, and drug use diagnoses

# self report or MD report is called yes, any no is next, otherwise unknown
either <- function(x1, x2) {
  answer <-
    case_when(
      x1 == "Yes" | x2 == "Yes" ~ "Yes",
      x1 == "No" & x2 == "No" ~ "No",
      TRUE ~ "Unknown"
    )
  factor(answer)
}

psychiatric <-
  public.ctn0094data::psychiatric |>
  mutate(any_schiz = either(has_schizophrenia, schizophrenia)) |>
  mutate(any_dep = either(has_major_dep, depression)) |>
  mutate(any_anx = either(has_anx_pan, anxiety)) |>
  select(
    c(
      who,
      any_schiz, any_dep, any_anx, has_bipolar, has_brain_damage, has_epilepsy,
      # has_opiates_dx, useless for modeling
      has_alcol_dx, has_amphetamines_dx, has_cannabis_dx,
      has_cocaine_dx, has_sedatives_dx
    )
  )

analysis <-
  left_join(analysis, psychiatric, by = "who")
rm(either, psychiatric)

# qol ----
# homelessness
analysis <-
  left_join(analysis, public.ctn0094data::qol, by = "who")

# rbs ----
# rbs use info
rbs <-
  public.ctn0094data::rbs |>
  mutate(
    days =
      if_else(
        is.na(days) & did_use == "No", 0, days
      )
  ) |>
  pivot_wider(names_from = "what", values_from = c("did_use", days))

analysis <-
  left_join(analysis, rbs, by = "who")
rm(rbs)

# rbs_iv -----
# amount of IV drug use and needle sharing
rbs_iv <-
  public.ctn0094data::rbs_iv |>
  rename(
    days_iv_use = days,
    max_iv_use = max
  ) |>
  select(
    who, days_iv_use, max_iv_use, amount, shared, cocaine_inject_days,
    heroin_inject_days, speedball_inject_days, opioid_inject_days,
    speed_inject_days
  )

analysis <-
  left_join(analysis, rbs_iv, by = "who")
rm(rbs_iv)

# sex -----
# total partners
sex <-
  public.ctn0094data::sex |>
  rename(sex_partners = t_p) |>
  select(who, sex_partners)

analysis <-
  left_join(analysis, sex, by = "who")
rm(sex)

# site -----
analysis <-
  # site_masked from code_with_notes.R
  left_join(analysis, site_masked, by = "who")

# tlfb -----
# how many days of use and what drugs
tlfb <-
  public.ctn0094data::tlfb |>
  group_by(who) |>
  filter(when > -29 & when < 0)

tlfb_days_of_use <-
  tlfb |>
  select(who, when) |>
  distinct() |>
  summarise(tlfb_days_of_use_n = n())

tlfb_what_used <-
  tlfb |>
  select(who, what) |>
  distinct() |>
  summarise(tlfb_what_used_n = n())

analysis <-
  left_join(analysis, tlfb_days_of_use, by = "who") |>
  left_join(tlfb_what_used, by = "who")
rm(tlfb, tlfb_days_of_use, tlfb_what_used)

# withdrawal -----
# select withdrawal closest to randomization
withdrawal <-
  public.ctn0094data::withdrawal |>
  group_by(who) |>
  mutate(absolute = abs(when)) |>
  arrange(absolute) |>
  filter(row_number() == 1) |>
  filter(absolute <= 28) |>
  select(who, withdrawal)

analysis <-
  left_join(analysis, withdrawal, by = "who")
rm(withdrawal)

detox <-
  public.ctn0094data::detox |>
  arrange(who, when) |>
  group_by(who) |>
  summarise(detox_days = last(when) - first(when)) |>
  ungroup()

analysis <-
  left_join(analysis, detox, by = "who") |>
  mutate(
    detox_days = 
      if_else(
        is.na(detox_days), 
        runif(nrow(analysis), min = -0.001, max = 0.001), 
        detox_days
      )
  )
# from sensitivity analysis - do not delete
# mutate(
#  detox_days = 
#    if_else(is.na(detox_days), rpois(nrow(analysis), lambda = 5), detox_days))

# data from Dr Pan
latent_class <- 
  read_rds("data/public_polysubstane_group.rds")

almost_analysis <-
  analysis |>
  mutate(
    x = case_when(
      project == "27" ~ "CTN-0027",
      project == "30" ~ "CTN-0030",
      project == "51" ~ "CTN-0051"
    )
  ) |>
  mutate(`CTN Trial Number` = factor(x)) |>
  select(-x) |>
  inner_join(latent_class, by = "who")

# this is Ray's variables with or without latent class
analysis <- 
  almost_analysis |>
  rename(trial = `CTN Trial Number`) |>
  select(
    who,
    # basics
    trial, medication, in_out,
    # asi
    used_iv,
    # demographics
    age, race, is_hispanic, job, is_living_stable, education, marital, is_male,
    # fagerstrom
    is_smoker, per_day, ftnd,
    # pain
    pain,
    # psychiatric
    any_schiz, any_dep, any_anx, has_bipolar, has_brain_damage, has_epilepsy,
    has_alcol_dx, has_amphetamines_dx, has_cannabis_dx,
    has_cocaine_dx, has_sedatives_dx,
    # qol
    is_homeless,
    # rbs
    did_use_cocaine, did_use_heroin, did_use_speedball,
    did_use_opioid, did_use_speed, days_cocaine, days_heroin,
    days_speedball, days_opioid, days_speed,
    # rbs_iv
    days_iv_use, shared,
    # not in kyles variables - medication is also extra - he had who and treatment
    # max_iv_use,  amount, cocaine_inject_days,
    # heroin_inject_days, speedball_inject_days, opioid_inject_days,
    # speed_inject_days,
    # sex
    # sex_partners, # see if DALEX will work if this is dropped
    # tlfb
    tlfb_days_of_use_n, tlfb_what_used_n,
    # withdrawal
    withdrawal,
    # detox
    detox_days,
    # latent class
    # group, # dont use latent clas groups
    # site
    site_masked
  )

outcome <- 
  public.ctn0094extra::derived_weeklyOpioidPattern |>
  mutate(use_pattern_uds = paste0(Phase_1, Phase_2)) |>
  rowwise() |>
  mutate(
    udsPattern = recode_missing_visits(
      use_pattern = use_pattern_uds
    )
  ) |>
  mutate(
    udsPattern = recode_missing_visits(
      use_pattern = udsPattern,
      missing_is = "*"
    )
  ) |>
  mutate(
    udsPatternTrimmed = str_sub(udsPattern, start = 3L)
  ) |>
  rowwise() |>
  mutate(
    lee2018_rel = detect_in_window(
      use_pattern = udsPatternTrimmed,
      window_width = 4L,
      threshold = 4L
    )
  ) |>
  unnest(cols = "lee2018_rel", names_sep = "_") |>
  mutate(lee2018_rel_time = lee2018_rel_time + 2) |>
  select(who, starts_with("lee2018_rel")) |>
  rename(
    RsT_ctnFiftyOne_2018 = lee2018_rel_time,
    RsE_ctnFiftyOne_2018 = lee2018_rel_event
  ) |>
  transmute(
    who,
    did_relapse = factor(RsE_ctnFiftyOne_2018, levels = c("0", "1"))
  )

analysis <- 
  inner_join(analysis, outcome, by = "who") |>
  select(-who)

write_rds(analysis, "data/analysis.rds")

final_analysis <- 
  analysis
write_rds(final_analysis, "data/final_analysis.rds")

####

# set to subset or subset_not_used
# uses latent class variables instead of history

# use this version to use latent class instead of my preprocessing
analysis2 <- 
  almost_analysis |>
  select(
    -contains("did_use"), -contains("tlfb"), -contains("_iv"),
    -c(
      days_cocaine, days_heroin, days_speedball, days_opioid, days_speed,
      days_iv_use, cocaine_inject_days, heroin_inject_days,
      speedball_inject_days, opioid_inject_days, speed_inject_days, 
      tlfb_days_of_use_n
    )
  )
write_rds(analysis2, "data/analysis2.rds")

alternative_analysis <- 
  analysis2 |>
  rename(trial = `CTN Trial Number`) |>
  select(
    # who, # removed because DALEXtra::explain_tidymodels uses who as predictor
    # basics
    trial, treatment, in_out,
    # asi
    # used_iv,
    # demographics
    age, race, is_hispanic, job, is_living_stable, education, marital, is_male,
    # fagerstrom
    is_smoker, per_day, ftnd,
    # pain
    pain,
    # psychiatric
    any_schiz, any_dep, any_anx, has_bipolar, has_brain_damage, has_epilepsy,
    has_alcol_dx, has_amphetamines_dx, has_cannabis_dx,
    has_cocaine_dx, has_sedatives_dx,
    # qol
    is_homeless,
    # rbs
    # did_use_cocaine, did_use_heroin, did_use_speedball,
    # did_use_opioid, did_use_speed, days_cocaine, days_heroin,
    # days_speedball, days_opioid, days_speed,
    # rbs_iv
    # days_iv_use, shared,
    # sex
    sex_partners,
    # tlfb
    # tlfb_days_of_use_n, tlfb_what_used_n,
    # withdrawal
    withdrawal,
    # detox
    detox_days,
    # latent class
    group
  )

check_exclusion <- 
  analysis |>
  select(
    contains("days"), contains("did_use"), contains("tlfb"), contains("_iv")
  )

suppressMessages(conflict_prefer("col_factor", "readr"))

The modeling and model-specific data preprocessing is performed using the tidymodels ecosystem of R packages. The best places to learn about tidymodels are https://www.tidymodels.org/ and https://www.tmwr.org/. The concepts, pieces, and relevant details we use are explained below. The model-specific preprocessing is located in the preprocess_recipe.R file:

set.seed(305)

a_split <- initial_split(analysis, strata = "did_relapse")

a_train <- training(a_split)
save(a_train, file = "data/a_train.RData")

a_test <- testing(a_split)
save(a_test, file = "data/a_test.RData")

a_fold <- vfold_cv(a_train, v = 5)

a_recipe <-
  recipe(formula = did_relapse ~ ., data = a_train) |>
  # no update_role because DALEXtra::explain_tidymodels uses who as a predictor
  # update_role(who, new_role = "id variable") |>
  step_nzv(all_predictors()) |>
  # Needed for KNN see https://github.com/tidymodels/recipes/issues/926
  step_string2factor(all_nominal_predictors()) |>
  step_impute_knn(all_predictors()) |>
  step_dummy(all_nominal_predictors()) |>
  step_other(all_nominal_predictors()) |>
  step_corr(all_numeric_predictors()) |>
  step_normalize(all_numeric_predictors())

We’ve organized the code for each model workflow we use in its own file. For example, the CART modeling is in the CART.R file. We will later explain the code after first introducing key ML concepts.

After completing the modeling, we summarize the results using an R script called summarize.R:

# load libraries ----------------------------------------------------------
# library(conflicted)
# suppressPackageStartupMessages(library(tidymodels))
# conflicted::conflict_prefer("filter", "dplyr")


# import comparison objects -----------------------------------------------
models <- 
  list(
    "knn", "logistic", "logistic_via_lasso", "lasso",
    "mars", "cart", "rf", "xgb", "bart", "svm", "nnet"
  )
for (model in models) {
  # metrics data
  metrics_data_file <- glue::glue("data/{model}_metrics.RData")
  data <- rio::import(metrics_data_file, trust = TRUE)
  assign(glue::glue("{model}_metrics"), data)

  # resample data
  resample_data_file <- glue::glue("data/{model}_resample_best.RData")
  data <- rio::import(resample_data_file, trust = TRUE)
  assign(glue::glue("{model}_resample_best"), data)

  # metrics TEST data
  metrics_test_data_file <- glue::glue("data/{model}_metrics_test.RData")
  data <- rio::import(metrics_test_data_file, trust = TRUE)
  assign(glue::glue("{model}_metrics_test"), data)
}

# do the comparison -------------------------------------------------------

best_full_train <-
  bind_rows(
    knn_metrics |>
      filter(.metric == "roc_auc") |>
      mutate(Model = "KNN") |>
      select(Model, .estimate),
    logistic_via_lasso_metrics |>
      filter(.metric == "roc_auc") |>
      mutate(Model = "Logistic LASSO") |>
      select(Model, .estimate),
    logistic_metrics |>
      filter(.metric == "roc_auc") |>
      mutate(Model = "Logistic") |>
      select(Model, .estimate),
    lasso_metrics |>
      filter(.metric == "roc_auc") |>
      mutate(Model = "LASSO") |>
      select(Model, .estimate),
    mars_metrics |>
      filter(.metric == "roc_auc") |>
      mutate(Model = "MARS") |>
      select(Model, .estimate),
    cart_metrics |>
      filter(.metric == "roc_auc") |>
      mutate(Model = "CART") |>
      select(Model, .estimate),
    rf_metrics |>
      filter(.metric == "roc_auc") |>
      mutate(Model = "Random Forest") |>
      select(Model, .estimate),
    bart_metrics |>
      filter(.metric == "roc_auc") |>
      mutate(Model = "BART") |>
      select(Model, .estimate),
    xgb_metrics |>
      filter(.metric == "roc_auc") |>
      mutate(Model = "Boosted Trees") |>
      select(Model, .estimate),
    svm_metrics |>
      filter(.metric == "roc_auc") |>
      mutate(Model = "Support Vector") |>
      select(Model, .estimate),
    nnet_metrics |>
      filter(.metric == "roc_auc") |>
      mutate(Model = "Neural Net") |>
      select(Model, .estimate)
  ) |>
  arrange(desc(.estimate)) |>
  rename(`Full Training Dataset` = .estimate)

best_cv <-
  bind_rows(
    knn_resample_best |>
      slice_head(n = 1) |>
      mutate(Model = "KNN") |>
      select(Model, mean),
    logistic_via_lasso_resample_best |>
      slice_head(n = 1) |>
      mutate(Model = "Logistic LASSO") |>
      select(Model, mean),
    logistic_resample_best |>
      slice_head(n = 1) |>
      mutate(Model = "Logistic") |>
      select(Model, mean),
    lasso_resample_best |>
      slice_head(n = 1) |>
      mutate(Model = "LASSO") |>
      select(Model, mean),
    mars_resample_best |>
      slice_head(n = 1) |>
      mutate(Model = "MARS") |>
      select(Model, mean),
    cart_resample_best |>
      slice_head(n = 1) |>
      mutate(Model = "CART") |>
      select(Model, mean),
    rf_resample_best |>
      slice_head(n = 1) |>
      mutate(Model = "Random Forest") |>
      select(Model, mean),
    bart_resample_best |>
      slice_head(n = 1) |>
      mutate(Model = "BART") |>
      select(Model, mean),
    xgb_resample_best |>
      slice_head(n = 1) |>
      mutate(Model = "Boosted Trees") |>
      select(Model, mean),
    svm_resample_best |>
      slice_head(n = 1) |>
      mutate(Model = "Support Vector") |>
      select(Model, mean),
    nnet_resample_best |>
      slice_head(n = 1) |>
      mutate(Model = "Neural Net") |>
      select(Model, mean)
  ) |>
  arrange(desc(mean)) |>
  rename(`Cross Validation` = mean) |>
  mutate(`Cross Validation` = round(`Cross Validation`, 4))

testing_results <-
  bind_rows(
    knn_metrics_test |> transmute(Model = "KNN", `Testing Dataset` = .estimate),
    logistic_via_lasso_metrics_test |> 
      transmute(Model = "Logistic LASSO", `Testing Dataset` = .estimate),
    logistic_metrics_test |> 
      transmute(Model = "Logistic", `Testing Dataset` = .estimate),
    lasso_metrics_test |> 
      transmute(Model = "LASSO", `Testing Dataset` = .estimate),
    mars_metrics_test |> 
      transmute(Model = "MARS", `Testing Dataset` = .estimate),
    cart_metrics_test |> 
      transmute(Model = "CART", `Testing Dataset` = .estimate),
    rf_metrics_test |> 
      transmute(Model = "Random Forest", `Testing Dataset` = .estimate),
    bart_metrics_test |> 
      transmute(Model = "BART", `Testing Dataset` = .estimate),
    xgb_metrics_test |> 
      transmute(Model = "Boosted Trees", `Testing Dataset` = .estimate),
    svm_metrics_test |> 
      transmute(Model = "Support Vector", `Testing Dataset` = .estimate),
    nnet_metrics_test |> 
      transmute(Model = "Neural Net", `Testing Dataset` = .estimate)
  )

best_ROC <- 
  inner_join(best_cv, best_full_train, by = join_by(Model)) |>
  left_join(
    testing_results,
    by = join_by(Model)
  ) |>
  # 2 = round(2, digits = 4),
  mutate(across(3:4, ~ round(.x, digits = 3))) |>
  arrange(desc("Cross Validation"))


# ----------------------- calculate accuracy on train & test ------------------
# Function to calculate accuracy and F1
calculate_model_metrics <- function(model_fit, data) {
  augmented_data <- augment(model_fit, new_data = data)

  sens <- 
    augmented_data |>
    sens(truth = did_relapse, estimate = .pred_class) |>
    pull(.estimate)

  spec <-
    augmented_data |>
    spec(truth = did_relapse, estimate = .pred_class) |>
    pull(.estimate)

  accuracy <-
    augmented_data |>
    accuracy(truth = did_relapse, estimate = .pred_class) |>
    pull(.estimate)

  f1 <- 
    augmented_data |>
    f_meas(truth = did_relapse, estimate = .pred_class) |>
    pull(.estimate)

  return(c(sens, spec, accuracy, f1))
}

# Initialize empty lists to store results
train_results <- list()
test_results <- list()

# Initialize an empty tibble to store results
results_table <-
  tibble(
    Model = character(),
    `Train Sensitivity` = numeric(),
    `Train Specificity` = numeric(),
    `Train Accuracy` = numeric(),
    `Train F1` = numeric(),
    `Test Sensitivity` = numeric(),
    `Test Specificity` = numeric(),
    `Test Accuracy` = numeric(),
    `Test F1` = numeric()
  )

# Loop through models
for (model in models) {
  # Load model fit
  metrics_data_file <- glue::glue("data/{model}_final_fit.RData")
  model_fit <- rio::import(metrics_data_file, trust = TRUE)

  # Calculate metrics for train and test sets
  train_metrics <- calculate_model_metrics(model_fit, a_train)
  test_metrics <- calculate_model_metrics(model_fit, a_test)

  # Store results
  results_table <-
    results_table |>
    add_row(
      Model = model,
      `Train Sensitivity` = train_metrics[1],
      `Train Specificity` = train_metrics[2],
      `Train Accuracy` = train_metrics[3],
      `Train F1` = train_metrics[4],
      `Test Sensitivity` = test_metrics[1],
      `Test Specificity` = test_metrics[2],
      `Test Accuracy` = test_metrics[3],
      `Test F1` = test_metrics[4]
    )
}

best_sens_spec <-
  results_table |>
  arrange(desc(`Train F1`))

Finally, we apply interpretable ML tools from the DALEX and DALEXtra R packages to reprocess the data using our modeling results. This process involves repeatedly applying the model to estimate performance, as if on new participants. Due to its iterative nature, these steps take several hours to complete. The code can be found in vip_calcs.R:

load(glue::glue("data/a_train.RData"))

library(DALEXtra)
library(tidymodels)

do_vip <- function(thingy) {
  call_info <- match.call()

  load(glue::glue("{here::here()}/data/a_train.RData"))
  load(glue::glue("{here::here()}/data/{as.character(call_info[[2]])}.RData"))

  print(glue::glue("Starting {as.character(call_info[[2]])}"))
  
  set.seed(305)
  
  the_explainer <-
    explain_tidymodels(
      thingy,
      data = a_train |> select(-did_relapse),
      y = as.numeric(a_train$did_relapse),
      label = as.character(call_info[[2]]),
      verbose = FALSE
    )
  doParallel::registerDoParallel(cores = parallel::detectCores() - 2)
  start_time <- Sys.time()

  the_vip_results <-
    the_explainer |>
    model_parts(N = NULL, B = 20, type = "difference")

  print(Sys.time() - start_time)
  doParallel::stopImplicitCluster()
  the_vip_results
}

load(glue::glue("{here::here()}/data/bart_final_fit.RData"))
vip_bart <- do_vip(bart_final_fit)
save(vip_bart, file = "data/vip_bart.RData")

load(glue::glue("{here::here()}/data/cart_final_fit.RData"))
vip_cart <- do_vip(cart_final_fit)
save(vip_cart, file = "data/vip_cart.RData")

load(glue::glue("{here::here()}/data/lasso_final_fit.RData"))
vip_lasso <- do_vip(lasso_final_fit)
save(vip_lasso, file = "data/vip_lasso.RData")

load(glue::glue("{here::here()}/data/logistic_final_fit.RData"))
vip_logistic <- do_vip(logistic_final_fit)
save(vip_logistic, file = "data/vip_logistic.RData")

load(glue::glue("{here::here()}/data/logistic_via_lasso_final_fit.RData"))
vip_logistic_via_lasso <- do_vip(logistic_via_lasso_final_fit)
save(vip_logistic_via_lasso, file = "data/vip_logistic_via_lasso.RData")

load(glue::glue("{here::here()}/data/mars_final_fit.RData"))
vip_mars <- do_vip(mars_final_fit)
save(vip_mars, file = "data/vip_mars.RData")

load(glue::glue("{here::here()}/data/nnet_final_fit.RData"))
vip_nnet <- do_vip(nnet_final_fit)
save(vip_nnet, file = "data/vip_nnet.RData")

load(glue::glue("{here::here()}/data/rf_final_fit.RData"))
vip_rf <- do_vip(rf_final_fit)
save(vip_rf, file = "data/vip_rf.RData")

load(glue::glue("{here::here()}/data/svm_final_fit.RData"))
vip_svm <- do_vip(svm_final_fit)
save(vip_svm, file = "data/vip_svm.RData")

load(glue::glue("{here::here()}/data/xgb_final_fit.RData"))
vip_xgb <- do_vip(xgb_final_fit)
save(vip_xgb, file = "data/vip_xgb.RData")

Participant Details

Subjects

Of the 2,492 people who were randomized, a total of 2,478 individuals were used for the analysis. A total of 14 people were excluded because they had no self-reported drug use information before and after randomization. The analysis dataset is created by the code in load.R. Most of the steps to create analysis are simple transforms and joins to merge many database tables but because some of the subjects in CTN-0030 were randomized twice, the randomization table requires a filtering step. For all analyses, only the first treatment was used.

subjects <-
  public.ctn0094data::randomization |>
  filter(which == 1) |>
  inner_join(everybody, by = "who") |>
  rename(rand_date = when) |>
  select(project, who, treatment, rand_date) |>
  mutate(
    medication = case_when(
      treatment == "Inpatient BUP" ~ "Buprenorphine",
      treatment == "Inpatient NR-NTX" ~ "Naltrexone",
      treatment == "Methadone" ~ "Methadone",
      treatment == "Outpatient BUP" ~ "Buprenorphine",
      treatment == "Outpatient BUP + EMM" ~ "Buprenorphine",
      treatment == "Outpatient BUP + SMM" ~ "Buprenorphine"
    ),
    medication = factor(medication),
    in_out = case_when(
      treatment == "Inpatient BUP" ~ "Inpatient",
      treatment == "Inpatient NR-NTX" ~ "Inpatient",
      treatment == "Methadone" ~ "Outpatient",
      treatment == "Outpatient BUP" ~ "Outpatient",
      treatment == "Outpatient BUP + EMM" ~ "Outpatient",
      treatment == "Outpatient BUP + SMM" ~ "Outpatient"
    ),
    in_out = factor(in_out)
  )

Features/Variables

Note

The terms “feature” and “variable” each refer to the details of a treatment outcome predictor. Typically the two terms can be used interchangeably. A yes/no variable, like “used cocaine”, is a single feature in the model. Things get complicated when a predictor can be converted to multiple features. For example, education can be converted to a series of yes/no indicators like “graduated grade school”, “attended high school but did not graduate”, and “graduated high school”. In this case, “education” is a feature and so is “high school graduate”. Similar ambiguities arise when an algorithms splits a continuous variable like age into subgroups. Specifically, age with a numeric value and categorical “age > 60” can be features. The subtle difference, between using a variable as a predictor versus using groups or levels from the variable as predictors can hopefully be inferred from context.

The output below contains descriptive statistics on the analysis dataset which includes categorical variables, known in R as factors, and numeric variables. In R, factors are categorical variables that have a fixed and known set of possible values. Below, these data are described using the R package skimr. The skim() function generates summary statistics, presenting them in separate tables for factor and numeric variables. In both tables, each variable is listed in a column titled skim_variable, with its corresponding summary statistics presented in the adjacent columns:

  • n_missing displays the total number of missing values.
  • complete_rate shows a particular variable’s percentage of non-missing values; a complete_rate of 1.00 corresponds to 100% completed records and would therefore have a n_missing of 0.
  • ordered is used to indicate whether the factor variable should be ordered in some way (e.g., small, medium and large are ordered categories).
  • unique identifies the total number of distinct categorical factors.
  • top_counts shows an abbreviated version of the categories and a count of the records.
skim_variable n_missing complete_rate factor.ordered factor.n_unique factor.top_counts
trial 0 1.000 FALSE 3 CTN: 1262, CTN: 646, CTN: 570
medication 0 1.000 FALSE 3 Bup: 1668, Met: 527, Nal: 283
in_out 0 1.000 FALSE 2 Out: 1908, Inp: 570
used_iv 1102 0.555 FALSE 2 No: 861, Yes: 515
race 0 1.000 FALSE 4 Whi: 1910, Oth: 333, Bla: 222, Ref: 13
is_hispanic 0 1.000 FALSE 2 No: 2144, Yes: 334
job 1103 0.555 FALSE 5 Ful: 742, Une: 285, Par: 262, Oth: 47
is_living_stable 1103 0.555 FALSE 2 Yes: 1329, No: 46
education 1102 0.555 FALSE 3 Mor: 573, HS : 537, Les: 266
marital 1105 0.554 FALSE 3 Nev: 802, Sep: 307, Mar: 264
is_male 0 1.000 FALSE 2 Yes: 1646, No: 832
is_smoker 5 0.998 FALSE 2 Yes: 2096, No: 377
per_day 5 0.998 TRUE 5 11 : 1000, 10 : 662, 0: 377, 21 : 344
ftnd 384 0.845 FALSE 11 5: 341, 4: 317, 6: 279, 3: 275
pain 116 0.953 FALSE 4 Ver: 1499, No : 533, Sev: 319, Mis: 11
any_schiz 0 1.000 FALSE 3 No: 1303, Unk: 1077, Yes: 98
any_dep 0 1.000 FALSE 3 Yes: 1123, Unk: 805, No: 550
any_anx 0 1.000 FALSE 3 Yes: 1214, Unk: 779, No: 485
has_bipolar 2 0.999 FALSE 2 No: 2213, Yes: 263
has_brain_damage 3 0.999 FALSE 2 No: 2175, Yes: 300
has_epilepsy 0 1.000 FALSE 2 No: 2367, Yes: 111
has_alcol_dx 818 0.670 FALSE 2 No: 1251, Yes: 409
has_amphetamines_dx 818 0.670 FALSE 2 No: 1434, Yes: 226
has_cannabis_dx 818 0.670 FALSE 2 No: 1274, Yes: 386
has_cocaine_dx 819 0.669 FALSE 2 No: 1128, Yes: 531
has_sedatives_dx 818 0.670 FALSE 2 No: 1354, Yes: 306
is_homeless 1908 0.230 FALSE 2 No: 427, Yes: 143
did_use_cocaine 2 0.999 FALSE 2 Yes: 1271, No: 1205
did_use_heroin 2 0.999 FALSE 2 Yes: 1712, No: 764
did_use_speedball 2 0.999 FALSE 2 No: 2174, Yes: 302
did_use_opioid 2 0.999 FALSE 2 Yes: 1479, No: 997
did_use_speed 2 0.999 FALSE 2 No: 1899, Yes: 577
shared 1 1.000 FALSE 2 No: 2200, Yes: 277
withdrawal 8 0.997 FALSE 4 1: 1229, 2: 998, 3: 200, 0: 43
site_masked 0 1.000 FALSE 32 11: 120, 23: 110, 17: 109, 24: 105
did_relapse 0 1.000 FALSE 2 1: 1791, 0: 687

The table for numeric data is shown in the dropdown below (do notice that the table scrolls left-to-right) and provides summary statistics such as the mean, the standard deviation (sd), a histogram, and the quartiles (Q0, Q1, Q2, Q3, Q4) labeled p0, p25, p50, p75, and p100 respectively. Each quartile value represents the number that separates each quarter. Therefore, Q0 is the smallest value in the data and Q4 would be the largest. Values between Q0 and Q1 detail the 25% lowest values in the data, values between Q1 and Q2 are the next 25%, etc. Lastly, a histogram is included to visualize the data distribution.

This table is scrollable left-to-right.

skim_variable n_missing complete_rate numeric.mean numeric.sd numeric.p0 numeric.p25 numeric.p50 numeric.p75 numeric.p100 numeric.hist
age 0 1.000 35.085 10.69 18.000 26 33 43.000 77 ▇▆▅▁▁
days_cocaine 2 0.999 2.659 6.11 0.000 0 0 2.000 30 ▇▁▁▁▁
days_heroin 2 0.999 16.414 13.91 0.000 0 23 30.000 30 ▆▁▁▁▇
days_speedball 2 0.999 0.956 4.14 0.000 0 0 0.000 30 ▇▁▁▁▁
days_opioid 2 0.999 12.186 13.37 0.000 0 4 30.000 30 ▇▁▁▁▅
days_speed 2 0.999 0.877 3.42 0.000 0 0 0.000 30 ▇▁▁▁▁
days_iv_use 1 1.000 13.258 14.04 0.000 0 4 30.000 30 ▇▁▁▁▆
tlfb_days_of_use_n 6 0.998 24.805 5.21 1.000 24 28 28.000 28 ▁▁▁▁▇
tlfb_what_used_n 6 0.998 2.777 1.36 1.000 2 3 4.000 10 ▇▇▂▁▁
detox_days 0 1.000 1.783 4.11 -0.001 0 0 0.001 40 ▇▁▁▁▁

Full documentation for the dataset tables and descriptions of their variables can be found here.

The model feature details are shown in Table 1. For approximately 1,500 participants, schizophrenia, depression and anxiety were assessed using both a medical/psychiatric interview and the ASI-Lite questionnaire.1 The agreement between the two sources shows little/weak agreement (\(\phi\) is -0.25, -0.34, -0.36 respectively). Therefore, we created composite indicators that scored a participant as affected if they were positive on either measure and negative otherwise. A priori, it was unclear how to handle treatment regimen and recent drug use history in the modeling process. Treatments were initially grouped using study-specific treatment arms (with six levels). However, such groupings are not helpful for participants who wish to generalize our results. Therefore, we created two indicators variables: one representing the study drug and the other representing inpatient vs. outpatient care.

Recent drug use was also processed in two ways. One approach, reflected in the variables presented in Table 1 and Table 2, uses many variables, such as those indicating comorbid substance use diagnoses, the total number of distinct drugs used in the past 28 days, the presence and absence of specific drugs, the amount of use for specific drugs, and the number of days where at least one drug was used. The other method was to assign each participant to a polysubstance drug use profile based on a previously published latent class analysis.2 The latter method is somewhat problematic in that the drug use profiles were previously built using all the data. While groupings were constructed irrespective of the outcome, this process did not involve using resamples when building latent classes in the subsets. This violates of the premise that feature engineering and preprocessing should be encapsulated inside of resamples (like the 10-fold cross-validation described below).3 The details of resampling will be explained in the next few sections. Currently, the statistical theory and software tools necessary to properly handle the tuning and “averaging” of results from latent class analysis across repeated samples are underdeveloped. This includes adequately accounting for the uncertainty in the estimates. Consequently, we opted for a strategy utilizing many variables rather than reducing variables through latent class analysis.

Feature Details
Demographics
Age Numeric
Ethnicity (is Hispanic) Yes, No, Unknown
Race Black, White, Other
Unemployed Yes, No, Unknown
Stable Housing Yes, No, Unknown
Education Missing, Less than HS, HS or GED, More than HS
Marital Status Unknown, Never married, Married or Partnered, Separated/Divorced/Widowed
Sex (is Male) Yes, No, Unknown
Drug Use
Smoking History Yes, No, Unknown
Fagerstrom Test for Nicotine Dependence Numeric
IV Drug use History Yes, No, Unknown
History & Physical
Pain Closest to Enrollment None, Very mild to moderate, Severe
Schizophrenia Yes, No, Unknown
Depression Yes, No, Unknown
Anxiety Yes, No, Unknown
Bipolar Yes, No, Unknown
Neurological Damage Yes, No, Unknown
Epilepsy Yes, No, Unknown
Comorbid Drug Use Diagnoses
Alcohol Yes, No, Unknown
Amphetamines Yes, No, Unknown
Cannabis Yes, No, Unknown
Cocaine Yes, No, Unknown
Treatment Details
Study Site Clinic Number
Clinic Type Inpatient, Outpatient
Medication Inpatient BUP, Inpatient NR-NTX, Methadone, Outpatient BUP, Outpatient BUP + Enhanced Medical Management, Outpatient BUP + Standard Medical Management
Drugs Used in Past 28 Days
Number of Distinct Substances Numeric
Number of Days with Any Use Numeric

CTN-0027
(N=1262)
CTN-0030
(N=646)
CTN-0051
(N=570)
Overall
(N=2478)

Table 2: Descriptive statistics. (This table scrolls vertically)

Medication
Buprenorphine 735 (58.2%) 646 (100%) 287 (50.4%) 1668 (67.3%)
Methadone 527 (41.8%) 0 (0%) 0 (0%) 527 (21.3%)
Naltrexone 0 (0%) 0 (0%) 283 (49.6%) 283 (11.4%)
Type
Inpatient 0 (0%) 0 (0%) 570 (100%) 570 (23.0%)
Outpatient 1262 (100%) 646 (100%) 0 (0%) 1908 (77.0%)
Used IV Drugs
No 58 (4.6%) 624 (96.6%) 179 (31.4%) 861 (34.7%)
Yes 102 (8.1%) 22 (3.4%) 391 (68.6%) 515 (20.8%)
Missing 1102 (87.3%) 0 (0%) 0 (0%) 1102 (44.5%)
Age
Mean (SD) 36.9 (11.1) 32.6 (10.2) 33.9 (9.63) 35.1 (10.7)
Median [Min, Max] 35.0 [18.0, 67.0] 30.0 [18.0, 77.0] 31.0 [19.0, 67.0] 33.0 [18.0, 77.0]
Race
Black 128 (10.1%) 21 (3.3%) 73 (12.8%) 222 (9.0%)
Other 228 (18.1%) 35 (5.4%) 70 (12.3%) 333 (13.4%)
Refused/missing 6 (0.5%) 1 (0.2%) 6 (1.1%) 13 (0.5%)
White 900 (71.3%) 589 (91.2%) 421 (73.9%) 1910 (77.1%)
Hispanic
No 1057 (83.8%) 616 (95.4%) 471 (82.6%) 2144 (86.5%)
Yes 205 (16.2%) 30 (4.6%) 99 (17.4%) 334 (13.5%)
Employment
Full Time 67 (5.3%) 407 (63.0%) 268 (47.0%) 742 (29.9%)
Other 7 (0.6%) 16 (2.5%) 24 (4.2%) 47 (1.9%)
Part Time 33 (2.6%) 113 (17.5%) 116 (20.4%) 262 (10.6%)
Student 1 (0.1%) 28 (4.3%) 10 (1.8%) 39 (1.6%)
Unemployed 52 (4.1%) 82 (12.7%) 151 (26.5%) 285 (11.5%)
Missing 1102 (87.3%) 0 (0%) 1 (0.2%) 1103 (44.5%)
Stable Housing
No 8 (0.6%) 2 (0.3%) 36 (6.3%) 46 (1.9%)
Yes 151 (12.0%) 644 (99.7%) 534 (93.7%) 1329 (53.6%)
Missing 1103 (87.4%) 0 (0%) 0 (0%) 1103 (44.5%)
Education
HS or GED 58 (4.6%) 251 (38.9%) 228 (40.0%) 537 (21.7%)
Less than HS 44 (3.5%) 99 (15.3%) 123 (21.6%) 266 (10.7%)
More than HS 58 (4.6%) 296 (45.8%) 219 (38.4%) 573 (23.1%)
Missing 1102 (87.3%) 0 (0%) 0 (0%) 1102 (44.5%)
Marital Status
Married or Partnered 25 (2.0%) 184 (28.5%) 55 (9.6%) 264 (10.7%)
Never married 86 (6.8%) 324 (50.2%) 392 (68.8%) 802 (32.4%)
Separated/Divorced/Widowed 48 (3.8%) 136 (21.1%) 123 (21.6%) 307 (12.4%)
Missing 1103 (87.4%) 2 (0.3%) 0 (0%) 1105 (44.6%)
Is Male
No 406 (32.2%) 257 (39.8%) 169 (29.6%) 832 (33.6%)
Yes 856 (67.8%) 389 (60.2%) 401 (70.4%) 1646 (66.4%)
Is Smoker
No 135 (10.7%) 161 (24.9%) 81 (14.2%) 377 (15.2%)
Yes 1122 (88.9%) 485 (75.1%) 489 (85.8%) 2096 (84.6%)
Missing 5 (0.4%) 0 (0%) 0 (0%) 5 (0.2%)
Cigarettes Per Day
0 135 (10.7%) 161 (24.9%) 81 (14.2%) 377 (15.2%)
10 OR LESS 352 (27.9%) 116 (18.0%) 194 (34.0%) 662 (26.7%)
11 TO 20 572 (45.3%) 227 (35.1%) 201 (35.3%) 1000 (40.4%)
21 TO 30 151 (12.0%) 112 (17.3%) 81 (14.2%) 344 (13.9%)
31 OR MORE 47 (3.7%) 30 (4.6%) 13 (2.3%) 90 (3.6%)
Missing 5 (0.4%) 0 (0%) 0 (0%) 5 (0.2%)
Fagerstrom Test for Nicotine Dependence
0 61 (4.8%) 40 (6.2%) 35 (6.1%) 136 (5.5%)
1 76 (6.0%) 44 (6.8%) 31 (5.4%) 151 (6.1%)
2 101 (8.0%) 50 (7.7%) 39 (6.8%) 190 (7.7%)
3 154 (12.2%) 55 (8.5%) 66 (11.6%) 275 (11.1%)
4 191 (15.1%) 49 (7.6%) 77 (13.5%) 317 (12.8%)
5 187 (14.8%) 67 (10.4%) 87 (15.3%) 341 (13.8%)
6 151 (12.0%) 71 (11.0%) 57 (10.0%) 279 (11.3%)
7 111 (8.8%) 64 (9.9%) 46 (8.1%) 221 (8.9%)
8 63 (5.0%) 34 (5.3%) 36 (6.3%) 133 (5.4%)
9 19 (1.5%) 8 (1.2%) 10 (1.8%) 37 (1.5%)
10 8 (0.6%) 3 (0.5%) 3 (0.5%) 14 (0.6%)
Missing 140 (11.1%) 161 (24.9%) 83 (14.6%) 384 (15.5%)
Pain at Enrollment
Missing 11 (0.9%) 0 (0%) 0 (0%) 11 (0.4%)
No Pain 178 (14.1%) 120 (18.6%) 235 (41.2%) 533 (21.5%)
Severe Pain 189 (15.0%) 104 (16.1%) 26 (4.6%) 319 (12.9%)
Very mild to Moderate Pain 769 (60.9%) 421 (65.2%) 309 (54.2%) 1499 (60.5%)
Missing 115 (9.1%) 1 (0.2%) 0 (0%) 116 (4.7%)
Schizophrenia
No 141 (11.2%) 627 (97.1%) 535 (93.9%) 1303 (52.6%)
Unknown 1076 (85.3%) 1 (0.2%) 0 (0%) 1077 (43.5%)
Yes 45 (3.6%) 18 (2.8%) 35 (6.1%) 98 (4.0%)
Depression
No 60 (4.8%) 304 (47.1%) 186 (32.6%) 550 (22.2%)
Unknown 802 (63.6%) 1 (0.2%) 2 (0.4%) 805 (32.5%)
Yes 400 (31.7%) 341 (52.8%) 382 (67.0%) 1123 (45.3%)
Anxiety
No 58 (4.6%) 289 (44.7%) 138 (24.2%) 485 (19.6%)
Unknown 779 (61.7%) 0 (0%) 0 (0%) 779 (31.4%)
Yes 425 (33.7%) 357 (55.3%) 432 (75.8%) 1214 (49.0%)
Bipolar
No 1114 (88.3%) 608 (94.1%) 491 (86.1%) 2213 (89.3%)
Yes 147 (11.6%) 37 (5.7%) 79 (13.9%) 263 (10.6%)
Missing 1 (0.1%) 1 (0.2%) 0 (0%) 2 (0.1%)
Brain Damage
No 1144 (90.6%) 530 (82.0%) 501 (87.9%) 2175 (87.8%)
Yes 115 (9.1%) 116 (18.0%) 69 (12.1%) 300 (12.1%)
Missing 3 (0.2%) 0 (0%) 0 (0%) 3 (0.1%)
Epilepsy
No 1222 (96.8%) 624 (96.6%) 521 (91.4%) 2367 (95.5%)
Yes 40 (3.2%) 22 (3.4%) 49 (8.6%) 111 (4.5%)
Alcohol Diagnosis
No 840 (66.6%) 0 (0%) 411 (72.1%) 1251 (50.5%)
Yes 250 (19.8%) 0 (0%) 159 (27.9%) 409 (16.5%)
Missing 172 (13.6%) 646 (100%) 0 (0%) 818 (33.0%)
Amphetamines Diagnosis
No 970 (76.9%) 0 (0%) 464 (81.4%) 1434 (57.9%)
Yes 120 (9.5%) 0 (0%) 106 (18.6%) 226 (9.1%)
Missing 172 (13.6%) 646 (100%) 0 (0%) 818 (33.0%)
Cannabis Diagnosis
No 867 (68.7%) 0 (0%) 407 (71.4%) 1274 (51.4%)
Yes 223 (17.7%) 0 (0%) 163 (28.6%) 386 (15.6%)
Missing 172 (13.6%) 646 (100%) 0 (0%) 818 (33.0%)
Cocaine Diagnosis
No 734 (58.2%) 0 (0%) 394 (69.1%) 1128 (45.5%)
Yes 356 (28.2%) 0 (0%) 175 (30.7%) 531 (21.4%)
Missing 172 (13.6%) 646 (100%) 1 (0.2%) 819 (33.1%)
Sedatives Diagnosis
No 937 (74.2%) 0 (0%) 417 (73.2%) 1354 (54.6%)
Yes 153 (12.1%) 0 (0%) 153 (26.8%) 306 (12.3%)
Missing 172 (13.6%) 646 (100%) 0 (0%) 818 (33.0%)
Is Homeless
No 0 (0%) 0 (0%) 427 (74.9%) 427 (17.2%)
Yes 0 (0%) 0 (0%) 143 (25.1%) 143 (5.8%)
Missing 1262 (100%) 646 (100%) 0 (0%) 1908 (77.0%)
Used Cocaine
No 725 (57.4%) 165 (25.5%) 315 (55.3%) 1205 (48.6%)
Yes 537 (42.6%) 481 (74.5%) 253 (44.4%) 1271 (51.3%)
Missing 0 (0%) 0 (0%) 2 (0.4%) 2 (0.1%)
Used Heroin
No 188 (14.9%) 504 (78.0%) 72 (12.6%) 764 (30.8%)
Yes 1074 (85.1%) 142 (22.0%) 496 (87.0%) 1712 (69.1%)
Missing 0 (0%) 0 (0%) 2 (0.4%) 2 (0.1%)
Used Speedball
No 1093 (86.6%) 622 (96.3%) 459 (80.5%) 2174 (87.7%)
Yes 169 (13.4%) 24 (3.7%) 109 (19.1%) 302 (12.2%)
Missing 0 (0%) 0 (0%) 2 (0.4%) 2 (0.1%)
Used Opioid
No 708 (56.1%) 1 (0.2%) 288 (50.5%) 997 (40.2%)
Yes 554 (43.9%) 645 (99.8%) 280 (49.1%) 1479 (59.7%)
Missing 0 (0%) 0 (0%) 2 (0.4%) 2 (0.1%)
Used Stimulants
No 1067 (84.5%) 396 (61.3%) 436 (76.5%) 1899 (76.6%)
Yes 195 (15.5%) 250 (38.7%) 132 (23.2%) 577 (23.3%)
Missing 0 (0%) 0 (0%) 2 (0.4%) 2 (0.1%)
Days Using Cocaine
Mean (SD) 3.00 (6.40) 0.450 (1.81) 4.40 (7.69) 2.66 (6.11)
Median [Min, Max] 0 [0, 30.0] 0 [0, 20.0] 0 [0, 30.0] 0 [0, 30.0]
Missing 0 (0%) 0 (0%) 2 (0.4%) 2 (0.1%)
Days Using Heroin
Mean (SD) 22.4 (11.5) 0.127 (0.565) 21.7 (11.8) 16.4 (13.9)
Median [Min, Max] 30.0 [0, 30.0] 0 [0, 6.00] 30.0 [0, 30.0] 23.0 [0, 30.0]
Missing 0 (0%) 0 (0%) 2 (0.4%) 2 (0.1%)
Days Using Speedball
Mean (SD) 0.994 (4.26) 0.00155 (0.0393) 1.96 (5.70) 0.956 (4.14)
Median [Min, Max] 0 [0, 30.0] 0 [0, 1.00] 0 [0, 30.0] 0 [0, 30.0]
Missing 0 (0%) 0 (0%) 2 (0.4%) 2 (0.1%)
Days Using Opioids
Mean (SD) 6.16 (10.5) 28.5 (3.37) 7.06 (10.7) 12.2 (13.4)
Median [Min, Max] 0 [0, 30.0] 30.0 [0, 30.0] 0 [0, 30.0] 4.00 [0, 30.0]
Missing 0 (0%) 0 (0%) 2 (0.4%) 2 (0.1%)
Days Using Stimulants
Mean (SD) 0.672 (2.63) 0.203 (1.61) 2.10 (5.54) 0.877 (3.42)
Median [Min, Max] 0 [0, 30.0] 0 [0, 30.0] 0 [0, 30.0] 0 [0, 30.0]
Missing 0 (0%) 0 (0%) 2 (0.4%) 2 (0.1%)
Days Using IV Drugs
Mean (SD) 18.3 (13.6) 0.325 (2.67) 16.8 (13.6) 13.3 (14.0)
Median [Min, Max] 27.0 [0, 30.0] 0 [0, 30.0] 14.0 [0, 30.0] 4.00 [0, 30.0]
Missing 0 (0%) 0 (0%) 1 (0.2%) 1 (0.0%)
Shared Needles
No 1088 (86.2%) 645 (99.8%) 467 (81.9%) 2200 (88.8%)
Yes 174 (13.8%) 1 (0.2%) 102 (17.9%) 277 (11.2%)
Missing 0 (0%) 0 (0%) 1 (0.2%) 1 (0.0%)
TLFB Days of Drug Use
Mean (SD) 25.6 (4.56) 27.1 (2.16) 20.3 (6.27) 24.8 (5.21)
Median [Min, Max] 28.0 [1.00, 28.0] 28.0 [1.00, 28.0] 22.0 [2.00, 28.0] 28.0 [1.00, 28.0]
Missing 2 (0.2%) 1 (0.2%) 3 (0.5%) 6 (0.2%)
TLFB Number Drugs Used
Mean (SD) 2.61 (1.27) 2.74 (1.35) 3.20 (1.46) 2.78 (1.36)
Median [Min, Max] 2.00 [1.00, 7.00] 3.00 [1.00, 10.0] 3.00 [1.00, 10.0] 3.00 [1.00, 10.0]
Missing 2 (0.2%) 1 (0.2%) 3 (0.5%) 6 (0.2%)
Withdrawal Severity
0 17 (1.3%) 7 (1.1%) 19 (3.3%) 43 (1.7%)
1 649 (51.4%) 343 (53.1%) 237 (41.6%) 1229 (49.6%)
2 565 (44.8%) 288 (44.6%) 145 (25.4%) 998 (40.3%)
3 24 (1.9%) 7 (1.1%) 169 (29.6%) 200 (8.1%)
Missing 7 (0.6%) 1 (0.2%) 0 (0%) 8 (0.3%)
Days in Detox (NA approx 0)
Mean (SD) -0.00000915 (0.000570) -0.0000233 (0.000563) 7.75 (5.22) 1.78 (4.11)
Median [Min, Max] -0.0000371 [-0.000999, 0.000999] -0.0000248 [-0.000995, 0.000996] 6.00 [-0.000524, 40.0] 0.000264 [-0.000999, 40.0]
site_masked
270001 60 (4.8%) 0 (0%) 0 (0%) 60 (2.4%)
270002 62 (4.9%) 0 (0%) 0 (0%) 62 (2.5%)
270003 99 (7.8%) 0 (0%) 0 (0%) 99 (4.0%)
270004 77 (6.1%) 0 (0%) 0 (0%) 77 (3.1%)
270005 61 (4.8%) 0 (0%) 0 (0%) 61 (2.5%)
270006 80 (6.3%) 0 (0%) 0 (0%) 80 (3.2%)
270007 59 (4.7%) 0 (0%) 0 (0%) 59 (2.4%)
270008 77 (6.1%) 0 (0%) 0 (0%) 77 (3.1%)
270009 45 (3.6%) 0 (0%) 0 (0%) 45 (1.8%)
270010 64 (5.1%) 0 (0%) 0 (0%) 64 (2.6%)
270011 62 (4.9%) 0 (0%) 0 (0%) 62 (2.5%)
270012 68 (5.4%) 0 (0%) 0 (0%) 68 (2.7%)
270013 69 (5.5%) 0 (0%) 0 (0%) 69 (2.8%)
270014 64 (5.1%) 0 (0%) 0 (0%) 64 (2.6%)
270015 56 (4.4%) 0 (0%) 0 (0%) 56 (2.3%)
270016 67 (5.3%) 0 (0%) 0 (0%) 67 (2.7%)
270017 71 (5.6%) 0 (0%) 0 (0%) 71 (2.9%)
270018 60 (4.8%) 0 (0%) 0 (0%) 60 (2.4%)
270019 61 (4.8%) 0 (0%) 0 (0%) 61 (2.5%)
300001 0 (0%) 110 (17.0%) 0 (0%) 110 (4.4%)
300002 0 (0%) 56 (8.7%) 0 (0%) 56 (2.3%)
300003 0 (0%) 61 (9.4%) 0 (0%) 61 (2.5%)
300004 0 (0%) 85 (13.2%) 0 (0%) 85 (3.4%)
300005 0 (0%) 58 (9.0%) 0 (0%) 58 (2.3%)
300006 0 (0%) 68 (10.5%) 0 (0%) 68 (2.7%)
300007 0 (0%) 77 (11.9%) 0 (0%) 77 (3.1%)
300008 0 (0%) 56 (8.7%) 0 (0%) 56 (2.3%)
300009 0 (0%) 75 (11.6%) 0 (0%) 75 (3.0%)
510001 0 (0%) 0 (0%) 91 (16.0%) 91 (3.7%)
510002 0 (0%) 0 (0%) 91 (16.0%) 91 (3.7%)
510003 0 (0%) 0 (0%) 67 (11.8%) 67 (2.7%)
510004 0 (0%) 0 (0%) 66 (11.6%) 66 (2.7%)
510005 0 (0%) 0 (0%) 79 (13.9%) 79 (3.2%)
510006 0 (0%) 0 (0%) 81 (14.2%) 81 (3.3%)
510007 0 (0%) 0 (0%) 95 (16.7%) 95 (3.8%)
Did Relapse
0 357 (28.3%) 111 (17.2%) 219 (38.4%) 687 (27.7%)
1 905 (71.7%) 535 (82.8%) 351 (61.6%) 1791 (72.3%)

The Endpoint to Predict

As mentioned in the paper, though many methods have been proposed to evaluate the treatment success for OUD, we opted to use the Lee, et al. definition which assesses weeks to relapse. It uses urine drug screening starting at day 21 post-randomization and regards four consecutive weeks with positive and/or missing urine drug screening as “positive”.4 The code to calculate the definition of relapse used in the paper, along with dozens of other definitions, is available through CTNote, an open-source software library for the R language.5,6 A summary of the code used here can be found in the load.R file shown above.

Machine Learning Concepts

The Big Picture

The predictive modeling typically done with ML methods has different goals compared to the traditional (Neyman-Pearson) p-value-based hypothesis testing familiar to most clinical investigators. A p-value-based approach is useful in the context of experiments where the goal is testing for differences between groups. Traditional hypothesis testing begins with the idea that an experiment was conducted at great expense and the goal is to make a statement about how unlikely the observed effect was to appear by chance alone. However, the broad application of p-value-based statistics, especially if the goal is to produce replicable predictive modeling, has been fraught with issues.7 Of particular concern is that every effect (even a microscopic difference in the mean response to two treatments) becomes statistically significant with a large enough sample. This notion, mixed with a common misconception that p-values measure the size of a treatment effect, when in fact they do not, suggests that investigators should be interpreting different metrics to evaluate model performance instead of relying solely on p-values.

The desire to shift the focus away from p-value-based model assessments has led to the demand by statisticians, data scientists and other ML aficionados to report on the precision of estimates, typically in the form of confidence (or credible) intervals. Interestingly, the traditionally taught confidence interval formulas, which rely on theoretical assumptions about the data, are being largely supplanted by ML methods, like bootstrap, championed by experts who move fluidly between theoretical statistics and machine learning algorithms.8

In contrast to the traditional experimental approach of overemphasizing p-values, ML begins with the premise that some data was collected, ideally a lot of data, and the goal is usually to make predictions about an outcome for a new observation. There is a lot of common ground between traditional statistics and ML. The machinery that statisticians created long before the rise of ML can be (and is) used to calculate estimates with valid confidence intervals and also to make predictions. In fact, many ML projects begin by using traditional statistical techniques as the first step, such as linear or logistic regression. The next steps differ where ML experts typically deprecate the p-values produced by the initial models and quickly move on to trying many other algorithms.

There is a plethora of statistical and ML methods for a multitude of outcomes. While ML methods can be used to predict continuous outcomes (morphine milligram equivalents consumed in a day), count data (the number of used needles brought to a harm reduction center), or time until an event (days until overdose after discharge), we will focus on predicting a binary outcome: treatment success or failure in Medications for Opioid Use Disorder (MOUD) programs.

The Role of Training and Testing Datasets

The ML process begins by setting aside a small part of the data as a technique to focus on predicting future samples. This data split is typically between 10% to 25% of the collected data and are eventually used to evaluate or test the performance of a predictive model that was trained on the remaining split of the data. In the context of predicting treatment success or failure, there is a yes/no outcome variable for each person in the test set. This outcome variable holds the truth about whether a person is still using drugs after an intervention.

An algorithm is selected, such as one of those described below, and applied to the training dataset. This algorithm produces an optimal set of rules (or a formula) for predicting success or failure for the people in the training dataset. The quality of that training model can be assessed using the small sample that was set aside and labeled the test or testing data. Predictions for people in the testing data can be generated by applying the predictive model built from the training data. Then the predictions are checked against the true values in the testing data.

It is important to fully consider the implications of the training and testing split. Specifically, realize that first applying the training data to construct the model and then assess against the testing data mirrors the process that would occur if the treatments were repeated on a new sample from the same population. This is the gold standard for evaluating this population. However, like all clinical investigations, a key question is “are these study participants like my study participants?” This answer is addressed by studying “Table 1” of a paper. Just like the optimal model describing harm reduction policies in rural Zimbabwe may not be particularly useful for policy makers in New York City, the value of the test data is its ability to describe what would happen in another data set from a similar cohort. Here, as outlined in the paper, we build predictive models using data from three of the largest clinical MOUD trials which were harmonized as part of a NIDA-sponsored project CTN-0094.9

Choosing a Metric of Success

Someone evaluating a yes/no question, such as whether a person stopped using a drug, can generate a large number of evaluation metrics.6 Some, such as accuracy, are relatively obvious. That is, everyone wants to know what percentage of the participants in the testing dataset were correctly called/guessed as either “using” or “no longer using” a drug after an intervention designed to produce cessation. When you think about scenarios where most participants are (un)successful in treatment, a weakness with accuracy becomes apparent. If 95% of participants fail treatment, a model which ignores all of the predictors and just specifies that everyone will fail, will still be 95% accurate. So, instead of just focusing on accuracy, insightful researchers use additional metrics to evaluate the quality of a predictive model. To understand these metrics, study the information in the “confusion matrix”, shown in Figure 1. A confusion matrix is a two-by-two table displaying the count of the true outcomes versus the predicted outcomes.

Figure 1: The groups defined by a confusion matrix showing the predicted and true drug use categories after medication assisted treatment (MAT) for opioid use disorder (OUD).

Truth

Prediction

Predicted to Use:
Predicted Positive (PP)

Predicted to Stop: Predicted Negative (PN)

Actually Using: Positive (P)

True Positive (TP)

False Negative (FN)

Actually Stopped: Actually Negative (N)

False Positive (FP)

True Negative (TN)

Figure 1: Two-by-two table showing predicted and true drug use categories after medication assisted treatment (MAT) for opioid use disorder (OUD).

Having dealt with issues related to predicted and actual cure rates and rare outcomes for centuries, epidemiologists have developed a toolbox of metrics to describe patterns in confusion matrices. These include sensitivity (\(\frac{TP}{P}\)), specificity (\(\frac{TN}{N}\)), positive predictive value (PPV) (\(\frac{TP}{PP}\)) and negative predictive value (NPV) (\(\frac{TN}{PN}\)). Not being trained in epidemiology, the ML community rediscovered and relabeled the metrics. For example, sensitivity is labeled “recall” and PPV is called “precision”. As annoying as the new terms may be, these ML practitioners also have popularized useful metrics like \(F_1\) (\(\frac{2TP}{2TP + FP + FN}\)), which simultaneously summarizes both sensitivity and PPV. These ML experts have also popularized older metrics from the statistics literature like Cohen’s kappa (\(\frac{2 \times (TP \times TN - FN \times FP)}{(TP + FP) \times (FP + TN) + (TP + FN) \times (FN + TN)}\)), which quantifies model performance beyond what is expected by chance. When there is an extreme imbalance in the number of the participants in two outcome groups like the scenario described above where accuracy is practically useless, many ML engineers create models to optimize Cohen’s kappa.

Modeling practitioners agree that the overall quality of a model can typically1 be well described by looking simultaneously at the probability of a true positive and the probability of a false positive score.11 This is typically expressed with a graphic called a receiver operating characteristic (ROC) curve and a numeric summary called the C-statistic (or the concordance statistic). To understand their value, consider a model which attempts to predict treatment success for participants in a clinical trial designed to help participants with an OUD diagnosis stop using opioids. After a model has been built on the training set and applied to the testing dataset, each participant in the testing dataset is assigned a predicted probability of treatment success. The performance of the model can be checked against the true values for these participants in the test set.

Practically speaking, this works by varying the threshold of success. We begin by saying “if anybody has a predicted probability of 0 or higher, label them as a success”, and check if that is the correct value for each person. Given that the minimum possible score is 0, by saying a probability of 0 or higher is a treatment success, the predictive model will correctly capture all the true positives but incorrectly classify all the true negatives. Next, we can increase the threshold to, say 0.1, and classify anybody who has a predicted probability of 0.1 or higher a success. That will (potentially) decrease the true positives and decrease the false positives. The same process can be repeated across the range of probabilities, up to saying only the participants who have a predicted probability of 1.0, the highest possible score, will be called a success (which will perfectly exclude the true negatives but will likely miss some of the true positives). That is, in this extreme scenario, someone who has a predicted probability of success of 0.99 would be called/guessed to have a failure of treatment because they didn’t meet the requirement of a predicted probability of 1.0.

The percentage2 of participants correctly labelled a treatment success and the percentage incorrectly called a success can be plotted (see Figure 2). The plots are typically drawn as a square and the area under the curve is shaded. This area under the curve, which typically falls between 0.5 and 1.0, captures the C-statistic (also called the receiver operating characteristic area under the curve ROC AUC). If the model performs perfectly, the plot will fill the square, and the C-statistic will be 1.0. If it only performs at chance, it will draw a 45-degree line across the plot, and the C-statistic will be 0.5. The C-statistic has a practical interpretation: if you have two participants, one who will eventually succeed in treatment and another who will fail, the C-statistic gives you the probability that the model will correctly guess which of them will be the treatment success.

Figure 2: An example ROC Curve where the area under the curve, the C-statistic, is 0.94. The dotted diagonal line represents a C-statistic of 0.5.

All that said, after building a predictive model on the training data, it is possible to get an assessment of how well it will perform on new participants, who are comparable to the participants in the original training data set. The “new” set of participants can be the test dataset mentioned above or it can be a different assessment set created using one of the methods described below. This focus on quantifying the practical model performance, using metrics like accuracy or ROC AUC on a sample helps address the concern about relying solely on p-values as mentioned above. Remember, the ML world focuses on the quality of a model’s performance when it is applied to new data. Different metrics emphasize different kinds of successes or failures when the model is applied to new data. An interested reader should begin with the classic paper by Marina Sokolova and Lapalme12 and explore more modern papers such as those by Opitz13 and Ganbek, Temizel and Sagiroglu14. We chose to optimize our predictive models to produce the best ROC AUC when we apply the models to new data.

Assessing Variable Importance

With traditional regression methods, like ordinary least squares regression, logistic regression or survival analysis, investigators typically focus on statistically significant values (p < 0.05) and then use details like the size of beta coefficients to make statements. For example, the beta coefficients give us the number to describe the quantity of needles brought on average to a syringe exchange by men compared to women; the beta estimates for a logistic regression model can describe the odds of an overdose event for females relative to males; or the beta estimates from a survival model describe how the instantaneous risk of death goes up by some amount for someone who has started to inject heroin. While this approach still plays a role in determining the impact of predictors, most ML practitioners use other “model agnostic” methods to assess the role and importance of predictors. There are model agnostic methods to describe how changes in a predictor impact an individual (using techniques like break-down plots, Shapley additive explanation plots and ceteris-paribus profiles) as well as techniques to describe the general importance of predictors using permutation techniques. We hope to apply and describe the application of these techniques in future publications. For now, consider that the techniques which attempt to describe how an individual’s outcomes change when any predictor is modified are complex, but approachable explanations can be found in publications such as Biecek and Burzykowski.15

A useful and easy to understand concept for determining the importance of a variable is permutation-based variable importance. If a variable included as a predictor in a model actually does not matter and if its values are shuffled then assigned randomly to other participants (i.e., the data are randomly permuted), the model’s performance, as measured by accuracy or the C-statistic, will not change much, if at all. If a variable does matter to a model, then the same random shuffling (i.e., permutation) of the data will decrease the model’s performance on a measure like accuracy or the C-statistic.

This permutation method to judge the importance of variables can be used for nearly all modeling techniques. The harm done to a model’s performance when a variable is permuted is graphed in a variable importance plot (VIP). VIP plots are typically drawn in one of two ways. One method is to use a bar graph where a bar for each variable stretches toward the right. In this kind of plot, longer bars represent more harm being done to the model’s performance when a variable is permuted. Modern software will attempt to scale the bars using the key outcome metric (like accuracy or the C-statistic), but there is always some subjectivity in determining what is a “long” bar. Another common VIP plot uses a vertical line on the right side of the plot to indicate the model’s optimized performance. It then draws bars toward the left to indicate decreases in performance if a variable is permuted. Again, a long bar indicates an important variable because shuffling and randomly assigning values to random participants hurts the model performance a lot. An alternative to just drawing a bar is to repeatedly take samples of the data and reshuffle the values again and again to see how much harm is done in the resamples. Figure 2: Variable Importance Plot (VIP) for Random Forest Model in the paper shows this pattern. Instead of representing a single estimate of a variable’s importance with the length of a bar, a box plot can show the typical amount of harm (the shaded part of the box plot shows the middle 50%) as well as high and low estimates of the importance of each variable using the whiskers and outliers of the box plot.

How the ML Workflow Differs from Traditional Methods

As mentioned above, a typical ML approach begins by using methods championed by statisticians for more than a generation. While the philosophy behind these methods are different, with the statistician focusing on the stochastic data model and the ML modeler focusing on the properties of the algorithms, mathematically, they are comparable.16,17

Next, before exploring the coding, we discuss the subtle differences in the data pre-processing typically done in a ML workflow. Then we discuss the role of data splitting and resampling, the cornerstones of both modern statistical methods and ML. Finally, we discuss model tuning, which is the key difference between traditional and ML methods. Along the way, we will provide real examples using k nearest neighbors (kNN), a conceptually simple but powerful ML algorithm.

Preprocessing Recipes

Traditional statistical modeling workflows involve some preprocessing of the data. For example, problematic data is identified and corrected (e.g., participants who report being 1973 years of age or subjects who died before the start of the study period). Further, outcomes are mathematically modified to ensure the appropriate theoretical assumptions are met (e.g., by taking a logarithm or applying a Box-Cox transformation on variables that cannot be well described by a bell-shaped curve). This processing can differentially impact the model metrics mentioned above (e.g., a model’s accuracy or ROC AUC). The preprocessing steps that evolved for traditional modeling are used in the ML context, but the process is more complicated for two reasons: resampling/subsampling and model-specific preprocessing needs.

Resampling / Subsampling

First, as will be shown shortly, ML models are often built using subsamples of the training dataset. When modeling is done on subsamples, care must be taken to make sure that there is no leakage of information across the samples. For example, some modeling methods require all predictors to be on the same scale (a process which involves calculating a mean and standard deviation). When using resampling techniques, it is important that the entire dataset is not used to calculate these statistics because if all the data were to be used, the subsamples would gain information from the entire dataset (i.e., an outlier not in a subsample would still impact the overall mean). Of particular concern is the mistake of including all the data (i.e., both the training and testing data) in the preprocessing. By leaking information like the mean or standard deviation of the entire dataset into the training data, the faulty code can allow powerful algorithms to pick up on details that should not have been known before the testing data is touched. In other words, the models will give overly optimistic estimates when applied to the “new” testing data. These concerns are addressed by using preprocessing “recipes” which can be easily applied to subsamples.

Model-Specific Preprocessing

The second departure from traditional preprocessing methods is driven by the fact that different ML methods benefit from different kinds of preprocessing steps.18 Some methods want categorical variables recoded as yes/no indicator variables, some methods benefit from including all the levels, and others need one level dropped. Other steps differentially help different predictive algorithms. These steps include things like dropping highly correlated predictors, combining uncommon categorical variables, and modifying datasets to account for rare events.

A host of preprocessing steps have been explored to help deal with rare outcomes. With large to huge datasets, rare outcomes can be modeled with appropriate cautions. For instance, if all the training data is used, some ML methods will do extraordinarily well at predicting the “not rare” case and poorly when predicting the “rare” outcome. This is extremely problematic when the rare category is what we most need to predict. For example, if 99% of participants have an overdose reversed with a treatment drug, ML algorithms will do a relatively good job predicting the survivors while failing to predict the fatalities. With huge data sets where the rare event happens to thousands of participants, one viable strategy is to downsample (i.e., throw away cases/records) from the more common class. That is, if there are 3 thousand people who have the rare outcome and 300,000 people who have the common outcome, 297,000 records from the common class are not used when building a sample. In contrast, it is also possible to upsample (i.e., repeat) records from the rare class. A popular variant on upsampling, which goes beyond simply duplicating records, is to create new “synthetic” cases using algorithms like SMOTE or ROSE.20 These techniques typically reduce the quality of some metrics (i.e., decreased overall accuracy) and require more time while modeling, but these losses can be offset by large improvements in the predictions of the rare class (i.e., improved sensitivity).21

With modern ML workflows it is relatively easy to apply a host of different preprocessing recipes and choose a model that successfully trades performance on one metric for another. For example, a model can be optimized to find the solution with the best accuracy, the highest ROC AUC, or one of many other metrics of success. This work is all done on the training data. Once a preprocessing recipe is paired with a modeling algorithm, that choice is applied to the testing data once. It is important to remember that the testing data is held as a precious resource that is only used once. This affords us the ability to describe the model performance on a “new” sample and makes the interpretation of p-values less problematic compared to traditional methods where p-values are rarely properly adjusted for all the preliminary modeling that was done.

The paper mentions the preprocessing recipe that we used for the final analysis. We will show you that and several alternatives below in the section titled Preprocessing Recipe.

Model Tuning

A major difference between traditionally used modeling methods like ordinary least squares regression or logistic regression and more modern ML methods is the reliance on “tuning” the model for optimal performance. Consider the machine learning algorithm called k-nearest neighbors (kNN), which we will use, and discuss more, later. In a case where we are predicting treatment success for a new participant, where success is defined as a yes/no response by clinicians, kNN will find k “similar” participants, calculate the percentage of participants who were a treatment success or failure among them and use the majority response (i.e., “were the majority of similar participants treatment successes or failures?”) to assign the outcome for the new participant. Of course, that raises two questions: how do you define “similar” and how many similar participants should be used? The “similarity” issue is typically resolved by plotting the data and measuring the distance between points in the plot. For example, think of a scatter plot, like the one shown in Figure 3, which plots two potentially useful predictors of treatment success (years of drug use and drug use events in the last 30 days) on the x- and y-axes respectively. Further, imagine a training dataset with 300 participants, each of whom would be plotted as a point in the scatter plot with their outcome shown using different colored symbols for treatment success and failure. Finding people similar to our test case would be like adding a point to the scatter plot with the drug use history for our new participant and drawing concentric circles (bulls-eye rings) until we had included k participants. If the majority of participants in the circle were no longer using the targeted drug, we would predict treatment success for our new individual.

Figure 3: Example data showing 123 treatment success (green filled dots) and 147 treatment failures (gray open dots) from the analysis set and one participant from the assessment set who is actually a treatment failure (red open diamond). The 3 closest neighbors from the analysis set (k = 3) are inside or touching the inner (black) circle and five closest neighbors (k = 5) are inside or touching the outer (blue) circle. With k = 3, 100% of the neighbors are treatment failures. With k = 5, 80% are failures. So, this participant is correctly classified using either option of k.

Now the more interesting question is, how many similar neighbors should we use? That is, how do we pick the best k? With a training dataset of 300 participants, we have a lot of choices for possible values of k: every number between 1 and 300! At the extremes, we could use the one closest person (k = 1), or we could use k = 300 people and calculate the overall probability of success in the entire training dataset. Throwing away 299 bits of information by using only the single closest person (k = 1) seems unwise. On the other hand, including everybody (k = 300), which assumes that you don’t benefit from using the most similar participants, also seems like a bad idea. What we want to do is grab a “handful” of the most similar participants, test the quality of the predictions for them (perhaps with accuracy or using the ROC AUC/C-statistic), then increase the size of these “handfuls” (try k = 3, k = 5, …).

The machine learning solution to this problem is to split the 300 participants in the training set into subgroups, which practitioners call partitions or folds, and then use the subgroups to evaluate across different values of k. If we decide to do a 10-fold analysis, we would assign each participant as being a member of one of ten subgroups/folds. We can take the first fold (which would be 30 of the 300 participants, say the first 10%) and call them an assessment dataset and we can call the remaining participants who are in the other nine folds (9 \(\times\) 30 = 270 participants) an analysis set. Notice this is distinct from the initial split separating the data into training and test sets. Here, we are taking the training data and further splitting/partitioning it into a 10% assessment set and a 90% analysis set. We then set k = 3 and see how the model performs (check the accuracy) when we use the 270 participants in the analysis set to predict the 30 participants in the assessment set. That is, take the first participant in the assessment set, add them to the scatter plot with the 270 points in the analysis set, find the three closest participants to the new participant, see if the majority succeed or fail treatment and then compare the prediction vs. the truth for the first participant. This example is shown in Figure 3. We repeat this process for the other 29 participants in the assessment set. We can repeat this entire process on the next 10% fold, then the next fold, until all the participants in the training data set have been included/used in an assessment set. When we have finished, we can average the performance across the 10 groups of 30 (what ML aficionados would call the 10 folds) to calculate an overall assessment of performance with k = 3. We can then set k = 5 and repeat this entire procedure to check the performance across the folds. This process of checking the performance across the folds/resamples is called cross validation. A machine can iterate over any value of k but typically, we check k = 3, k = 5, etc. out to some fraction of the sample size (a rule of thumb is to use the square root of the sample size; here the square root of 300 is about 17).3 This process of checking the model performance across a range of possible options, like the different values of k here, is called “tuning” a “hyperparameter.” Remember the phrase “tuning hyperparameters” and remember that the tuning is performed using cross validation. Those details are the core that sets ML apart from other modeling methods/mindsets.

In addition to checking the count of the numbers of neighbors, it is possible to assign different “weights” to the k closest participants. That is, it is easy to imagine that the 10 closest participants are valuable for making predictions but within that “circle”, the participants who are closest are even more valuable than the more distant participants. So, it is also possible to “tune” many distance weighting functions at the same time the model is tuning the k.

Mathematicians have also defined ways to measure distance beyond straight lines. For example, traveling in a city does not involve moving in a straight line “as the crow flies”, rather, it involves moving up/down and right/left some number of blocks (i.e., Manhattan distance). So, the distance between two neighbors can also be optimized using many ways to measure distance. The end result is an optimal k with its optimal proximity and distance weighting function. As you can imagine, checking all the possible k values with many different distance weights leads to a huge number of combinations.

While it is possible to search every combination of k and weights by creating a grid, typically a large number of random combinations are tried to cover the space of possibilities. A “space-filling” design can be used to make sure the randomly chosen values are not too clumped together. You will see examples of grids in the code below for methods that are not tuning many hyperparameters (e.g., LASSO) as well as space-filling designs for methods that need to tune many parameters (e.g., neural networks). The state-of-the-art in ML focuses on how to choose options and then check other possible values in the neighborhood that look promising using Bayesian optimization methods. Watch for those Bayesian methods in future work.

Once we know the best k along with its best weighting function across the training data subsets we can use the test data once to see how our model performs on new data. This is how ML works. Different algorithms require different hyperparameters to be tuned, but at its core, ML is all about checking the model performance on samples of the training data (after it is split into numerous analysis and assessment sets).

There are many ways to create analysis and assessment sets. Splitting the training data into tenths and using 10-fold cross validation is a very popular option for huge datasets. Another popular alternative, which we have done in the paper, is 5-fold cross validation. It is popular because it helps decrease computing time and it ensures that the assessment data has enough successes if the dataset is small or if the outcomes are extremely rare. Another very common option is to use bootstrap resamples. With the bootstrap, an analysis set is created by randomly sampling a number of participants with replacement to make an assessment set which has the same number of participants as the original data. With replacement means that after an individual is selected, they are returned to the pool of eligible participants to be selected. So, one person can be selected two or more times in a bootstrap sample. Because bootstrap can select the same person more than once (after they were replaced) and because the requirement is that the assessment set is the same size as the original data, there will be some participants who were not selected as part of the analysis set. They can be used to assess the performance of a model applied to the analysis set. It can be mathematically shown that a bootstrap resample will generate an analysis set that contains about 2/3 of the original records, leaving 1/3 of the records for the assessment set. To learn more, see An Introduction to Statistical Learning22.

Explanations of the Methods Used in the Paper- Without Code

k-Nearest Neighbors - Prediction Without Learning

Along with the other methods we explore here, we can make a statement about which variables are relatively influential. For some methods, like the regression models which we will cover next, we can make strong statements about what happens to the outcome as predictors change. With kNN, we gain none of those benefits. We may get a good prediction, but we don’t know why it works beyond saying that “this person responds like similar people.” This lack of transparency in the relationship between the predictors and the outcomes has led kNN to be labeled as a “black box” method. Whereas ML experts have been able to explore the inner workings of many ML algorithms that were once considered inscrutable black box methods, the kNN algorithm remains a closed black box. In essence, kNN will yield predictions, but we learn nothing about the role and influence of the individual predictors.

Regression Modeling

The ML community has embraced and expanded the linear modeling methods that most people are taught in introductory statistics classes. Unlike the other methods we explore below, traditional ordinary least squares regression and logistic regression models do not require tuning of hyperparameters. However, leaders in the field have expanded and modified the traditional approaches to take advantage of ML methods.22

Traditional Logistic Regression

While the algorithms used to do logistic regression in the context of ML modeling are the same as traditional statistical methods, their practical implementation gains when applying a ML framework. Rather than fitting a single model using all the available data, as is often done by classically trained practitioners, those using ML will split into training and testing sets, then use cross-validation or bootstrapping and then evaluate the model on the test data. In this context, logistic regression becomes just another competitor which is evaluated to see if it does the best job predicting the data. While traditional logistic regression cannot readily detect the complex patterns in the data which other methods note, it has not fallen out of favor because it offers relatively simple explanations in terms of the change in the odds of an outcome as each predictor changes. Models with simple explanations are called parsimonious.

The traditional “manual” modeling process historically taught in epidemiology classes is unpopular in the ML framework. That process, which exploits subject matter knowledge to select variables to include in a prespecified order (typically by looking at the impact of each predictor by itself, then including variables that are statistically significant in other models), followed by manually removing variables (often based on a p-value based criteria), has a tendency to produce model results which do not replicate when applied to new datasets. Instead of relying on p-value-based metrics, some practitioners will use automatic “stepwise” modeling procedures that are optimized based on criteria other than p-values, like AIC. Even these automatic stepwise procedures are being supplanted by ML “shrinkage” methods like Least Absolute Shrinkage and Selection Operator (LASSO).22,23

Logistic LASSO

LASSO and its cousin, ridge regression, conceptually begin by saying the effects that we note based on this sample are going to be too large. That is, if we were to gather another sample, the same predictors may matter, but extra large effects, which may happen by chance associations in this sample, need to be shrunk toward zero. So, while traditional logistic regression is optimized to find the optimal “betas” to describe the impact of the predictor in this sample,4 LASSO is designed to simultaneously find the most accurate betas for this sample and shrink them down toward results with the magnitude we expect to see in a new sample. The LASSO adjustment “penalty” pushes all the betas toward zero and some of them are shrunk to be zero, meaning that as the predictor changes, the outcome is not influenced. Because LASSO can be used to simultaneously produce conservative estimates of the impact of each predictor and remove unimportant variables (drop some betas to zero), many ML practitioners use LASSO instead of the traditional logistic or automatic stepwise modeling methods. That useful ability of eliminating variables because their betas have been shrunk to zero is called variable selection. That said, because so much of the biomedical literature has used traditional logistic methods, we still find it useful to use both traditional and shrinkage regression methods so we can compare our effect sizes with previous work.

Like with kNN and the logistic modeling, the LASSO modeling process involves splitting the training data into analysis and assessment sets, fitting a series of models on the cross-validation (or bootstrap sets) and evaluating at the performance on the assessment sets. The key difference between this and kNN is that with LASSO, the tuning parameter is the optimal amount of shrinkage. That is, the training data is split into many subgroups and the model performance is assessed across a range of shrinkage penalties. Once the optimal amount of shrinkage is found on the training data, a prediction can be generated on the test set and the performance can be evaluated.

Logistic Regression with Resampling

While traditional logistic regression is done by fitting a single model with all the training data, it is possible to fit a LASSO model with the penalty set to be extremely small. By doing this, it effectively just does traditional logistic regression but, because the model is fit using the resampling framework needed for LASSO, we can calculate the cross-validation/resampling estimates of future model performance. That is, the average logistic performance can be calculated across the resampling folds and we can get a predicted ROC AUC.

MARS

Both traditional and LASSO regression methods assume that there is a linear (or linear in the logit space) impact of continuous predictors. That is, for every one unit increase in a predictor, there is the same impact (increase or decrease of the odds) on the outcome. Many methods have been proposed to relax this constraint. A popular ML option is Multivariate Adaptive Regression Splines (MARS) models. Rather than saying the impact of a predictor must be described by a line, MARS models allow the line to change its slope in a few places. Imagine it as adding a “hinge” to a traditional regression line. This theoretical hinge, which attaches two line segments, acts like a physical hinge attached to a door allowing it to change the angle by which the door is attached to a wall. This is extremely powerful because we can also describe a different impact of a one unit increase in the predictor across different ranges of the predictor. MARS models are designed to find the breakpoints and add a “hinge” to the regression line where the change happens, thus allowing the person evaluating the model to say the odds change by some known amount below the threshold and then by a different amount above the threshold. For example, the chances that someone will stop breathing may go up for each miligram (mg) equivalent of morphine across a broad range of doses, but after hitting a threshold, the chances will skyrocket for each additional mg. In this example, there may not be a clear physiological breakpoint. In theory, an optimal model would add a curve to the regression model where the change happens, but that flexibility comes with the cost of difficulty in interpretation. MARS models sidestep that complexity by saying the line has one slope until a threshold is met and then has a different slope.

Support Vector Machines

At their core, Support Vector Machines (SVMs) seek to find an optimal “dividing line” that splits data into groups. That dividing line can be a line through a two-dimensional scatter plot, a plane that slices though a three-dimensional cube of data or a hyperplane that cuts across four or more dimensions. SVMs strive to maximize space around the dividing line (i.e., the margin between the boundary and the nearest data points on either side of the line). These nearest points, called support vectors, are crucial in defining the dividing line and giving the algorithm its name.

The true power of SVMs lies in their ability to handle nonlinear decision boundaries through the “kernel trick.” To understand this, imagine a scatter plot that shows a circular blob in the middle with a lot of minus symbols representing participants who fail treatment, and all around the blob are many plus symbols representing participants who succeed in treatment. Most ML methods will struggle to make correct predictions with this kind of data because there is no straight line (or step function) that can form a boundary. Imagine that the scatter plot described above is printed on a stretchy rubber surface. The SVM will pinch the middle of the blob for failure of treatment and stretch it up off the page. After the data has been pulled into the third dimension, it will find the optimal plane to slice the data into successes and failures. In other words, this mathematical technique allows SVMs to implicitly map data into a higher dimensional space where linear separation becomes possible. Common kernel functions to do the “stretching” include polynomial, radial basis function (RBF), and sigmoid, each offering different ways to transform the feature space. (NOTE: The many types of kernel functions are beyond the scope of this supplement.) This flexibility enables SVMs to capture more complex patterns than the circular blob that I describe.

In the paper, we use a relatively simple SVM that is tuned for the degree of the polynomial use to change the space and the penalty associated with putting a point on the wrong side of the dividing line.

Tree-Based Modeling

Tree-based methods begin with the assumption that there are homogeneous groups of participants, and these groups can be identified by building a series of rules that are framed as yes/no questions. These techniques find the optimum yes/no splits in the predictor variables to group participants with the same outcomes. Here, we will be using tree-based methods to find homogeneous groups of participants, that is, participants with similar values on predictor variables who share a common outcome, that succeeded or failed in treatment. Be aware that tree-based methods can also be used to find participants with similar numeric scores, like the number of needles exchanged at a harm reduction program, or to identify groups of participants who have a similar pattern of events through time, like the time until an overdose. In other words, tree-based methods can be used to solve classification, regression, and time-until-event/survival problems. Tree-based methods are particularly useful because their construction intrinsically provides insight into which variables are important and they can pick up on complex (non-linear) trends. See Greenwell’s Tree-based methods for statistical learning in R for an extremely well written introduction to tree-based models.24

CART

To understand how the yes/no rules are built, consider how CART (Classification and Regression Trees), one of the most influential early tree-based methods, works to predict treatment success with three variables: age in years, race (i.e., White, Black and Other) and sex (i.e., male or female assigned at birth). CART uses every variable to make every possible yes/no split and then checks the percentage of participants in the two groups who are treatment successes. So, it evaluates a rule that says, “Is this person male?” and it checks to see what percentage of participants in the two groups, male and female, are correctly labeled as treatment successes. It then looks at the age variable and notes the youngest participant is 15. It makes a rule that says, “Is this person 15 or younger?” and it checks the percentage of participants who are treatment successes in the two age groups. Next, it notices there is a 17-year-old, so it makes two groups using the rule, “Is the person 17 or younger?” Again, it checks the percentage of participants who are correctly classified as treatment successes in the two age groups. It repeats this for every possible age split. It then turns to race and makes every possible two-way split. That is, it checks the percentage of participants correctly classified as treatment successes if it splits on Black vs. a group containing both White and Other. Then it checks the percentage correctly classified if it splits on White vs. a group with both Black and Other. Finally, it checks Other vs. a group with both White and Black. Each time, CART notes the percentage who are correctly labeled as treatment successes in each subgroup. Here we are working with only three variables, but CART and other tree-based methods will iterate through and make every possible two-way split in each variable. Once the process is completed, it finds which split makes the most accurate prediction. That is, it picks the split that will result in the lowest probability of incorrectly labeling a randomly chosen participant.

Figure 4 shows a hypothetical CART “tree” diagram. Note that decision trees are typically drawn upside down, with the “root” of the tree at the top and the final groups, which are called “leaves”, drawn at the bottom of the graphic. Real classification trees may provide additional details, like the number of participants in the intermediate decision nodes and the proportion at each intermediate node that are treatment successes or failures. In this example, the algorithm finds that the split of 54 years or younger vs. 55 and older results in the fewest incorrect labels of treatment success and failure. That will be the first (top) split in the tree. It then takes only the participants in the younger group and checks the accuracy of the labels if it splits on sex, all the combinations of race, and every possible age split in this younger group. If it can improve the quality by adding another split, the CART method will do it. It chooses the split that reduces the outcome label errors the most. It then turns to the older group and checks every possible split. It does this over and over. It is extremely important to note that with tree-based methods like CART, different variables can play an important role in only some subgroups. It may happen that in the older group race may be the next most important split while in the younger group, sex matters the most to form the second split. CART and other tree-based algorithms can reuse the same variable, so you can see an initial split on age, followed by race, then by age again (corresponding to a differential treatment response only in say middle-aged Black people, for example).

Figure 4: A hypothetical tree predicting treatment success for a cohort entering treatment for opioid use disorder.

The ability to inspect splits allows looking at differential treatment effects. For example, a tree could split a cohort into people who received treatment medication one vs. other medications. The most important predictors, signified by additional splits, for the medication one group could be age, smoking and IV drug use. The other medications group could instead split on study site, days in detox and amount of prior opioid use. CART (and all tree-based methods) allows us to see when completely different predictors could be emphasized for each of the treatment medications.

One obvious way to make homogeneous groups is to keep splitting until each final group (called the leaves of a tree) contains a single participant. When should the tree building process end? There are two general strategies, early stopping rules or pruning full trees. Early stopping rules can take the form of controlling the depth of the tree by dictating the number of splits in advance (say, only allowing three nested subgroups) or preventing the tree from forming subgroups if there are not many participants in a group.

The alternate strategy, pruning a full “bushy” tree, takes advantage of resamples of the training data. After building out the full tree on an analysis subset of the training data, the quality of the model is checked using the assessment set after different amounts of the tree, the portions that contribute the least to the improvement of classification, are cut off. This post-processing of a bushy tree avoids missing important signals that can be unnoticed if a tree stops growing after a small number of splits, for example, limiting a tree to only 3 splits. While it costs extra time, it is also possible, as we did in the paper, to tune both options.

In addition to using model agnostic methods, like permutation-based VIP metrics to determine which variables are important for a prediction, there are model specific, here CART-specific, methods for assessing variable importance. The tree building process tells us which variables are the most important for classifying participants. Splits close to the root of the tree are the most important for classification of treatment success or failure, but variables that are used to split over and over are also important. The ability to use “height in the tree”, or the number of times a variable is “used”, leads people to view tree-based methods as having intrinsic properties that inform us of a variable’s importance. Another extremely useful feature of this process is that the trees don’t care about “linear trends” traditionally captured by regression beta parameters. That is, there are no statements about what happens to the outcome if a predictor increases by one point. Rather, there are nonparametric statements about being in one group or another. This means that CART can pick up complex, nonlinear patterns but it will “miss the point” if there actually is a linear trend.

While CART remains popular, especially among statisticians, there are other popular algorithms for building classification trees. These include the C5.0 algorithm, which is similar to CART but uses entropy or information gain to measure the quality of group predictions, and the GUIDE algorithm. These algorithms help to avoid a common issue where CART will tend to overvalue a (categorical) predictor that has many levels. CART’s popularity is driven in part by the appeal of a single easy-to-understand set of rules. However, CART and other single tree methods generally perform poorly compared to the predictions generated using sets of trees, like random forests, boosted trees, and BART.

Random Forest

Where single tree methods like CART use all the predictors and participants to make a prediction, random forests repeatedly sample both the participants, a process called bagging (i.e.,bootstrap aggregating) and randomly sample predictors as they build many trees. By randomly including portions of the data, predictions coming from random forests devalue highly influential participants and correlated predictors, because they will not always be used when making a tree. Whereas single tree methods typically tune the complexity of trees (e.g., tuning the amount of pruning), random forests, like those in the paper, typically tune the number of variables that are available at each split in the tree as well as the number of participants who are needed to make further splits in the tree.

Because a random forest can include hundreds of trees, we lose the benefit of visually inspecting “the tree” to see which variables are important. These ensembles of tree methods take advantage of the model agnostic methods, like the permutation-based variable importance scores mentioned above, to describe which variables matter to the prediction.

Boosted Trees (XGBoost)

Where random forests typically gain strength over single trees by aggregating many trees using randomly selected variables an alternative is to build a tree, check where it does a poor job and then build additional trees that focus on improving the areas where the model is performing poorly. Like random forests, boosted trees can take advantage of permutation-based variable importance scores to specify which variables are critical for predictions.

XGBoost is one such modern algorithm that learns from its mistakes as it builds trees. It and its cousins LightGBM and CatBoost are wildly popular, because of both their speed and accuracy in machine learning competitions.25 Effectively, it builds a tree, looks to see where the modeling is doing a poor job, and it refines the tree structure to handle those cases that are poorly predicted. It is controlled by many of the same hyperparameters as discussed above but it adds a learning rate parameter that controls how much influence each successive tree has on the model as it tries to correct errors.

BART

Bayesian Additive Regression Tree (BART) algorithms are at the leading edge of ML. They leverage an entirely different set of statistical theory (Bayesian inference) and incorporate other methods like LASSO’s shrinkage estimators for calculating the complexity of trees. Unlike other tree-based methods that require careful tuning for how to prune back overly bushy trees, the BART algorithms “automatically” figure out these details. BART trees tend to perform well compared to other algorithms and they are strong candidates when the goal is to estimate causal relationships.26

Neural Networks

Neural Networks (NNs) are a class of algorithms that include popular complex variants like Deep Learning or Deep Neural Networks. They are extremely popular for solving problems where the data do not come as a traditional statistical dataset (i.e., image or audio processing or complex text). They work using interconnected layers of artificial neurons, with each layer building upon the features learned by previous layers, allowing the model to capture increasingly complex patterns. Their ability to automatically extract relevant features from raw data makes neural networks particularly powerful for complex problems. However, they often require large amounts of computational resources, and their complex structure can make them difficult to interpret compared to simpler models.27

Application of Methods with Code

The Role of Training and Testing Datasets

Here we used a 75% training (N = 1,858) and 25% testing (N = 620) set split of the 2,478 total participants. This split was stratified to make sure that we had approximately the same proportion of participants who relapsed in both datasets.

set.seed(1)

a_split <- initial_split(analysis, strata = "did_relapse")

a_train <- training(a_split)
save(a_train, file = "data/a_train.RData")

a_test <- testing(a_split)
save(a_test, file = "data/a_test.RData")

The set.seed() function specifies which set of random numbers will be used for all tasks that involve randomization. The initial_split() function assigns each participant to be a member of the training or testing dataset. The training() and testing() functions create the datasets.

After splitting the data into the train and testing sets, the training data is divided into five analysis and assessment folds to allow the cross validation of the models described above.

a_fold <- vfold_cv(a_train, v = 5)

Preprocessing Recipe

Prior to modeling, all data were preprocessed using the same steps which were programmed using the R recipies package (version 1.3.1). That code, which is shown in the Workflow section above, is repeated here.

set.seed(305)

a_split <- initial_split(analysis, strata = "did_relapse")

a_train <- training(a_split)
save(a_train, file = "data/a_train.RData")

a_test <- testing(a_split)
save(a_test, file = "data/a_test.RData")

a_fold <- vfold_cv(a_train, v = 5)

a_recipe <-
  recipe(formula = did_relapse ~ ., data = a_train) |>
  # no update_role because DALEXtra::explain_tidymodels uses who as a predictor
  # update_role(who, new_role = "id variable") |>
  step_nzv(all_predictors()) |>
  # Needed for KNN see https://github.com/tidymodels/recipes/issues/926
  step_string2factor(all_nominal_predictors()) |>
  step_impute_knn(all_predictors()) |>
  step_dummy(all_nominal_predictors()) |>
  step_other(all_nominal_predictors()) |>
  step_corr(all_numeric_predictors()) |>
  step_normalize(all_numeric_predictors())

The recipe we specify is applied to each subset. As discussed above, this logic, rather than doing the preprocessing step once on the complete data (i.e., the full original dataset), is critical to ensure that the algorithm is not learning from other subsets. That potential problem, called leaking information, could cause the model to give optimistic estimates of future performance. As mentioned in the paper, the steps we take here are to remove any predictor that has nearly no variability with step_nzv(all_predictors()). Then any character string is converted to a factor with step_string2factor(all_predictors()). Logically, this step is not needed but it is in place to deal with a software weakness/bug. Any missing predictor values are imputed using the KNN algorithm step_impute_knn(all_predictors()). Categorical variables are dummy-coded step_dummy(all_nominal_predictors()). Extremely rare levels of nominal categorical predictors (levels with less than 5% frequency) are pooled into an “other” category step_other(all_nominal_predictors()). Next, the set of all numeric variables are examined and those that were correlated above 0.9 are removed using an algorithm that keeps as many variables as possible step_corr(all_numeric_predictors()).28 Finally, numeric variables are normalized step_normalize(all_numeric_predictors()). While not every predictive method requires all of these steps, we use a common recipe to make sure that every model was able to use the same predictors for the same participants.

If you are curious, you can study several alternative recipes that we tried by looking in the other_recipes.R file.

#######################################################################
###### original recipe before any offshoot branches were created ######
#######################################################################
#             79 features retained in the model - Recipe A
#######################################################################
a_recipe <-
  recipe(formula = did_relapse ~ ., data = a_train) |>
  update_role(who, new_role = "id variable") |>
  step_corr(all_numeric_predictors()) |>
  step_other(all_nominal_predictors()) |>
  step_zv(all_predictors()) |>          # reorder for running with out a trial
  step_impute_bag(all_predictors()) |>
  step_dummy(all_nominal_predictors()) |>
  step_zv(all_predictors()) |>          # reorder for running with out a trial
  step_normalize(all_numeric_predictors())

################################################################
###### model with PCA variable reduction and 5 components ######
################################################################
#        7 features retained in the model - Recipe C
################################################################
a_recipe <-
  recipe(formula = did_relapse ~ ., data = a_train) |>
  update_role(who, new_role = "id variable") |>
  step_other(all_nominal_predictors()) |>
  step_dummy(all_nominal_predictors()) |>
  step_nzv(all_predictors()) |>
  step_impute_bag(all_predictors()) |>
  step_corr(all_numeric_predictors()) |>
  step_normalize(all_numeric_predictors()) |>
  step_pca(all_predictors(), num_comp = 5)


#########################################################
###### model with step_upsample to balance classes ######
#########################################################
#     110 features retained in the model - Recipe D
#########################################################
a_recipe <-
 recipe(formula = did_relapse ~ ., data = a_train) |>
 update_role(who, new_role = "id variable") |>
 themis::step_upsample(did_relapse, over_ratio = 1) |> # default and ensures 50/50
 step_nzv(all_predictors()) |>
 step_impute_knn(all_predictors()) |>
 step_dummy(all_nominal_predictors()) |>
 step_other(all_nominal_predictors()) |>
 step_corr(all_numeric_predictors()) |>
 step_normalize(all_numeric_predictors())

Algorithms

KNN

The k-nearest neighbor analyses were conducted using the kknn package (version 1.4.1). The model selected possible values of k between 1 and 50 along with the 10 default distance metrics provided in tidymodels.

# library(kknn)
# library(vip)
# library(ggthemes)
set.seed(305)

knn_spec <- 
  nearest_neighbor(
    neighbors = tune(), 
    weight_func = tune(), 
    dist_power = tune()
  ) |>
  set_engine("kknn") |>
  set_mode("classification")

knn_grid <-
  grid_latin_hypercube(
    neighbors(c(1, 50)),
    weight_func(),
    dist_power(),
    size = 50
  )

knn_workflow <-
  workflow() |>
  add_recipe(a_recipe) |>
  add_model(knn_spec)

# doParallel::registerDoParallel(cores = parallel::detectCores() - 2)
knn_tune_res <- 
  tune_grid(
    knn_workflow,
    resamples = a_fold,
    grid = knn_grid,
    control = control_grid(save_pred = TRUE)
  )
# doParallel::stopImplicitCluster()

collect_metrics(knn_tune_res)

knn_resample_best <- 
  show_best(knn_tune_res, metric = "roc_auc")
save(knn_resample_best, file = "data/knn_resample_best.RData")

knn_details <- 
  select_best(knn_tune_res, metric = "roc_auc")
knn_final <- 
  finalize_workflow(knn_workflow, knn_details)

knn_final_fit <-
  fit(knn_final, data = a_train)
save(knn_final_fit, file = "data/knn_final_fit.RData")

# Create a local data directory to hold modeling results
if (!dir.exists(paste0(here::here(), "/data"))) {
  dir.create(paste0(here::here(), "/data"))
}

knn_autoplot <-
  autoplot(knn_tune_res) +
  ggthemes::theme_few() +
  theme(text = element_text(size = 30))
save(knn_autoplot, file = "data/knn_autoplot.RData")

# Obtain model metrics
knn_metrics <-
  augment(knn_final_fit, new_data = a_train) |>
  mutate(estimate = .pred_0) |>
  roc_auc(truth = did_relapse, estimate)
save(knn_metrics, file = "data/knn_metrics.Rdata")

knn_conf_mat <-
  augment(knn_final_fit, new_data = a_train) |>
  conf_mat(truth = did_relapse, estimate = .pred_class)
save(knn_conf_mat, file = "data/knn_conf_mat.Rdata")


knn_metrics_test <-
  augment(knn_final_fit, new_data = a_test) |>
  mutate(estimate = .pred_0) |>
  roc_auc(truth = did_relapse, estimate)
save(knn_metrics_test, file = "data/knn_metrics_test.Rdata")

knn_conf_mat_test <-
  augment(knn_final_fit, new_data = a_test) |>
  conf_mat(truth = did_relapse, estimate = .pred_class)
save(knn_conf_mat, file = "data/knn_conf_mat_test.Rdata")

The first sets of code are specifying the kind of model we will use and then tuning the hyperparameters.

The nearest_neighbor() function specifies the hyperparameters that need to be trained for the model. The arguments for nearest_neighbors function include the number of neighbors to evaluate, as well as the weight_func and dist_power details which influence the importance of close proximity of the neighbors.

The grid_latin_hypercube() is used because there are so many possible combinations of neighbors, weight functions and distance powers. It specifies that rather than trying every possible combination of the number of neighbors, weighting functions and distance power functions, we will use 50 values that are randomly selected to make sure that the choices are representative of all the combinations.

The knn_workflow() function specifies that the recipe and grid will be used together.

We actually tune the hyperparameters using the tune_grid() function.

After running the model, we extract and save the results from the modeling objects. As discussed in the paper and shown below, even though it is not normally done, we save results from applying the model to the original training data as well as the testing data. In a normal workflow, we would fit the model, collect the metrics from the training data, and choose to use the optimal values. If this algorithm is the best, then we would apply the tuned hyperparameters to the test data. See below for an explanation of the code.

Logistic

A traditional logistic model was built using the glm function from the r stats package (version 4.5.0)

# library(vip)
# library(ggthemes)
set.seed(305)

logistic_spec <- 
  logistic_reg() |>
  set_engine(engine = "glm")

logistic_workflow <- 
  workflow() |>
  add_recipe(a_recipe) |>
  add_model(logistic_spec)

logistic_final_fit <- 
  logistic_workflow |>
  fit(data = a_train)
save(logistic_final_fit, file = "data/logistic_final_fit.RData")

# logistic_metrics
logistic_trained <- 
  logistic_final_fit |>
  extract_fit_parsnip()

logistic_results <- 
  logistic_final_fit |>
  augment(new_data = a_train)

logistic_metrics <- 
  bind_rows(
    logistic_results |>
      roc_auc(did_relapse, .pred_0),
    logistic_results |>
      accuracy(did_relapse, .pred_class)
  )
save(logistic_metrics, file = "data/logistic_metrics.RData")

# make a local data directory to hold modeling results
if (!dir.exists(paste0(here::here(), "/data"))) {
  dir.create(paste0(here::here(), "/data"))
}

logistic_betas <- 
  logistic_trained |>
  tidy(conf.int = TRUE, conf.level = 0.95, exponentiate = TRUE) |>
  arrange(desc(abs(estimate))) |>
  select(term, estimate, conf.low, conf.high, std.error, p.value) |>
  mutate(across(estimate:std.error, ~ round(.x, digits = 2))) |>
  mutate(p.value = scales::pvalue(p.value))
save(logistic_betas, file = "data/logistic_betas.RData")

logistic_vip <- 
  vip(
    logistic_trained,
    geom = "point"
  ) +
  theme_few() +
  theme(text = element_text(size = 30))
save(logistic_vip, file = "data/logistic_vip.RData")

logistic_importance <- 
  extract_fit_engine(logistic_trained) |>
  vip::vi()
save(logistic_importance, file = "data/logistic_importance.RData")

logistic_conf_mat <- 
  logistic_results |>
  conf_mat(truth = did_relapse, estimate = .pred_class)
save(logistic_conf_mat, file = "data/logistic_conf_mat.RData")

logistic_metrics_test <- 
  augment(logistic_final_fit, new_data = a_test) |>
  mutate(estimate = .pred_0) |>
  roc_auc(truth = did_relapse, estimate)
save(logistic_metrics_test, file = "data/logistic_metrics_test.RData")

# Needed to complete table(s) in summarize.R script
# Because no resampling is done in the logistic modeling, it is necessary to create a 
# tibble as if resampling was possible.
logistic_resample_best <- 
  tibble(
    penalty = NULL,
    .metric = "roc_auc",
    .estimator = "binary",
    mean = NA_real_,
    n = NA_real_,
    std_err = NA_real_,
    .config = NULL
  )
save(logistic_resample_best, file = "data/logistic_resample_best.RData")

While the algorithms used to do logistic regression in the context of ML modeling are the same as traditional statistical methods, their practical implementation gains from a ML framework. Rather than fitting a single model using all the available data, as is often done by classically trained practitioners, people doing ML will typically set aside a portion of the data calling it a “test” set, build a model on the remaining “training” data (or average the result of a set of models generated by using cross validation or bootstrap techniques), and then evaluate the model on the test data by computing a metric like classification accuracy or ROC AUC. In this context, logistic regression becomes just another competitor which is evaluated to see if it does the best job predicting the data. While traditional logistic regression cannot readily detect the complex patterns in the data which other methods note, it has not fallen out of favor because it offers relatively simple explanations, in terms of the change in the odds of an outcome as each predictor changes.

That said, the traditional “manual” modeling process historically taught in epidemiology classes is unpopular in the ML framework. That process, which exploits subject matter knowledge to pick which variables to include in a prespecified order (typically by looking at the impact of each predictor by itself, then including variables that are statistically significant in other models), followed by manually removing variables (often based on a p-value based criteria), has a tendency to produce model results which do not replicate when applied to new datasets. Instead of relying on p-value-based metrics, some practitioners will use automatic “stepwise” modeling procedures that are optimized based on criteria other than p-values, like AIC. Even these automatic stepwise procedures are being supplanted by ML “shrinkage” methods like LASSO.

LASSO Logistic

LASSO models were built using the glmnet package (version 4.1.8) with 30 different levels of regularization penalty, or “shrinkage”, between .001 and 1.

# suppressMessages(library(glmnet))
set.seed(305)

lasso_spec <- 
  logistic_reg(penalty = tune(), mixture = 1) |>
  set_engine("glmnet")

lasso_workflow <- 
  workflow() |>
  add_recipe(a_recipe) |>
  add_model(lasso_spec)

penalty_grid <- 
  grid_regular(penalty(range = c(-3, 0)), levels = 30)

# doParallel::registerDoParallel(cores = parallel::detectCores() - 2)
lasso_tune_res <- 
  tune_grid(
    lasso_workflow,
    resamples = a_fold,
    grid = penalty_grid
  )
# doParallel::stopImplicitCluster()

lasso_resample_best <- 
  show_best(lasso_tune_res, metric = "roc_auc")
save(lasso_resample_best, file = "data/lasso_resample_best.RData")

best_penalty <- 
  select_best(lasso_tune_res, metric = "roc_auc")

lasso_final <- 
  finalize_workflow(lasso_workflow, best_penalty)

lasso_final_fit <- 
  fit(lasso_final, data = a_train)
save(lasso_final_fit, file = "data/lasso_final_fit.RData")

# make a local data directory to hold modeling results
if (!dir.exists(paste0(here::here(), "/data"))) {
  dir.create(paste0(here::here(), "/data"))
}

lasso_betas <- 
  lasso_final_fit |>
  tidy(exponentiate = TRUE) |>
  filter(estimate != 0) |>
  select(term, estimate) |>
  arrange(desc(estimate))
save(lasso_betas, file = "data/lasso_betas.RData")

lasso_vip <- 
  vip(
    lasso_final_fit,
    geom = "point"
  ) +
  theme_few() +
  theme(text = element_text(size = 30))
save(lasso_vip, file = "data/lasso_vip.RData")

lasso_importance <- 
  extract_fit_engine(lasso_final_fit) |>
  vip::vi()
save(lasso_importance, file = "data/lasso_importance.RData")

lasso_autoplot <- 
  autoplot(lasso_tune_res) +
  ggthemes::theme_few() +
  theme(text = element_text(size = 30))
save(lasso_autoplot, file = "data/lasso_autoplot.RData")


lasso_metrics <- 
  augment(lasso_final_fit, new_data = a_train) |>
  mutate(estimate = .pred_0) |>
  roc_auc(truth = did_relapse, estimate)
save(lasso_metrics, file = "data/lasso_metrics.RData")

lasso_conf_mat <- 
  augment(lasso_final_fit, new_data = a_train) |>
  conf_mat(truth = did_relapse, estimate = .pred_class)
save(lasso_conf_mat, file = "data/lasso_conf_mat.RData")


lasso_metrics_test <- 
  augment(lasso_final_fit, new_data = a_test) |>
  mutate(estimate = .pred_0) |>
  roc_auc(truth = did_relapse, estimate)
save(lasso_metrics_test, file = "data/lasso_metrics_test.RData")

lasso_conf_mat_test <- 
  augment(lasso_final_fit, new_data = a_test) |>
  conf_mat(truth = did_relapse, estimate = .pred_class)
save(lasso_conf_mat_test, file = "data/lasso_conf_mat_test.RData")

Tuning LASSO regression models involves one parameter: penalty (the amount of regularization). The same function in R can also be used to fit a ridge regression or elastic net model. The code we are using sets mixture = 1, to specify our use of a pure LASSO model. The LASSO technique is also commonly referred to as L1 regularization. Ridge regression (L2 regularization) sets mixture = 0 and elastic net regressions use mixture values between 0 and 1.22

The tune_grid() function computes ROC AUC for various LASSO shrinkage (penalty) values.29 The penalty parameter represents the amount of shrinkage. The impact of the shrinkage will be assessed across 30 levels with log-scaled values from -3 to 0. The goal of applying these penalties is to shrink the coefficient values (the beta values) towards zero. The result is that some coefficients may be set to zero, meaning the variable does not impact the outcome, while others will have their magnitude reduced. In other words, the variables with the greatest impact on treatment outcome will remain important, though with a possibly dampened effect. This is an effective application of feature reduction.

The level of regularization chosen for the model is the one that produces the greatest value for the ROC AUC with cross-validated resamples, then the final workflow is fit to the training data. To obtain exponentiated beta estimates corresponding to odds ratios (ORs), the lasso_final_fit object holding the model results is sent to tidy(exponentiate = TRUE). The value of the OR is the change in odds of the outcome for a one-unit increase in the predictor variable. We apply more cleaning steps by filtering to only non-zero variables (filter(estimte != 0)) and selecting the term (named variable) and estimate, which now is the OR. Finally, the results are arranged in descending value of OR which sorts the highest ORs to the top of the table and draws the reader’s attention to the variables with the greatest impact on the relapse outcome.

Logistic with Resampling

We can trick the tidymodels ecosystem to use cross-validated resamples by using the infrastructure for LASSO but force tidymodels to not apply the full LASSO technique. Basically, we are saying use the LASSO code but don’t try different levels of shrinkage. Use only a tiny amount.

# suppressMessages(library(glmnet))
# library(vip)
# library(ggthemes)
set.seed(305)

lasso_spec <- 
  logistic_reg(penalty = tune(), mixture = 1) |>
  set_engine("glmnet")

penalty_grid <-
  grid_regular(penalty(range = c(-10, -10)), levels = 1)

lasso_workflow <- 
  workflow() |>
  add_recipe(a_recipe) |>
  add_model(lasso_spec)

# doParallel::registerDoParallel(cores = parallel::detectCores() - 2)
logistic_via_lasso_tune_res <- 
  tune_grid(
    lasso_workflow,
    resamples = a_fold,
    grid = penalty_grid
  )
# doParallel::stopImplicitCluster()


logistic_via_lasso_resample_best <- 
  show_best(logistic_via_lasso_tune_res, metric = "roc_auc")
save(
  logistic_via_lasso_resample_best, 
  file = "data/logistic_via_lasso_resample_best.RData"
)

logistic_via_lasso_best_penalty <- 
  select_best(logistic_via_lasso_tune_res, metric = "roc_auc")

logistic_via_lasso_final <- 
  finalize_workflow(lasso_workflow, logistic_via_lasso_best_penalty)

logistic_via_lasso_final_fit <- 
  fit(logistic_via_lasso_final, data = a_train)
save(
  logistic_via_lasso_final_fit, 
  file = "data/logistic_via_lasso_final_fit.RData"
)

# make a local data directory to hold modeling results
if (!dir.exists(paste0(here::here(), "/data"))) {
  dir.create(paste0(here::here(), "/data"))
}

logistic_via_lasso_betas <- 
  logistic_via_lasso_final_fit |>
  tidy(exponentiate = TRUE) |>
  filter(estimate != 0) |>
  select(term, estimate) |>
  arrange(desc(estimate))
save(logistic_via_lasso_betas, file = "data/logistic_via_lasso_betas.RData")

logistic_via_lasso_vip <- 
  vip(
    logistic_via_lasso_final_fit,
    geom = "point"
  ) +
  theme_few() +
  theme(text = element_text(size = 30))
save(logistic_via_lasso_vip, file = "data/logistic_via_lasso_vip.RData")

logistic_via_lasso_importance <- 
  extract_fit_engine(logistic_via_lasso_final_fit) |>
  vip::vi()
save(
  logistic_via_lasso_importance, 
  file = "data/logistic_via_lasso_importance.RData"
)

logistic_via_lasso_metrics <- 
  augment(logistic_via_lasso_final_fit, new_data = a_train) |>
  mutate(estimate = .pred_0) |>
  roc_auc(truth = did_relapse, estimate)
save(logistic_via_lasso_metrics, file = "data/logistic_via_lasso_metrics.RData")

logistic_via_lasso_conf_mat <- 
  augment(logistic_via_lasso_final_fit, new_data = a_train) |>
  conf_mat(truth = did_relapse, estimate = .pred_class)
save(
  logistic_via_lasso_conf_mat, 
  file = "data/logistic_via_lasso_conf_mat.RData"
)

logistic_via_lasso_metrics_test <- 
  augment(logistic_via_lasso_final_fit, new_data = a_test) |>
  mutate(estimate = .pred_0) |>
  roc_auc(truth = did_relapse, estimate)
save(
  logistic_via_lasso_metrics_test, 
  file = "data/logistic_via_lasso_metrics_test.RData"
)

logistic_via_lasso_conf_mat_test <- 
  augment(
    logistic_via_lasso_final_fit, 
    new_data = a_test
  ) |>
  conf_mat(truth = did_relapse, estimate = .pred_class)
save(
  logistic_via_lasso_conf_mat_test, 
  file = "data/logistic_via_lasso_conf_mat_test.RData"
)

As discussed above tuning LASSO regression models involves one parameter: penalty (the amount of regularization) and setting mixture = 1. A limiting factor of tidymodels is that it does not discretely allow for fitting a logistic regression model with resampling. In order to use the cross-validated resamples in a logistic regression model, the penalty parameter should not be tuned. We achieve this by first setting the engine to use LASSO’s glmnet engine and creating a grid_regular() with one level and a one-value range (range = c(-10, -10)). In essence, the model will be tuned to only a mixture = 1 and penalty = -10. The net effect is that tidymodels regards our model as a LASSO model, but since we have set it so no penalty is being applied to the variables, it can therefore fit resamples as a logistic regression model.

The chosen model fit is the one that produces the greatest value for the ROC AUC with cross-validated resamples, then the final workflow is fit to the training data. To obtain exponentiated beta estimates corresponding to odds ratios (ORs) we pipe (i.e., |>) the results to the lasso_final_fit object to tidy(exponentiate = TRUE). The value of the OR is the change in odds of the outcome for a one-unit increase in the predictor variable. We apply more cleaning steps by filtering to only non-zero variables (filter(estimte != 0)) and selecting the term (named variable) and estimate, which now is the OR. Finally, the results are arranged in descending value of OR which sorts the highest ORs to the top of the table and draws the reader’s attention to the variables with the greatest impact on the relapse outcome.

SVM

# if (!require('kernlab')) install.packages('kernlab')
# NOTE: see more details here: https://parsnip.tidymodels.org/reference/svm_rbf.html
set.seed(305)

svm_spec <-
  svm_poly(
    mode = "classification",
    cost = tune(),
    degree = tune()
  ) |>
  set_engine("kernlab")

linear_grid <- 
  grid_regular(
    cost(),
    degree(range = c(1, 3)),
    levels = 10
  )

svm_workflow <-
  workflow() |>
  add_recipe(a_recipe) |>
  add_model(svm_spec)

# doParallel::registerDoParallel(cores = parallel::detectCores() - 2)
svm_tune_res <-
  tune_grid(
    svm_workflow,
    resamples = a_fold,
    grid = linear_grid
  )
# doParallel::stopImplicitCluster()

svm_resample_best <-
  show_best(svm_tune_res, metric = "roc_auc")
save(svm_resample_best, file = "data/svm_resample_best.RData")

best_penalty <-
  select_best(svm_tune_res, metric = "roc_auc")
svm_final <-
  finalize_workflow(svm_workflow, best_penalty)

svm_final_fit <-
  fit(svm_final, data = a_train)
save(svm_final_fit, file = "data/svm_final_fit.RData")

# make a local data directory to hold modeling results
if (!dir.exists(paste0(here::here(), "/data"))) {
  dir.create(paste0(here::here(), "/data"))
}

svm_autoplot <-
  autoplot(svm_tune_res) +
  ggthemes::theme_few() +
  theme(text = element_text(size = 30))
save(svm_autoplot, file = "data/svm_autoplot.RData")

svm_metrics <-
  augment(svm_final_fit, new_data = a_train) |>
  mutate(estimate = .pred_0) |>
  roc_auc(truth = did_relapse, estimate)
save(svm_metrics, file = "data/svm_metrics.RData")

svm_conf_mat <-
  augment(svm_final_fit, new_data = a_train) |>
  conf_mat(truth = did_relapse, estimate = .pred_class)
save(svm_conf_mat, file = "data/svm_conf_mat.RData")


svm_metrics_test <-
  augment(svm_final_fit, new_data = a_test) |>
  mutate(estimate = .pred_0) |>
  roc_auc(truth = did_relapse, estimate)
save(svm_metrics_test, file = "data/svm_metrics_test.RData")

svm_conf_mat_test <-
  augment(svm_final_fit, new_data = a_test) |>
  conf_mat(truth = did_relapse, estimate = .pred_class)
save(svm_conf_mat_test, file = "data/svm_conf_mat_test.RData")


# variable importance ----
svm_importance <-
  svm_final_fit |>
  vip::vi(
    method = "permute",
    nsim = 5,
    target = "did_relapse",
    metric = "roc_auc",
    event_level = "first",
    pred_wrapper = 
      function(object, newdata) predict(object, 
                                        newdata, 
                                        type = "prob")[[".pred_0"]],
    # can use the training data since the function in the pred_wrapper arg is 
    # using the normal predict function, and expects non-transformed or 
    # baked data
    train = a_train )
svm_importance
save(svm_importance, file = "data/svm_importance.RData")

svm_vip <-
  svm_importance |>
  arrange(desc(Importance)) |>
  head(10) |>
  ggplot(aes(Importance, reorder(Variable, Importance))) +
  geom_point() +
  theme_few()
svm_vip
save(svm_vip, file = "data/svm_vip.RData")

CART

CART models were built using the rpart package (version 4.1.23) with trees tuned on tree depth (1, 15), minimum number needed to split (2, 40) and cost complexity (log10 scale between -10 and -1).

cart_spec <-
  decision_tree(
    tree_depth = tune(), min_n = tune(), cost_complexity = tune()
  ) |>
  set_engine("rpart") |>
  set_mode("classification")

The decision_tree() function fits both classification and regression models and creates a tree-based structure. Three parameters(tree_depth, min_n, and cost_complexity) were tuned in order to obtain a decision tree model that better captures the underlying patterns of the data. “rpart” is the package set_engine() used for building the classification and regression trees by implementing recursive partitioning.

cart_workflow <- workflow() |>
  add_recipe(a_recipe) |>
  add_model(cart_spec)

The steps added to the a_recipe object include: a near-zero variance filter, KNN imputation on all predictors, dummy coding on all nominal predictors, a step that creates an “other“ category to house infrequently occurring values, a high correlation filter, and a step to normalize numeric data to have a standard deviation of one and a mean of zero.

Random Forest

# suppressPackageStartupMessages(library(randomForest))
# library(vip)
# library(ggthemes)
set.seed(305)

rf_spec <-
  rand_forest(mtry = tune(), min_n = tune()) |>
  set_engine("randomForest") |>
  set_mode("classification")

rf_grid <- 
  grid_latin_hypercube(
    min_n(),
    finalize(mtry(), a_train),
    size = 50
  )

rf_workflow <-
  workflow() |>
  add_recipe(a_recipe) |>
  add_model(rf_spec)

# doParallel::registerDoParallel(cores = parallel::detectCores() - 2)
rf_tune_res <-
  tune_grid(
    rf_workflow,
    resamples = a_fold,
    grid = rf_grid,
    control = control_grid(save_pred = TRUE)
  )
# doParallel::stopImplicitCluster()

collect_metrics(rf_tune_res)

rf_resample_best <-
  show_best(rf_tune_res, metric = "roc_auc")
save(rf_resample_best, file = "data/rf_resample_best.RData")

rf_details <-
  select_best(rf_tune_res, metric = "roc_auc")
final_rf <-
  finalize_workflow(rf_workflow, rf_details)

rf_final_fit <-
  fit(final_rf, data = a_train)
save(rf_final_fit, file = "data/rf_final_fit.RData")

# make a local data directory to hold modeling results
if (!dir.exists(paste0(here::here(), "/data"))) {
  dir.create(paste0(here::here(), "/data"))
}
rf_autoplot <-
  autoplot(rf_tune_res) +
  ggthemes::theme_few() +
  theme(text = element_text(size = 30))
save(rf_autoplot, file = "data/rf_autoplot.RData")

rf_final <-
  last_fit(final_rf, a_split)
rf_final_metrics <-
  rf_final |> collect_metrics()

rf_metrics <-
  augment(rf_final_fit, new_data = a_train) |>
  mutate(estimate = .pred_0) |>
  roc_auc(truth = did_relapse, estimate)
save(rf_metrics, file = "data/rf_metrics.RData")

rf_conf_mat <-
  augment(rf_final_fit, new_data = a_train) |>
  conf_mat(truth = did_relapse, estimate = .pred_class)
save(rf_conf_mat, file = "data/rf_conf_mat.RData")

rf_metrics_test <-
  augment(rf_final_fit, new_data = a_test) |>
  mutate(estimate = .pred_0) |>
  roc_auc(truth = did_relapse, estimate)
save(rf_metrics_test, file = "data/rf_metrics_test.RData")

rf_conf_mat_test <-
  augment(rf_final_fit, new_data = a_test) |>
  conf_mat(truth = did_relapse, estimate = .pred_class)
save(rf_conf_mat_test, file = "data/rf_conf_mat_test.RData")

last_results <-
  last_fit(final_rf, a_split)

rf_last_metrics <-
  collect_metrics(last_results) |>
  select(-c(.config, .estimator)) |>
  rename(Method = .metric, Estiamte = .estimate)
save(rf_last_metrics, file = "data/rf_last_metrics.RData")

# variable importance -----
rf_vip <-
  rf_final_fit |>
  extract_fit_parsnip() |>
  vip(geom = "point") +
  theme_few() +
  theme(text = element_text(size = 30))
save(rf_vip, file = "data/rf_vip.Rdata")

rf_importance <-
  extract_fit_engine(rf_final_fit) |>
  vip::vi()
save(rf_importance, file = "data/rf_importance.RData")

BART

# library(dbarts)
set.seed(305)

bart_spec <- 
  bart(
    trees = tune(), prior_terminal_node_coef = tune(),
    prior_terminal_node_expo = tune(), prior_outcome_range = tune()
  ) |>
  set_engine("dbarts") |>
  set_mode("classification")

bart_grid <- 
  grid_latin_hypercube(
    trees(), prior_terminal_node_coef(),
    prior_terminal_node_expo(), prior_outcome_range(),
    size = 50
  )

bart_workflow <-
  workflow() |>
  add_recipe(a_recipe) |>
  add_model(bart_spec)

# doParallel::registerDoParallel(cores = parallel::detectCores() - 2)
bart_tune_res <-
  tune_grid(
    bart_workflow,
    resamples = a_fold,
    grid = bart_grid,
    control = control_grid(save_pred = TRUE)
  )
# doParallel::stopImplicitCluster()

bart_resample_best <-
  show_best(bart_tune_res, metric = "roc_auc")
save(bart_resample_best, file = "data/bart_resample_best.RData")

# show_best(bart_tune_res, metric = "roc_auc")

bart_details <-
  select_best(bart_tune_res, metric = "roc_auc")
final_bart <-
  finalize_workflow(bart_workflow, bart_details)

bart_final_fit <-
  fit(final_bart, data = a_train)
save(bart_final_fit, file = "data/bart_final_fit.RData")

# Create a local data directory to hold modeling results
if (!dir.exists(paste0(here::here(), "/data"))) {
  dir.create(paste0(here::here(), "/data"))
}
bart_autoplot <-
  autoplot(bart_tune_res) +
  ggthemes::theme_few() +
  theme(text = element_text(size = 30))
save(bart_autoplot, file = "data/bart_autoplot.RData")

bart_final <-
  last_fit(final_bart, a_split)
bart_final_metrics <- bart_final |> collect_metrics()

# Obtain model metrics
bart_metrics <- 
  augment(bart_final_fit, new_data = a_train) |>
  mutate(estimate = .pred_0) |>
  roc_auc(truth = did_relapse, estimate)
save(bart_metrics, file = "data/bart_metrics.RData")

bart_conf_mat <-
  augment(bart_final_fit, new_data = a_train) |>
  conf_mat(truth = did_relapse, estimate = .pred_class)
save(bart_conf_mat, file = "data/bart_conf_mat.RData")

bart_metrics_test <-
  augment(bart_final_fit, new_data = a_test) |>
  mutate(estimate = .pred_0) |>
  roc_auc(truth = did_relapse, estimate)
save(bart_metrics_test, file = "data/bart_metrics_test.RData")

bart_conf_mat_test <-
  augment(bart_final_fit, new_data = a_test) |>
  conf_mat(truth = did_relapse, estimate = .pred_class)
save(bart_conf_mat_test, file = "data/bart_conf_mat_test.RData")

# Final fitting and predictions
last_results <- 
  last_fit(final_bart, a_split)
bart_last_metrics <-
  collect_metrics(last_results) |>
  select(-c(.config, .estimator)) |>
  rename(Method = .metric, Estiamte = .estimate)
save(bart_last_metrics, file = "data/bart_last_metrics.RData")

Extreme Gradient Boost (XGBoost)

# library(xgboost)
# library(vip)
# library(ggthemes)
set.seed(305)

boost_tree_xgboost_spec <-
  boost_tree(
    trees = 1000,
    tree_depth = tune(), min_n = tune(),
    loss_reduction = tune(),
    sample_size = tune(), mtry = tune(),
    learn_rate = tune()
  ) |>
  set_engine("xgboost") |>
  set_mode("classification")

xgb_grid <- 
  grid_latin_hypercube(
    tree_depth(),
    min_n(),
    loss_reduction(),
    sample_size = sample_prop(),
    finalize(mtry(), a_train),
    learn_rate(),
    size = 50
  )

xgb_workflow <-
  workflow() |>
  add_recipe(a_recipe) |>
  add_model(boost_tree_xgboost_spec)

# doParallel::registerDoParallel(cores = parallel::detectCores() - 2)
xgb_tune_res <-
  tune_grid(
    xgb_workflow,
    resamples = a_fold,
    grid = xgb_grid,
    control = control_grid(save_pred = TRUE)
  )
# doParallel::stopImplicitCluster()

collect_metrics(xgb_tune_res)

xgb_resample_best <-
  show_best(xgb_tune_res, metric = "roc_auc")
save(xgb_resample_best, file = "data/xgb_resample_best.RData")

xgb_penalty <-
  select_best(xgb_tune_res, metric = "roc_auc")
xgb_final <-
  finalize_workflow(xgb_workflow, xgb_penalty)
xgb_final_fit <-
  fit(xgb_final, data = a_train)
save(xgb_final_fit, file = "data/xgb_final_fit.RData")

# make a local data directory to hold modeling results
if (!dir.exists(paste0(here::here(), "/data"))) {
  dir.create(paste0(here::here(), "/data"))
}

xgb_metrics <-
  augment(xgb_final_fit, new_data = a_train) |>
  mutate(estimate = .pred_0) |>
  roc_auc(truth = did_relapse, estimate)
save(xgb_metrics, file = "data/xgb_metrics.RData")

xgb_conf_mat <-
  augment(xgb_final_fit, new_data = a_train) |>
  conf_mat(truth = did_relapse, estimate = .pred_class)
save(xgb_conf_mat, file = "data/xgb_conf_mat.RData")

xgb_metrics_test <-
  augment(xgb_final_fit, new_data = a_test) |>
  mutate(estimate = .pred_0) |>
  roc_auc(truth = did_relapse, estimate)
save(xgb_metrics_test, file = "data/xgb_metrics_test.RData")

xgb_conf_mat_test <-
  augment(xgb_final_fit, new_data = a_test) |>
  conf_mat(truth = did_relapse, estimate = .pred_class)
save(xgb_conf_mat_test, file = "data/xgb_conf_mat_test.RData")

# variable importance -----
xgb_vip <-
  xgb_final_fit |>
  extract_fit_parsnip() |>
  vip(geom = "point") +
  theme_few() +
  theme(text = element_text(size = 30))
save(xgb_vip, file = "data/xgb_vip.RData")

xgb_importance <-
  extract_fit_engine(xgb_final_fit) |>
  vip::vi()
save(xgb_importance, file = "data/xgb_importance.RData")

Neural Networks

# suppressWarnings(library(AppliedPredictiveModeling))
# suppressWarnings(library(keras))
# suppressWarnings(library(torch))
# suppressWarnings(library(brulee))
# from here: https://www.tidymodels.org/learn/models/parsnip-nnet/
set.seed(305)

nnet_spec <-
  mlp(
    epochs = 100,
    hidden_units = tune(),
    penalty = tune(),
    learn_rate = tune()
  ) |>
  set_engine("brulee") |>
  set_mode("classification")

nnet_grid <- 
  grid_latin_hypercube(
  hidden_units(range = c(10, 100)),
  penalty(range = c(-5, 0)), # Log scale for penalty
  learn_rate(range = c(-5, -1)), # Log scale for learning rate
  size = 30 # Number of different hyperparameter combinations to try
)

nnet_workflow <- 
  workflow() |>
  add_recipe(a_recipe) |>
  add_model(nnet_spec)

nnet_tune_res <- 
  tune_grid(
  nnet_workflow,
  resamples = a_fold,
  grid = nnet_grid,
  metrics = metric_set(roc_auc, accuracy, kap)
)

# Select the best model
nnet_resample_best <- 
  show_best(nnet_tune_res, metric = "roc_auc")
save(nnet_resample_best, file = "data/nnet_resample_best.RData")

best_ann_tune <- 
  select_best(nnet_tune_res, metric = "roc_auc")
nnet_final <- 
  finalize_workflow(nnet_workflow, best_ann_tune)

nnet_final_fit <- 
  fit(nnet_final, data = a_train)
save(nnet_final_fit, file = "data/nnet_final_fit.RData")

# make a local data directory to hold modeling results
if (!dir.exists(paste0(here::here(), "/data"))) {
  dir.create(paste0(here::here(), "/data"))
}

nnet_autoplot <-
  autoplot(nnet_tune_res) +
  ggthemes::theme_few() +
  theme(text = element_text(size = 30))
save(nnet_autoplot, file = "data/nnet_autoplot.RData")

nnet_metrics <-
  augment(nnet_final_fit, new_data = a_train) |>
  mutate(estimate = .pred_0) |>
  roc_auc(truth = did_relapse, estimate)
save(nnet_metrics, file = "data/nnet_metrics.RData")

nnet_conf_mat <-
  augment(nnet_final_fit, new_data = a_train) |>
  conf_mat(truth = did_relapse, estimate = .pred_class)
save(nnet_conf_mat, file = "data/nnet_conf_mat.RData")

augmented <- augment(nnet_final_fit, new_data = a_test)

nnet_metrics_test <-
  augmented |>
  mutate(estimate = .pred_0) |>
  roc_auc(truth = did_relapse, estimate)
save(nnet_metrics_test, file = "data/nnet_metrics_test.RData")

nnet_conf_mat_test <-
  augmented |>
  conf_mat(truth = did_relapse, estimate = .pred_class)
save(nnet_conf_mat_test, file = "data/nnet_conf_mat_test.RData")

nnet_importance <- 
  nnet_final_fit |>
  vip::vi(
    method = "permute",
    nsim = 5,
    target = "did_relapse",
    metric = "roc_auc",
    event_level = "first",
    pred_wrapper = 
      function(object, newdata) predict(object, newdata, type = "prob")[[".pred_0"]],
    # can use the training data since the function in the pred_wrapper arg is 
    # using the normal predict function, and expects non-transformed or baked 
    # data
    train = a_train 
  )
nnet_importance
save(nnet_importance, file = "data/nnet_importance.RData")

nnet_vip <- 
  nnet_importance |>
  arrange(desc(Importance)) |>
  head(10) |>
  ggplot(aes(Importance, reorder(Variable, Importance))) +
  geom_point() +
  theme_few()
nnet_vip
save(nnet_vip, file = "data/nnet_vip.RData")

Tuning neural network models in tidymodels requires the selection of several key parameters: hidden_units (the number of nodes in the hidden layer), penalty (the amount of regularization), and learn_rate (the step size for the optimization algorithm). The penalty parameter controls the model’s complexity by adding a cost to large weights, helping to prevent overfitting. A higher penalty leads to simpler models, while a lower penalty allows for more complex patterns to be captured. The learn_rate determines how much the model adjusts its weights in response to the estimated error each time the model weights are updated. A higher learning rate means faster learning but risks overshooting the optimal solution, while a lower rate learns more slowly but can lead to more precise convergence.30

The nodes that reside in the hidden layer of the neural network lie between the input and output layers and are tasked with performing computations to detect complexities in the data. We use the mlp() function to specify a multilayer perceptron model, a type of feedforward artificial neural network. In this method, the computations performed in the hidden units are passed forward to the next layer.

The epochs parameter is set to 100, which determines the number of times the learning algorithm will work through the entire training dataset. The hidden_units, penalty, and learn_rate parameters are set to tune(), indicating that we want to find optimal values for these parameters. We use the “brulee” engine to model a neural network and set the mode to “classification” for our binary outcome prediction.

The grid_latin_hypercube() function is used to create a tuning grid with 30 different hyperparameter combinations, providing a good balance between exploration of the parameter space and computational efficiency. This method ensures a good coverage of the parameter space: hidden_units(range = c(10, 100)), penalty ranges on a log scale from -5 to 0, and learn_rate ranges on a log scale from -5 to -1. We use log scales for penalty and learn_rate to explore a wide range of values efficiently.

The workflow combines our recipe (a_recipe) with the neural network model specification. The tune_grid() function then trains models for each combination of hyperparameters in the grid, using cross-validation (a_fold). We evaluate the models using multiple metrics: ROC AUC, accuracy (the proportion of correct predictions), and Cohen’s Kappa (which measures agreement between predicted and actual classifications, adjusting for agreement by chance). The result of this tuning process will be a set of models with different hyperparameter combinations, evaluated across these metrics. The best performing model can then be selected based on the ROC AUC value. A confusion matrix and ROC plot is then generated to visualize the performance of the neural network model.

There are, however, caveats for using a modeling technique as a neural network. Foremost, their interpretability is often difficult as they are regarded as a “black box” model, making it relatively difficult to understand how they arrived at their predictions in comparison to logistic regression or other linear modeling methods. Second, neural networks are computationally expensive, meaning that training times are longer, especially with larger datasets, than random forest or boosted gradient models. Third, neural networks require tuning many parameters to achieve optimal performance. Comparatively, our LASSO regression was tuned only for optimal penalty whereas our neural network demanded a deeper understanding of three complex parameters.

We chose to include this complex model because we wished to evaluate its performance against the more conventional methods used in our area of research, but we were unable to achieve greater performance metrics with a neural network. It is important to and reiterate the importance of careful model tuning and interpretation when equipped with other comparative models.

Results

One of the major benefits of using the tidymodels ecosystem of R packages is that the same code can be used to to evaluate the performance of every algorithm. Normally the cross-validation results are used to pick a “winning” algorithm that has the best performance (on a metric like accuracy or ROC AUC) and then the testing data is used once to give an estimate of future model performance on new data. Here, because we wanted to use the data for teaching, we applied the tuned algorithms back to all the training data. As is shown in the paper, the trained models pick up on the true signal but also idiosyncratic details. This shows in the paper that when the ROC AUC produced on the full training data it is superior to the cross validation results. We also show what would have happened if we applied all the models to the test data. Outside of a teaching environment this is unwise because it could lead people to select a model based on the testing results and is likely to produce optimistic estimates of future performance. Below we show the code that was used to extract the modeling details and results from the various algorithms we tried.

Be aware that the tidymodels ecosystem has additional functionality called workflowsets that further streamline the coding experience by combining many possible recipes and many modeling methods into a cohesive structure. That framework also allows for rigorous statistical comparisons to highlight which methods outperform others. If there is interest, a future project will show that code base.

Training

After tuning the model using the training data, we extract and save the results from the modeling objects. Even though it is not normally done, we’ve chosen to save the results from applying the model to the original training data as well as the testing data. In a normal workflow, we would fit the model and collect the metrics from the training data. Ultimately, we would apply that tuned model, with the best hyperparameters, to the the test data one time. See below for an explanation of the code.

Above we showed the complete code for KNN and it is repeated here:

# library(kknn)
# library(vip)
# library(ggthemes)
set.seed(305)

knn_spec <- 
  nearest_neighbor(
    neighbors = tune(), 
    weight_func = tune(), 
    dist_power = tune()
  ) |>
  set_engine("kknn") |>
  set_mode("classification")

knn_grid <-
  grid_latin_hypercube(
    neighbors(c(1, 50)),
    weight_func(),
    dist_power(),
    size = 50
  )

knn_workflow <-
  workflow() |>
  add_recipe(a_recipe) |>
  add_model(knn_spec)

# doParallel::registerDoParallel(cores = parallel::detectCores() - 2)
knn_tune_res <- 
  tune_grid(
    knn_workflow,
    resamples = a_fold,
    grid = knn_grid,
    control = control_grid(save_pred = TRUE)
  )
# doParallel::stopImplicitCluster()

collect_metrics(knn_tune_res)

knn_resample_best <- 
  show_best(knn_tune_res, metric = "roc_auc")
save(knn_resample_best, file = "data/knn_resample_best.RData")

knn_details <- 
  select_best(knn_tune_res, metric = "roc_auc")
knn_final <- 
  finalize_workflow(knn_workflow, knn_details)

knn_final_fit <-
  fit(knn_final, data = a_train)
save(knn_final_fit, file = "data/knn_final_fit.RData")

# Create a local data directory to hold modeling results
if (!dir.exists(paste0(here::here(), "/data"))) {
  dir.create(paste0(here::here(), "/data"))
}

knn_autoplot <-
  autoplot(knn_tune_res) +
  ggthemes::theme_few() +
  theme(text = element_text(size = 30))
save(knn_autoplot, file = "data/knn_autoplot.RData")

# Obtain model metrics
knn_metrics <-
  augment(knn_final_fit, new_data = a_train) |>
  mutate(estimate = .pred_0) |>
  roc_auc(truth = did_relapse, estimate)
save(knn_metrics, file = "data/knn_metrics.Rdata")

knn_conf_mat <-
  augment(knn_final_fit, new_data = a_train) |>
  conf_mat(truth = did_relapse, estimate = .pred_class)
save(knn_conf_mat, file = "data/knn_conf_mat.Rdata")


knn_metrics_test <-
  augment(knn_final_fit, new_data = a_test) |>
  mutate(estimate = .pred_0) |>
  roc_auc(truth = did_relapse, estimate)
save(knn_metrics_test, file = "data/knn_metrics_test.Rdata")

knn_conf_mat_test <-
  augment(knn_final_fit, new_data = a_test) |>
  conf_mat(truth = did_relapse, estimate = .pred_class)
save(knn_conf_mat, file = "data/knn_conf_mat_test.Rdata")

The collect_metrics() function from the tune package returns a tibble of the knn_tune_res, which contains the KNN tuning results. This allows us to inspect a summary of the model’s training results (like the accuracy and ROC AUC) using the resampled data.

The show_best(knn_tune_res, metric = "roc_auc") and select_best(knn_tune_res, metric = "roc_auc") functions are used to show and save the optimized KNN model, specifically, the top performing model and its performance estimates.

Next we can specify that the future work will use the tuned hyperparameters from our best performing model.

Saving the finalize_workflow(knn_workflow, knn_details) object is how we accept and choose the best performing model. Our knn_final object is the model that we will use to obtain performance metrics on the training and testing data.

Full training data

Finally, we will apply the model to our training data and evaluate its performance. Remember, the model was tuned using the resampled data, so this is the first instance of using the training data on the model.

knn_final_fit <-
  fit(knn_final, data = a_train)

knn_autoplot <-
  autoplot(knn_tune_res) +
  ggthemes::theme_few() +
  theme(text = element_text(size = 30))
save(knn_autoplot, file = "data/knn_autoplot.RData")

# Obtain model metrics
knn_metrics <-
  augment(knn_final_fit, new_data = a_train) |>
  mutate(estimate = .pred_0) |>
  roc_auc(truth = did_relapse, estimate)
save(knn_metrics, file = "data/knn_metrics.Rdata")

knn_conf_mat <-
  augment(knn_final_fit, new_data = a_train) |>
  conf_mat(truth = did_relapse, estimate = .pred_class)
save(knn_conf_mat, file = "data/knn_conf_mat.Rdata")

The fit() function requires that we pass our training data (a_train) to the best performing model (knn_final). The knn_final_fit object holds the parameter estimates for our model on the training data.

By using the augment() function, we return our training data with extra columns with a predicted probability for failure (.pred_0) and success (.pred_1), and a outcome 0/1 classification (.pred_class). This in turn is used to obtain the ROC metric and confusion matrices for the datasets.

# Obtain model ROC AUC on training data
knn_metrics <-
  augment(knn_final_fit, new_data = a_train) |>
  mutate(estimate = .pred_0) |>
  roc_auc(truth = did_relapse, estimate)

Notice that we further clean our augmented training data and create a column estimate that is the predicted probability of failure of treatment. If we were to instead use .pred_1 our ROC AUC will not be correct in demonstrating the model’s ability to predict failure of treatment. Knowing which level of outcome is important in your own prediction modeling.

autoplot() is used to render a plot to visualize how the accuracy, Brier class, and ROC AUC varies across the various combinations of hyperparameters on the resampled data by using the knn_tune_res object.

knn_autoplot <-
  autoplot(knn_tune_res) +
  ggthemes::theme_few() +
  theme(text = element_text(size = 30))
save(knn_autoplot, file = "data/knn_autoplot.RData")

Confusion matrix

Next we use the model and display how it performs using a confusion matrix.

# Create a confusion matrix of model predictive performance
knn_conf_mat <-
  augment(knn_final_fit, new_data = a_train) |>
  conf_mat(truth = did_relapse, estimate = .pred_class)
save(knn_conf_mat, file = "data/knn_conf_mat.Rdata")

Again, we use the augment function to obtain the new prediction probabilities and predicted class of the training data. Similar to the roc_auc function, the conf_mat function requires the declaration of the observed outcome (truth = did_relapse) though here we state that the estimate is the .pred_class column. This creates a 2x2 cross-tabulated matrix for the observed and predicted classes.

          Truth
Prediction    0    1
         0   87    1
         1  428 1342

The upper left cell is known as true negative where the model correctly predicted the outcome that the participant did not relapse. The diagonal lower right is the true positive cell representing the number of times the model correctly predicted those participants that did relapse. The off-diagonal cells are the incorrect predictions.

The KNN model correctly predicted the outcome for the 87 study participants that did not relapse, but incorrectly predicted the outcome for 429 total participants.

ROC

# Obtain model ROC AUC on training data
knn_metrics <-
  augment(knn_final_fit, new_data = a_train) |>
  mutate(estimate = .pred_0) |>
  roc_auc(truth = did_relapse, estimate)
save(knn_metrics, file = "data/knn_metrics.Rdata")

The roc_auc() function receives the training data with new predictions and we specify that the observed outcome (truth) is in the did_relapse column, but assess that against the prediction of estimate (the predicted probability of failure).

.metric .estimator .estimate
roc_auc binary 0.928

Our ROC AUC statistic is very optimistic (0.928) when applied to the training data. We will need to compare that against our cross-validated results and, later, against the testing data.

Test

As we have demonstrated above in the Confusion matrix and ROC sections, obtaining the results of the testing data is achieved in the same way as for obtaining those results on the training data, except now we supply new_data = a_test to the augment function. This allows us to apply the same final fitted model (in this case, knn_final_fit) to the testing data.

knn_metrics_test <-
  augment(knn_final_fit, new_data = a_test) |>
  mutate(estimate = .pred_0) |>
  roc_auc(truth = did_relapse, estimate)
save(knn_metrics_test, file = "data/knn_metrics_test.Rdata")

knn_conf_mat_test <-
  augment(knn_final_fit, new_data = a_test) |>
  conf_mat(truth = did_relapse, estimate = .pred_class)
save(knn_conf_mat, file = "data/knn_conf_mat_test.Rdata")
Important

You will notice that we used the same code to obtain test results as we had for training results, but only changed the new_data argument to use a_test, our test data. This is the beauty of tidymodels: reuse your code structures, but change only a few pieces.

References

1.
Cacciola JS, Alterman AI, McLellan AT, Lin Y-T, Lynch KG. Initial evidence for the reliability and validity of a Lite version of the Addiction Severity Index. Drug and Alcohol Dependence [Internet] 2007;87(2-3):297–302. Available from: http://dx.doi.org/10.1016/j.drugalcdep.2006.09.002
2.
Pan Y, Feaster DJ, Odom G, et al. Specific polysubstance use patterns predict relapse among patients entering opioid use disorder treatment. Drug and Alcohol Dependence Reports [Internet] 2022;5:100128. Available from: http://dx.doi.org/10.1016/j.dadr.2022.100128
3.
Kuhn M, Johnson K. Feature engineering and selection [Internet]. Chapman & Hall/CRC; 2019. Available from: http://dx.doi.org/10.1201/9781315108230
4.
Lee JD, Nunes EV, Novo P, et al. Comparative effectiveness of extended-release naltrexone versus buprenorphine-naloxone for opioid relapse prevention (x:BOT): A multicentre, open-label, randomised controlled trial. Lancet (London, England) [Internet] 2018;391(10118):309–18. Available from: https://www.ncbi.nlm.nih.gov/pmc/articles/PMC5806119/
5.
Odom G, Brandt L, Balise R, Bouzoubaa L. CTNote: CTN outcomes, treatments, and endpoints [Internet]. Available from: https://CRAN.R-project.org/package=CTNote
6.
Brandt L, Odom G, Hu M-C, Castro C, Balise R. Empirically contrasting opioid use disorder treatment outcome definitions. Under Review
7.
Wasserstein RL, Lazar NA. The ASA Statement on p-Values: Context, Process, and Purpose. The American Statistician [Internet] 2016;70(2):129–33. Available from: http://dx.doi.org/10.1080/00031305.2016.1154108
8.
Efron B, Tibshirani RJ. An introduction to the bootstrap [Internet]. Chapman & Hall/CRC; 1994. Available from: http://dx.doi.org/10.1201/9780429246593
9.
Balise R, Hu M-C, Calderon A, et al. Data cleaning and harmonization of clinical trial data: Medication-assisted treatment for opioid use disorder. PLOS ONE In Press;
10.
Agresti A. An introduction to categorical data analysis. Third edition. Hoboken, NJ: John Wiley & Sons; 2019.
11.
Swets JA. Measuring the Accuracy of Diagnostic Systems. Science [Internet] 1988;240(4857):1285–93. Available from: http://dx.doi.org/10.1126/science.3287615
12.
Sokolova M, Lapalme G. A systematic analysis of performance measures for classification tasks. Information Processing & Management [Internet] 2009;45(4):427–37. Available from: http://dx.doi.org/10.1016/j.ipm.2009.03.002
13.
Opitz J. A Closer Look at Classification Evaluation Metrics and a Critical Reflection of Common Evaluation Practice. Transactions of the Association for Computational Linguistics [Internet] 2024;12:820–36. Available from: http://dx.doi.org/10.1162/tacl_a_00675
14.
Canbek G, Taskaya Temizel T, Sagiroglu S. PToPI: A Comprehensive Review, Analysis, and Knowledge Representation of Binary Classification Performance Measures/Metrics. SN Computer Science [Internet] 2022;4(1). Available from: http://dx.doi.org/10.1007/s42979-022-01409-1
15.
Biecek P, Burzykowski T. Explanatory model analysis [Internet]. Chapman & Hall/CRC; 2021. Available from: http://dx.doi.org/10.1201/9780429027192
16.
Breiman L. Statistical modeling: The two cultures (with comments and a rejoinder by the author). Statistical Science [Internet] 2001;16(3). Available from: http://dx.doi.org/10.1214/ss/1009213726
17.
Molnar C. Modeling Mindsets: The many cultures of learning from data. Deutschland: Selbstverlag; 2022.
18.
Silge MK, Julia. A recommended preprocessing | tidy modeling with r [Internet]. O’Reilly Media; 2023. Available from: https://www.tmwr.org/pre-proc-table
19.
Chawla NV, Bowyer KW, Hall LO, Kegelmeyer WP. SMOTE: Synthetic minority over-sampling technique. Journal of Artificial Intelligence Research [Internet] 2002;16:321–57. Available from: http://dx.doi.org/10.1613/jair.953
20.
Menardi G, Torelli N. Training and assessing classification rules with imbalanced data. Data Mining and Knowledge Discovery [Internet] 2012;28(1):92–122. Available from: http://dx.doi.org/10.1007/s10618-012-0295-5
21.
Rahman MM, Davis DN. Addressing the class imbalance problem in medical datasets. International Journal of Machine Learning and Computing [Internet] 2013;224–8. Available from: http://dx.doi.org/10.7763/IJMLC.2013.V3.307
22.
James G, Witten D, Hastie T, Tibshirani R. An introduction to statistical learning [Internet]. Springer US; 2021. Available from: http://dx.doi.org/10.1007/978-1-0716-1418-1
23.
Smith G. Step away from stepwise. Journal of Big Data [Internet] 2018;5(1). Available from: http://dx.doi.org/10.1186/s40537-018-0143-6
24.
Greenwell B. Tree-based methods for statistical learning in R. First edition. Boca Raton London New York: CRC Press, Taylor & Francis Group; 2022.
25.
26.
Hill J, Linero A, Murray J. Bayesian Additive Regression Trees: A Review and Look Forward. Annual Review of Statistics and Its Application [Internet] 2020;7(1):251–78. Available from: http://dx.doi.org/10.1146/annurev-statistics-031219-041110
27.
Burkov A. The hundred-page machine learning book [Internet]. Andriy Burkov, 2019; 2019. Available from: https://themlbook.com/
28.
Kuhn M, Wickham H, Hvitfeldt E. Recipes: Preprocessing and feature engineering steps for modeling. 2023;Available from: https://CRAN.R-project.org/package=recipes
29.
Hastie T, Qian J, Tay K. An introduction to glmnet. 2023;Available from: https://glmnet.stanford.edu/articles/glmnet.html
30.
Kuhn M, Wickham H. Tidymodels: A collection of packages for modeling and machine learning using tidyverse principles. [Internet]. 2020. Available from: https://www.tidymodels.org

Footnotes

  1. excluding the extreme class imbalance scenario just described↩︎

  2. really the probability of↩︎

  3. I am using odd numbers to avoid ties when voting for prediction.↩︎

  4. Betas are the coefficients/numbers that describe the impact of each predictor in a model. For example, a regression model describing the number of needles exchanged in a harm reduction program could say the number of needles exchanged is increased by on average 2.5 if the person is a man. In that case, the beta is 2.5.↩︎