Load packages
library(rstan)
library(rstanarm)
library(loo)
library(bayesplot)
theme_set(bayesplot::theme_default())
library(projpred)
SEED=170701694
This notebook was inspired by Eric Novik’s slides “Deconstructing Stan Manual Part 1: Linear”. The idea is to demonstrate how easy it is to do good variable selection with rstanarm
, loo
, and projpred
.
In this notebook we illustrate Bayesian inference for model selection, including PSIS-LOO (Vehtari, Gelman and Gabry, 2017) and projection predictive approach (Piironen and Vehtari, 2017a; Piironen, Paasiniemi and Vehtari, 2020) which makes decision theoretically justified inference after model selection..
We use Wine quality data set from UCI Machine Learning repository
d <- read.delim("winequality-red.csv", sep = ";")
dim(d)
[1] 1599 12
Remove duplicated
d <- d[!duplicated(d), ] # remove the duplicates
(p <- ncol(d))
[1] 12
(n <- nrow(d))
[1] 1359
names(d)
[1] "fixed.acidity" "volatile.acidity" "citric.acid" "residual.sugar"
[5] "chlorides" "free.sulfur.dioxide" "total.sulfur.dioxide" "density"
[9] "pH" "sulphates" "alcohol" "quality"
prednames <- names(d)[1:(p-1)]
We scale the covariates so that when looking at the marginal posteriors for the effects they are on the same scale.
ds <- scale(d)
df <- as.data.frame(ds)
The rstanarm
package provides stan_glm
which accepts same arguments as glm
, but makes full Bayesian inference using Stan (mc-stan.org). By default a weakly informative Gaussian prior is used for weights.
formula <- formula(paste("quality ~", paste(prednames, collapse = " + ")))
fitg <- stan_glm(formula, data = df, QR=TRUE, seed=SEED, refresh=0)
Let’s look at the summary:
summary(fitg)
Model Info:
function: stan_glm
family: gaussian [identity]
formula: quality ~ fixed.acidity + volatile.acidity + citric.acid + residual.sugar +
chlorides + free.sulfur.dioxide + total.sulfur.dioxide +
density + pH + sulphates + alcohol
algorithm: sampling
sample: 4000 (posterior sample size)
priors: see help('prior_summary')
observations: 1359
predictors: 12
Estimates:
mean sd 10% 50% 90%
(Intercept) 0.0 0.0 0.0 0.0 0.0
fixed.acidity 0.0 0.1 -0.1 0.0 0.1
volatile.acidity -0.2 0.0 -0.3 -0.2 -0.2
citric.acid 0.0 0.0 -0.1 0.0 0.0
residual.sugar 0.0 0.0 0.0 0.0 0.0
chlorides -0.1 0.0 -0.1 -0.1 -0.1
free.sulfur.dioxide 0.0 0.0 0.0 0.0 0.1
total.sulfur.dioxide -0.1 0.0 -0.2 -0.1 -0.1
density 0.0 0.1 -0.1 0.0 0.0
pH -0.1 0.0 -0.1 -0.1 0.0
sulphates 0.2 0.0 0.2 0.2 0.2
alcohol 0.4 0.0 0.3 0.4 0.4
sigma 0.8 0.0 0.8 0.8 0.8
Fit Diagnostics:
mean sd 10% 50% 90%
mean_PPD 0.0 0.0 0.0 0.0 0.0
The mean_ppd is the sample average posterior predictive distribution of the outcome variable (for details see help('summary.stanreg')).
MCMC diagnostics
mcse Rhat n_eff
(Intercept) 0.0 1.0 5009
fixed.acidity 0.0 1.0 5523
volatile.acidity 0.0 1.0 5615
citric.acid 0.0 1.0 5884
residual.sugar 0.0 1.0 4943
chlorides 0.0 1.0 5330
free.sulfur.dioxide 0.0 1.0 5237
total.sulfur.dioxide 0.0 1.0 5496
density 0.0 1.0 4811
pH 0.0 1.0 5149
sulphates 0.0 1.0 5378
alcohol 0.0 1.0 5213
sigma 0.0 1.0 4864
mean_PPD 0.0 1.0 4282
log-posterior 0.1 1.0 1578
For each parameter, mcse is Monte Carlo standard error, n_eff is a crude measure of effective sample size, and Rhat is the potential scale reduction factor on split chains (at convergence Rhat=1).
We didn’t get divergences, Rhat’s are less than 1.1 and n_eff’s are useful (see, e.g., RStan workflow).
mcmc_areas(as.matrix(fitg), pars=prednames, prob_outer = .95)
Several 95% posterior intervals are not overlapping 0, so maybe there is something useful here.
In case of collinear variables it is possible that marginal posteriors overlap 0, but the covariates can still useful for prediction. With many variables it will be difficult to analyse joint posterior to see which variables are jointly relevant. We can easily test whether any of the covariates are useful by using cross-validation to compare to a null model,
fitg0 <- stan_glm(quality ~ 1, data = df, seed=SEED, refresh=0)
We use fast Pareto smoothed importance sampling leave-one-out cross-validation (Vehtari, Gelman and Gabry, 2017)
(loog <- loo(fitg))
Computed from 4000 by 1359 log-likelihood matrix
Estimate SE
elpd_loo -1635.7 30.6
p_loo 16.6 1.6
looic 3271.4 61.2
------
Monte Carlo SE of elpd_loo is 0.1.
Pareto k diagnostic values:
Count Pct. Min. n_eff
(-Inf, 0.5] (good) 1358 99.9% 1600
(0.5, 0.7] (ok) 1 0.1% 1270
(0.7, 1] (bad) 0 0.0% <NA>
(1, Inf) (very bad) 0 0.0% <NA>
All Pareto k estimates are ok (k < 0.7).
See help('pareto-k-diagnostic') for details.
(loog0 <- loo(fitg0))
Computed from 4000 by 1359 log-likelihood matrix
Estimate SE
elpd_loo -1929.9 28.3
p_loo 2.2 0.2
looic 3859.9 56.5
------
Monte Carlo SE of elpd_loo is 0.0.
All Pareto k estimates are good (k < 0.5).
See help('pareto-k-diagnostic') for details.
loo_compare(loog0, loog)
elpd_diff se_diff
fitg 0.0 0.0
fitg0 -294.3 22.9
Based on cross-validation covariates together have a high predictive power. If we need just the predictions we can stop here, but if we want to learn more about the relevance of the covariates we can continue with variable selection.
We make the projective predictive variable selection (Piironen and Vehtari, 2017a; Piironen, Paasiniemi and Vehtari, 2020) using projpred
package. A fast PSIS-LOO (Vehtari, Gelman and Gabry, 2017) is used to choose the model size.
fitg_cv <- cv_varsel(fitg, method='forward', cv_method='loo', n_loo=nrow(df))
We can now look at the estimated predictive performance of smaller models compared to the full model.
plot(fitg_cv, stats = c('elpd', 'rmse'))
Three or four variables seems to be needed to get the same performance as the full model. We can get a loo-cv based recommendation for the model size to choose.
(nsel <- suggest_size(fitg_cv, alpha=0.1))
[1] 4
(vsel <- solution_terms(fitg_cv)[1:nsel])
[1] "alcohol" "volatile.acidity" "sulphates" "chlorides"
projpred recommends to use four variables: alcohol, volatile.acidity, sulphates, and chlorides.
Next we form the projected posterior for the chosen model. This projected model can be used in the future to make predictions by using only the selected variables.
projg <- project(fitg_cv, nv = nsel, ns = 4000)
round(colMeans(as.matrix(projg)), 1)
Intercept alcohol volatile.acidity sulphates
0.0 0.4 -0.3 0.2
chlorides total.sulfur.dioxide sigma
-0.1 -0.1 0.8
round(posterior_interval(as.matrix(projg)), 1)
5% 95%
Intercept 0.0 0.0
alcohol 0.3 0.4
volatile.acidity -0.3 -0.2
sulphates 0.2 0.2
chlorides -0.1 -0.1
total.sulfur.dioxide -0.1 -0.1
sigma 0.8 0.8
The marginals of projected posteriors look like this.
mcmc_areas(as.matrix(projg), pars = vsel)
We also test regularized horseshoe prior (Piironen and Vehtari, 2017b) which has more prior mass near 0.
p0 <- 5 # prior guess for the number of relevant variables
tau0 <- p0/(p-p0) * 1/sqrt(n)
hs_prior <- hs(df=1, global_df=1, global_scale=tau0)
fitrhs <- stan_glm(formula, data = df, prior=hs_prior,
seed=SEED, refresh=0)
mcmc_areas(as.matrix(fitrhs), pars=prednames, prob_outer = .95)
Many of the variables are shrunk more towards 0, but still based on these marginals it is not as easy to select the most useful variables as it is with projpred.
The posteriors with normal and regularized horseshoe priors are clearly different, but does this have an effect to the predictions? In case of collinearity prior may have a strong effect on posterior, but a weak effect on posterior predictions. We can use loo to compare
(loorhs <- loo(fitrhs))
Computed from 4000 by 1359 log-likelihood matrix
Estimate SE
elpd_loo -1634.4 30.5
p_loo 14.1 1.3
looic 3268.7 61.0
------
Monte Carlo SE of elpd_loo is 0.1.
All Pareto k estimates are good (k < 0.5).
See help('pareto-k-diagnostic') for details.
loo_compare(loog, loorhs)
elpd_diff se_diff
fitrhs 0.0 0.0
fitg -1.3 1.4
There is no difference in predictive performance and thus we don’t need to repeat the projpred variable selection for the model with regularized horseshoe prior.
sessionInfo()
R version 4.1.2 (2021-11-01)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 20.04.3 LTS
Matrix products: default
BLAS/LAPACK: /opt/OpenBLAS/lib/libopenblas_haswellp-r0.3.17.so
locale:
[1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C LC_TIME=fi_FI.UTF-8
[4] LC_COLLATE=en_US.UTF-8 LC_MONETARY=fi_FI.UTF-8 LC_MESSAGES=en_US.UTF-8
[7] LC_PAPER=fi_FI.UTF-8 LC_NAME=C LC_ADDRESS=C
[10] LC_TELEPHONE=C LC_MEASUREMENT=fi_FI.UTF-8 LC_IDENTIFICATION=C
attached base packages:
[1] splines stats graphics grDevices utils datasets methods base
other attached packages:
[1] rstan_2.21.3 StanHeaders_2.21.0-7 arm_1.12-2 lme4_1.1-28
[5] Matrix_1.4-0 reliabilitydiag_0.2.0 MASS_7.3-55 corrplot_0.92
[9] caret_6.0-90 lattice_0.20-45 GGally_2.1.2 fivethirtyeight_0.6.2
[13] projpred_2.0.2 bayesplot_1.8.1 forcats_0.5.1 stringr_1.4.0
[17] dplyr_1.0.8 purrr_0.3.4 readr_2.1.2 tidyr_1.2.0
[21] tibble_3.1.6 ggplot2_3.3.5 tidyverse_1.3.1 loo_2.4.1
[25] rstanarm_2.21.1 Rcpp_1.0.8 rmarkdown_2.11 knitr_1.37
loaded via a namespace (and not attached):
[1] readxl_1.3.1 backports_1.4.1 plyr_1.8.6 igraph_1.2.11
[5] listenv_0.8.0 crosstalk_1.2.0 usethis_2.1.5 rstantools_2.1.1
[9] inline_0.3.19 digest_0.6.29 foreach_1.5.2 htmltools_0.5.2
[13] rsconnect_0.8.25 fansi_1.0.2 magrittr_2.0.2 checkmate_2.0.0
[17] memoise_2.0.1 tzdb_0.2.0 remotes_2.4.2 globals_0.14.0
[21] recipes_0.1.17 gower_1.0.0 modelr_0.1.8 RcppParallel_5.1.5
[25] matrixStats_0.61.0 xts_0.12.1 prettyunits_1.1.1 colorspace_2.0-2
[29] rvest_1.0.2 haven_2.4.3 xfun_0.29 callr_3.7.0
[33] crayon_1.4.2 jsonlite_1.7.3 iterators_1.0.14 survival_3.2-13
[37] zoo_1.8-9 glue_1.6.1 gtable_0.3.0 ipred_0.9-12
[41] distributional_0.3.0 pkgbuild_1.3.1 future.apply_1.8.1 abind_1.4-5
[45] scales_1.1.1 DBI_1.1.2 miniUI_0.1.1.1 progress_1.2.2
[49] xtable_1.8-4 lava_1.6.10 prodlim_2019.11.13 stats4_4.1.2
[53] DT_0.20 htmlwidgets_1.5.4 httr_1.4.2 threejs_0.3.3
[57] RColorBrewer_1.1-2 posterior_1.2.0 ellipsis_0.3.2 reshape_0.8.8
[61] pkgconfig_2.0.3 farver_2.1.0 nnet_7.3-17 sass_0.4.0
[65] dbplyr_2.1.1 utf8_1.2.2 labeling_0.4.2 tidyselect_1.1.1
[69] rlang_1.0.1 reshape2_1.4.4 later_1.3.0 munsell_0.5.0
[73] cellranger_1.1.0 tools_4.1.2 cachem_1.0.6 cli_3.1.1
[77] generics_0.1.2 devtools_2.4.3 broom_0.7.12 ggridges_0.5.3
[81] evaluate_0.14 fastmap_1.1.0 yaml_2.2.2 ModelMetrics_1.2.2.2
[85] processx_3.5.2 fs_1.5.2 future_1.23.0 nlme_3.1-155
[89] mime_0.12 xml2_1.3.3 brio_1.1.3 compiler_4.1.2
[93] shinythemes_1.2.0 rstudioapi_0.13 gamm4_0.2-6 testthat_3.1.2
[97] reprex_2.0.1 bslib_0.3.1 stringi_1.7.6 highr_0.9
[101] ps_1.6.0 desc_1.4.0 nloptr_1.2.2.3 markdown_1.1
[105] shinyjs_2.1.0 tensorA_0.36.2 vctrs_0.3.8 pillar_1.7.0
[109] lifecycle_1.0.1 jquerylib_0.1.4 data.table_1.14.2 httpuv_1.6.5
[113] R6_2.5.1 promises_1.2.0.1 gridExtra_2.3 parallelly_1.30.0
[117] sessioninfo_1.2.2 codetools_0.2-18 boot_1.3-28 colourpicker_1.1.1
[121] gtools_3.9.2 assertthat_0.2.1 pkgload_1.2.4 rprojroot_2.0.2
[125] withr_2.4.3 shinystan_2.5.0 mgcv_1.8-38 parallel_4.1.2
[129] hms_1.1.1 rpart_4.1.16 timeDate_3043.102 grid_4.1.2
[133] coda_0.19-4 class_7.3-20 minqa_1.2.4 pROC_1.18.0
[137] shiny_1.7.1 lubridate_1.8.0 base64enc_0.1-3 dygraphs_1.1.1.6