Load packages
library(rstanarm)
options(mc.cores = parallel::detectCores())
library(loo)
library(tidyverse)
library(bayesplot)
library(projpred)
library(fivethirtyeight)
SEED=150702646
This notebook was inspired by Joshua Loftus’ two blog posts Model selection bias invalidates significance tests and A conditional approach to inference after model selection.
In this notebook we illustrate Bayesian inference for model selection, including PSIS-LOO (Vehtari, Gelman and Gabry, 2017) and projection predictive approach (Piironen and Vehtari, 2017a; Piironen, Paasiniemi and Vehtari, 2020) which makes decision theoretically justified inference after model selection..
We use candy rankings data from fivethirtyeight package. Dataset was originally used in a fivethirtyeight story.
df <- candy_rankings %>%
select(-competitorname) %>%
mutate_if(is.logical, as.numeric)
head(df)
# A tibble: 6 × 12
chocolate fruity caramel peanutyalmondy nougat crispedricewafer hard bar pluribus
<dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 1 0 1 0 0 1 0 1 0
2 1 0 0 0 1 0 0 1 0
3 0 0 0 0 0 0 0 0 0
4 0 0 0 0 0 0 0 0 0
5 0 1 0 0 0 0 0 0 0
6 1 0 0 1 0 0 0 1 0
# … with 3 more variables: sugarpercent <dbl>, pricepercent <dbl>, winpercent <dbl>
We start first analysing a “null” data set, where winpercent has been replaced with random draws from a normal distribution so that covariates do not have any predictive information.
dfr <- df %>% select(-winpercent)
n <- nrow(dfr)
p <- ncol(dfr)
prednames <- colnames(dfr)
set.seed(SEED)
ry = rnorm(n)
dfr$ry <- ry
(reg_formula <- formula(paste("ry ~", paste(prednames, collapse = " + "))))
ry ~ chocolate + fruity + caramel + peanutyalmondy + nougat +
crispedricewafer + hard + bar + pluribus + sugarpercent +
pricepercent
The rstanarm
package provides stan_glm
which accepts same arguments as glm
, but makes full Bayesian inference using Stan (mc-stan.org). Doing variable selection we are anyway assuming that some of the variables are not relevant, and thus it is sensible to use priors which assume some of the covariate effects are close to zero. We use regularized horseshoe prior (Piironen and Vehtari, 2017b) which has lot of prior mass near 0, but also thick tails allowing relevant effects to not shrunk.
p0 <- 5 # prior guess for the number of relevant variables
tau0 <- p0/(p-p0) * 1/sqrt(n)
hs_prior <- hs(df=1, global_df=1, global_scale=tau0)
t_prior <- student_t(df = 7, location = 0, scale = 2.5)
fitrhs <- stan_glm(reg_formula, data = dfr,
prior = hs_prior, prior_intercept = t_prior,
seed=SEED, refresh=0)
Let’s look at the summary:
summary(fitrhs)
Model Info:
function: stan_glm
family: gaussian [identity]
formula: ry ~ chocolate + fruity + caramel + peanutyalmondy + nougat +
crispedricewafer + hard + bar + pluribus + sugarpercent +
pricepercent
algorithm: sampling
sample: 4000 (posterior sample size)
priors: see help('prior_summary')
observations: 85
predictors: 12
Estimates:
mean sd 10% 50% 90%
(Intercept) 0.0 0.2 -0.3 0.0 0.2
chocolate 0.0 0.1 -0.1 0.0 0.1
fruity 0.0 0.1 0.0 0.0 0.2
caramel -0.1 0.2 -0.3 0.0 0.0
peanutyalmondy 0.0 0.1 -0.1 0.0 0.1
nougat 0.0 0.2 -0.2 0.0 0.1
crispedricewafer 0.0 0.2 -0.2 0.0 0.1
hard 0.1 0.2 0.0 0.0 0.2
bar -0.1 0.2 -0.5 0.0 0.0
pluribus 0.0 0.1 -0.1 0.0 0.1
sugarpercent 0.0 0.1 -0.1 0.0 0.1
pricepercent -0.1 0.3 -0.4 0.0 0.0
sigma 1.1 0.1 1.0 1.1 1.2
Fit Diagnostics:
mean sd 10% 50% 90%
mean_PPD -0.1 0.2 -0.3 -0.1 0.2
The mean_ppd is the sample average posterior predictive distribution of the outcome variable (for details see help('summary.stanreg')).
MCMC diagnostics
mcse Rhat n_eff
(Intercept) 0.0 1.0 3896
chocolate 0.0 1.0 3157
fruity 0.0 1.0 3155
caramel 0.0 1.0 3225
peanutyalmondy 0.0 1.0 3961
nougat 0.0 1.0 3907
crispedricewafer 0.0 1.0 3754
hard 0.0 1.0 3315
bar 0.0 1.0 1989
pluribus 0.0 1.0 4290
sugarpercent 0.0 1.0 4808
pricepercent 0.0 1.0 2568
sigma 0.0 1.0 5380
mean_PPD 0.0 1.0 4791
log-posterior 0.1 1.0 1256
For each parameter, mcse is Monte Carlo standard error, n_eff is a crude measure of effective sample size, and Rhat is the potential scale reduction factor on split chains (at convergence Rhat=1).
We didn’t get divergences, Rhat’s are less than 1.1 and n_eff’s are useful (see, e.g., RStan workflow).
mcmc_areas(as.matrix(fitrhs), prob_outer = .95)
All 95% posterior intervals are overlapping 0, regularized horseshoe prior makes the posteriors concentrate near 0, but there is some uncertainty.
We can easily test whether any of the covariates are useful by using cross-validation to compare to a null model,
fit0 <- stan_glm(ry ~ 1, data = dfr, seed=SEED, refresh=0)
(loorhs <- loo(fitrhs))
Computed from 4000 by 85 log-likelihood matrix
Estimate SE
elpd_loo -130.2 6.3
p_loo 4.3 0.8
looic 260.4 12.7
------
Monte Carlo SE of elpd_loo is 0.0.
All Pareto k estimates are good (k < 0.5).
See help('pareto-k-diagnostic') for details.
(loo0 <- loo(fit0))
Computed from 4000 by 85 log-likelihood matrix
Estimate SE
elpd_loo -130.1 6.2
p_loo 1.8 0.4
looic 260.2 12.4
------
Monte Carlo SE of elpd_loo is 0.0.
All Pareto k estimates are good (k < 0.5).
See help('pareto-k-diagnostic') for details.
loo_compare(loo0, loorhs)
elpd_diff se_diff
fit0 0.0 0.0
fitrhs -0.1 1.1
Based on cross-validation covariates together do not contain any useful information, and there is no need to continue with variable selection. This step of checking whether full mode has any predictive power is often ignored especially when non-Bayesian methods are used. If loo (or AIC as Joshua Loftus demonstrated) would be used for stepwise variable selection it is possible that selection process over a large number of models overfits to the data.
To illustrate the robustness of projpred, we make the projective predictive variable selection using the previous model for “null” data. A fast leave-one-out cross-validation approach (Vehtari, Gelman and Gabry, 2017) is used to choose the model size.
fitrhs_cv <- cv_varsel(fitrhs, method='forward', cv_method='loo', n_loo=n)
fitrhs_cv$vind
NULL
We can now look at the estimated predictive performance of smaller models compared to the full model.
plot(fitrhs_cv, stats = c('elpd', 'rmse'))
We can see that the differences to the full model are very small.
And we get a LOO based recommendation for the model size to choose
(nv <- suggest_size(fitrhs_cv, alpha=0.1))
[1] 1
We see that projpred agrees that no variables have useful information.
Next we form the projected posterior for the chosen model.
projrhs <- project(fitrhs_cv, nv = nv, ns = 4000)
round(colMeans(as.matrix(projrhs)),1)
Intercept bar sigma
0.0 -0.3 1.1
round(posterior_interval(as.matrix(projrhs)),1)
5% 95%
Intercept -0.2 0.2
bar -0.7 0.1
sigma 1.1 1.2
This looks good as the true values for “null” data are intercept=0, sigma=1.
Next we repeat the above analysis with original target variable winpercent.
reg_formula <- formula(paste("winpercent ~", paste(prednames, collapse = " + ")))
p0 <- 5 # prior guess for the number of relevant variables
tau0 <- p0/(p-p0) * 1/sqrt(n)
hs_prior <- hs(df=1, global_df=1, global_scale=tau0)
t_prior <- student_t(df = 7, location = 0, scale = 2.5)
fitrhs <- stan_glm(reg_formula, data = df,
prior = hs_prior, prior_intercept = t_prior,
seed=SEED, refresh=0)
Let’s look at the summary:
summary(fitrhs)
Model Info:
function: stan_glm
family: gaussian [identity]
formula: winpercent ~ chocolate + fruity + caramel + peanutyalmondy +
nougat + crispedricewafer + hard + bar + pluribus + sugarpercent +
pricepercent
algorithm: sampling
sample: 4000 (posterior sample size)
priors: see help('prior_summary')
observations: 85
predictors: 12
Estimates:
mean sd 10% 50% 90%
(Intercept) 40.6 3.7 35.7 40.9 45.2
chocolate 13.5 3.8 8.8 13.5 18.3
fruity 1.9 3.1 -1.2 1.1 6.2
caramel 0.9 2.4 -1.6 0.5 4.1
peanutyalmondy 5.2 3.6 0.2 5.2 9.8
nougat 0.4 2.7 -2.7 0.1 3.8
crispedricewafer 3.3 3.7 -0.4 2.5 8.6
hard -2.6 2.9 -6.6 -2.1 0.4
bar 1.4 2.7 -1.4 0.8 5.1
pluribus -0.6 2.0 -3.2 -0.3 1.6
sugarpercent 3.8 3.8 -0.2 3.3 8.9
pricepercent 0.0 3.0 -3.5 0.0 3.5
sigma 11.2 1.0 10.0 11.2 12.5
Fit Diagnostics:
mean sd 10% 50% 90%
mean_PPD 50.1 1.7 47.9 50.1 52.3
The mean_ppd is the sample average posterior predictive distribution of the outcome variable (for details see help('summary.stanreg')).
MCMC diagnostics
mcse Rhat n_eff
(Intercept) 0.1 1.0 3717
chocolate 0.1 1.0 3672
fruity 0.1 1.0 3184
caramel 0.0 1.0 4573
peanutyalmondy 0.1 1.0 2504
nougat 0.0 1.0 4443
crispedricewafer 0.1 1.0 3666
hard 0.1 1.0 3173
bar 0.0 1.0 4432
pluribus 0.0 1.0 5030
sugarpercent 0.1 1.0 3575
pricepercent 0.0 1.0 5237
sigma 0.0 1.0 4174
mean_PPD 0.0 1.0 4982
log-posterior 0.2 1.0 1118
For each parameter, mcse is Monte Carlo standard error, n_eff is a crude measure of effective sample size, and Rhat is the potential scale reduction factor on split chains (at convergence Rhat=1).
We didn’t get divergences, Rhat’s are less than 1.1 and n_eff’s are useful.
mcmc_areas(as.matrix(fitrhs), prob_outer = .95)
95% posterior interval for chocolateTRUE
is not overlapping 0, so maybe there is something useful here.
In case of collinear variables it is possible that marginal posteriors overlap 0, but the covariates can still useful for prediction. With many variables it will be difficult to analyse joint posterior to see which variables are jointly relevant. We can easily test whether any of the covariates are useful by using cross-validation to compare to a null model,
fit0 <- stan_glm(winpercent ~ 1, data = df, seed=SEED, refresh=0)
(loorhs <- loo(fitrhs))
Computed from 4000 by 85 log-likelihood matrix
Estimate SE
elpd_loo -329.7 5.8
p_loo 7.9 1.1
looic 659.5 11.5
------
Monte Carlo SE of elpd_loo is 0.1.
All Pareto k estimates are good (k < 0.5).
See help('pareto-k-diagnostic') for details.
(loo0 <- loo(fit0))
Computed from 4000 by 85 log-likelihood matrix
Estimate SE
elpd_loo -350.6 5.5
p_loo 1.7 0.3
looic 701.2 11.1
------
Monte Carlo SE of elpd_loo is 0.0.
All Pareto k estimates are good (k < 0.5).
See help('pareto-k-diagnostic') for details.
loo_compare(loo0, loorhs)
elpd_diff se_diff
fitrhs 0.0 0.0
fit0 -20.9 4.5
Based on cross-validation covariates together do contain useful information. If we need just the predictions we can stop here, but if we want to learn more about the relevance of the covariates we can continue with variable selection.
We make the projective predictive variable selection using the previous model for “null” data. A fast leave-one-out cross-validation approach is used to choose the model size.
fitrhs_cv <- cv_varsel(fitrhs, method='forward', cv_method='loo', n_loo=n)
fitrhs_cv$vind
NULL
We can now look at the estimated predictive performance of smaller models compared to the full model.
plot(fitrhs_cv, stats = c('elpd', 'rmse'))
Only one variable seems to be needed to get the same performance as the full model.
And we get a LOO based recommendation for the model size to choose
(nsel <- suggest_size(fitrhs_cv, alpha=0.1))
[1] 1
(vsel <- solution_terms(fitrhs_cv)[1:nsel])
[1] "chocolate"
projpred recommends to use just one variable.
Next we form the projected posterior for the chosen model.
projrhs <- project(fitrhs_cv, nv = nsel, ns = 4000)
projdraws <- as.matrix(projrhs)
colnames(projdraws) <- c("Intercept",vsel,"sigma")
round(colMeans(projdraws),1)
Intercept chocolate sigma
43.0 16.3 12.0
round(posterior_interval(projdraws),1)
5% 95%
Intercept 40.5 45.7
chocolate 11.9 20.6
sigma 11.2 12.8
mcmc_areas(projdraws)
In our loo and projpred analysis, we find the chocolateTRUE
to have predictive information. Other variables may have predictive power, too, but conditionally on chocolateTRUE
other variables do not provide additional information.
sessionInfo()
R version 4.1.2 (2021-11-01)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 20.04.3 LTS
Matrix products: default
BLAS/LAPACK: /opt/OpenBLAS/lib/libopenblas_haswellp-r0.3.17.so
locale:
[1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C LC_TIME=fi_FI.UTF-8
[4] LC_COLLATE=en_US.UTF-8 LC_MONETARY=fi_FI.UTF-8 LC_MESSAGES=en_US.UTF-8
[7] LC_PAPER=fi_FI.UTF-8 LC_NAME=C LC_ADDRESS=C
[10] LC_TELEPHONE=C LC_MEASUREMENT=fi_FI.UTF-8 LC_IDENTIFICATION=C
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] fivethirtyeight_0.6.2 projpred_2.0.2 bayesplot_1.8.1 forcats_0.5.1
[5] stringr_1.4.0 dplyr_1.0.8 purrr_0.3.4 readr_2.1.2
[9] tidyr_1.2.0 tibble_3.1.6 ggplot2_3.3.5 tidyverse_1.3.1
[13] loo_2.4.1 rstanarm_2.21.1 Rcpp_1.0.8 rmarkdown_2.11
[17] knitr_1.37
loaded via a namespace (and not attached):
[1] readxl_1.3.1 backports_1.4.1 plyr_1.8.6 igraph_1.2.11
[5] splines_4.1.2 crosstalk_1.2.0 usethis_2.1.5 rstantools_2.1.1
[9] inline_0.3.19 digest_0.6.29 htmltools_0.5.2 rsconnect_0.8.25
[13] fansi_1.0.2 magrittr_2.0.2 checkmate_2.0.0 memoise_2.0.1
[17] tzdb_0.2.0 remotes_2.4.2 modelr_0.1.8 RcppParallel_5.1.5
[21] matrixStats_0.61.0 xts_0.12.1 prettyunits_1.1.1 colorspace_2.0-2
[25] rvest_1.0.2 haven_2.4.3 xfun_0.29 callr_3.7.0
[29] crayon_1.4.2 jsonlite_1.7.3 lme4_1.1-28 survival_3.2-13
[33] zoo_1.8-9 glue_1.6.1 gtable_0.3.0 distributional_0.3.0
[37] pkgbuild_1.3.1 rstan_2.21.3 abind_1.4-5 scales_1.1.1
[41] DBI_1.1.2 miniUI_0.1.1.1 xtable_1.8-4 stats4_4.1.2
[45] StanHeaders_2.21.0-7 DT_0.20 htmlwidgets_1.5.4 httr_1.4.2
[49] threejs_0.3.3 RColorBrewer_1.1-2 posterior_1.2.0 ellipsis_0.3.2
[53] pkgconfig_2.0.3 farver_2.1.0 sass_0.4.0 dbplyr_2.1.1
[57] utf8_1.2.2 labeling_0.4.2 tidyselect_1.1.1 rlang_1.0.1
[61] reshape2_1.4.4 later_1.3.0 munsell_0.5.0 cellranger_1.1.0
[65] tools_4.1.2 cachem_1.0.6 cli_3.1.1 generics_0.1.2
[69] devtools_2.4.3 broom_0.7.12 ggridges_0.5.3 evaluate_0.14
[73] fastmap_1.1.0 yaml_2.2.2 processx_3.5.2 fs_1.5.2
[77] nlme_3.1-155 mime_0.12 xml2_1.3.3 brio_1.1.3
[81] compiler_4.1.2 shinythemes_1.2.0 rstudioapi_0.13 gamm4_0.2-6
[85] testthat_3.1.2 reprex_2.0.1 bslib_0.3.1 stringi_1.7.6
[89] highr_0.9 ps_1.6.0 desc_1.4.0 lattice_0.20-45
[93] Matrix_1.4-0 nloptr_1.2.2.3 markdown_1.1 shinyjs_2.1.0
[97] tensorA_0.36.2 vctrs_0.3.8 pillar_1.7.0 lifecycle_1.0.1
[101] jquerylib_0.1.4 httpuv_1.6.5 R6_2.5.1 promises_1.2.0.1
[105] gridExtra_2.3 sessioninfo_1.2.2 codetools_0.2-18 boot_1.3-28
[109] colourpicker_1.1.1 MASS_7.3-55 gtools_3.9.2 assertthat_0.2.1
[113] pkgload_1.2.4 rprojroot_2.0.2 withr_2.4.3 shinystan_2.5.0
[117] mgcv_1.8-38 parallel_4.1.2 hms_1.1.1 grid_4.1.2
[121] minqa_1.2.4 shiny_1.7.1 lubridate_1.8.0 base64enc_0.1-3
[125] dygraphs_1.1.1.6