Setup

Load packages

library(rstan)
library(rstanarm)
library(loo)
library(bayesplot)
theme_set(bayesplot::theme_default())
library(projpred)
SEED=170701694

1 Introduction

This notebook was inspired by Eric Novik’s slides “Deconstructing Stan Manual Part 1: Linear”. The idea is to demonstrate how easy it is to do good variable selection with rstanarm, loo, and projpred.

In this notebook we illustrate Bayesian inference for model selection, including PSIS-LOO (Vehtari, Gelman and Gabry, 2017) and projection predictive approach (Piironen and Vehtari, 2017a; Piironen, Paasiniemi and Vehtari, 2020) which makes decision theoretically justified inference after model selection..

2 Wine quality data

We use Wine quality data set from UCI Machine Learning repository

d <- read.delim("winequality-red.csv", sep = ";")
dim(d)
[1] 1599   12

Remove duplicated

d <- d[!duplicated(d), ] # remove the duplicates
(p <- ncol(d))
[1] 12
(n <- nrow(d))
[1] 1359
names(d)
 [1] "fixed.acidity"        "volatile.acidity"     "citric.acid"          "residual.sugar"      
 [5] "chlorides"            "free.sulfur.dioxide"  "total.sulfur.dioxide" "density"             
 [9] "pH"                   "sulphates"            "alcohol"              "quality"             
prednames <- names(d)[1:(p-1)]

We scale the covariates so that when looking at the marginal posteriors for the effects they are on the same scale.

ds <- scale(d)
df <- as.data.frame(ds)

3 Fit regression model

The rstanarm package provides stan_glm which accepts same arguments as glm, but makes full Bayesian inference using Stan (mc-stan.org). By default a weakly informative Gaussian prior is used for weights.

formula <- formula(paste("quality ~", paste(prednames, collapse = " + ")))
fitg <- stan_glm(formula, data = df, QR=TRUE, seed=SEED, refresh=0)

Let’s look at the summary:

summary(fitg)

Model Info:
 function:     stan_glm
 family:       gaussian [identity]
 formula:      quality ~ fixed.acidity + volatile.acidity + citric.acid + residual.sugar + 
       chlorides + free.sulfur.dioxide + total.sulfur.dioxide + 
       density + pH + sulphates + alcohol
 algorithm:    sampling
 sample:       4000 (posterior sample size)
 priors:       see help('prior_summary')
 observations: 1359
 predictors:   12

Estimates:
                       mean   sd   10%   50%   90%
(Intercept)           0.0    0.0  0.0   0.0   0.0 
fixed.acidity         0.0    0.1 -0.1   0.0   0.1 
volatile.acidity     -0.2    0.0 -0.3  -0.2  -0.2 
citric.acid           0.0    0.0 -0.1   0.0   0.0 
residual.sugar        0.0    0.0  0.0   0.0   0.0 
chlorides            -0.1    0.0 -0.1  -0.1  -0.1 
free.sulfur.dioxide   0.0    0.0  0.0   0.0   0.1 
total.sulfur.dioxide -0.1    0.0 -0.2  -0.1  -0.1 
density               0.0    0.1 -0.1   0.0   0.0 
pH                   -0.1    0.0 -0.1  -0.1   0.0 
sulphates             0.2    0.0  0.2   0.2   0.2 
alcohol               0.4    0.0  0.3   0.4   0.4 
sigma                 0.8    0.0  0.8   0.8   0.8 

Fit Diagnostics:
           mean   sd   10%   50%   90%
mean_PPD 0.0    0.0  0.0   0.0   0.0  

The mean_ppd is the sample average posterior predictive distribution of the outcome variable (for details see help('summary.stanreg')).

MCMC diagnostics
                     mcse Rhat n_eff
(Intercept)          0.0  1.0  5009 
fixed.acidity        0.0  1.0  5523 
volatile.acidity     0.0  1.0  5615 
citric.acid          0.0  1.0  5884 
residual.sugar       0.0  1.0  4943 
chlorides            0.0  1.0  5330 
free.sulfur.dioxide  0.0  1.0  5237 
total.sulfur.dioxide 0.0  1.0  5496 
density              0.0  1.0  4811 
pH                   0.0  1.0  5149 
sulphates            0.0  1.0  5378 
alcohol              0.0  1.0  5213 
sigma                0.0  1.0  4864 
mean_PPD             0.0  1.0  4282 
log-posterior        0.1  1.0  1578 

For each parameter, mcse is Monte Carlo standard error, n_eff is a crude measure of effective sample size, and Rhat is the potential scale reduction factor on split chains (at convergence Rhat=1).

We didn’t get divergences, Rhat’s are less than 1.1 and n_eff’s are useful (see, e.g., RStan workflow).

mcmc_areas(as.matrix(fitg), pars=prednames, prob_outer = .95)

Several 95% posterior intervals are not overlapping 0, so maybe there is something useful here.

4 Cross-validation checking

In case of collinear variables it is possible that marginal posteriors overlap 0, but the covariates can still useful for prediction. With many variables it will be difficult to analyse joint posterior to see which variables are jointly relevant. We can easily test whether any of the covariates are useful by using cross-validation to compare to a null model,

fitg0 <- stan_glm(quality ~ 1, data = df, seed=SEED, refresh=0)

We use fast Pareto smoothed importance sampling leave-one-out cross-validation (Vehtari, Gelman and Gabry, 2017)

(loog <- loo(fitg))

Computed from 4000 by 1359 log-likelihood matrix

         Estimate   SE
elpd_loo  -1635.7 30.6
p_loo        16.6  1.6
looic      3271.4 61.2
------
Monte Carlo SE of elpd_loo is 0.1.

Pareto k diagnostic values:
                         Count Pct.    Min. n_eff
(-Inf, 0.5]   (good)     1358  99.9%   1600      
 (0.5, 0.7]   (ok)          1   0.1%   1270      
   (0.7, 1]   (bad)         0   0.0%   <NA>      
   (1, Inf)   (very bad)    0   0.0%   <NA>      

All Pareto k estimates are ok (k < 0.7).
See help('pareto-k-diagnostic') for details.
(loog0 <- loo(fitg0))

Computed from 4000 by 1359 log-likelihood matrix

         Estimate   SE
elpd_loo  -1929.9 28.3
p_loo         2.2  0.2
looic      3859.9 56.5
------
Monte Carlo SE of elpd_loo is 0.0.

All Pareto k estimates are good (k < 0.5).
See help('pareto-k-diagnostic') for details.
loo_compare(loog0, loog)
      elpd_diff se_diff
fitg     0.0       0.0 
fitg0 -294.3      22.9 

Based on cross-validation covariates together have a high predictive power. If we need just the predictions we can stop here, but if we want to learn more about the relevance of the covariates we can continue with variable selection.

5 Projection predictive variable selection

We make the projective predictive variable selection (Piironen and Vehtari, 2017a; Piironen, Paasiniemi and Vehtari, 2020) using projpred package. A fast PSIS-LOO (Vehtari, Gelman and Gabry, 2017) is used to choose the model size.

fitg_cv <- cv_varsel(fitg, method='forward', cv_method='loo', n_loo=nrow(df))

We can now look at the estimated predictive performance of smaller models compared to the full model.

plot(fitg_cv, stats = c('elpd', 'rmse'))

Three or four variables seems to be needed to get the same performance as the full model. We can get a loo-cv based recommendation for the model size to choose.

(nsel <- suggest_size(fitg_cv, alpha=0.1))
[1] 4
(vsel <- solution_terms(fitg_cv)[1:nsel])
[1] "alcohol"          "volatile.acidity" "sulphates"        "chlorides"       

projpred recommends to use four variables: alcohol, volatile.acidity, sulphates, and chlorides.

Next we form the projected posterior for the chosen model. This projected model can be used in the future to make predictions by using only the selected variables.

projg <- project(fitg_cv, nv = nsel, ns = 4000)
round(colMeans(as.matrix(projg)), 1)
           Intercept              alcohol     volatile.acidity            sulphates 
                 0.0                  0.4                 -0.3                  0.2 
           chlorides total.sulfur.dioxide                sigma 
                -0.1                 -0.1                  0.8 
round(posterior_interval(as.matrix(projg)), 1)
                       5%  95%
Intercept             0.0  0.0
alcohol               0.3  0.4
volatile.acidity     -0.3 -0.2
sulphates             0.2  0.2
chlorides            -0.1 -0.1
total.sulfur.dioxide -0.1 -0.1
sigma                 0.8  0.8

The marginals of projected posteriors look like this.

mcmc_areas(as.matrix(projg), pars = vsel)

6 Alternative regularized horseshoe prior

We also test regularized horseshoe prior (Piironen and Vehtari, 2017b) which has more prior mass near 0.

p0 <- 5 # prior guess for the number of relevant variables
tau0 <- p0/(p-p0) * 1/sqrt(n)
hs_prior <- hs(df=1, global_df=1, global_scale=tau0)
fitrhs <- stan_glm(formula, data = df, prior=hs_prior,
                   seed=SEED, refresh=0)
mcmc_areas(as.matrix(fitrhs), pars=prednames, prob_outer = .95)

Many of the variables are shrunk more towards 0, but still based on these marginals it is not as easy to select the most useful variables as it is with projpred.

The posteriors with normal and regularized horseshoe priors are clearly different, but does this have an effect to the predictions? In case of collinearity prior may have a strong effect on posterior, but a weak effect on posterior predictions. We can use loo to compare

(loorhs <- loo(fitrhs))

Computed from 4000 by 1359 log-likelihood matrix

         Estimate   SE
elpd_loo  -1634.4 30.5
p_loo        14.1  1.3
looic      3268.7 61.0
------
Monte Carlo SE of elpd_loo is 0.1.

All Pareto k estimates are good (k < 0.5).
See help('pareto-k-diagnostic') for details.
loo_compare(loog, loorhs)
       elpd_diff se_diff
fitrhs  0.0       0.0   
fitg   -1.3       1.4   

There is no difference in predictive performance and thus we don’t need to repeat the projpred variable selection for the model with regularized horseshoe prior.


References

Piironen, J., Paasiniemi, M. and Vehtari, A. (2020) ‘Projective inference in high-dimensional problems: Prediction and feature selection’, Electronic Journal of Statistics, 14(1), pp. 2155–2197.
Piironen, J. and Vehtari, A. (2017a) ‘Comparison of Bayesian predictive methods for model selection’, Statistics and Computing, 27(3), pp. 711–735. doi: 10.1007/s11222-016-9649-y.
Piironen, J. and Vehtari, A. (2017b) ‘Sparsity information and regularization in the horseshoe and other shrinkage priors’, Electronic journal of Statistics, 11(2), pp. 5018–5051. doi: 10.1214/17-EJS1337SI.
Vehtari, A., Gelman, A. and Gabry, J. (2017) ‘Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC, Statistics and Computing, 27(5), pp. 1413–1432. doi: 10.1007/s11222-016-9696-4.

Licenses

  • Code © 2017-2018, Aki Vehtari, licensed under BSD-3.
  • Text © 2017-2018, Aki Vehtari, licensed under CC-BY-NC 4.0.

Original Computing Environment

sessionInfo()
R version 4.1.2 (2021-11-01)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 20.04.3 LTS

Matrix products: default
BLAS/LAPACK: /opt/OpenBLAS/lib/libopenblas_haswellp-r0.3.17.so

locale:
 [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C               LC_TIME=fi_FI.UTF-8       
 [4] LC_COLLATE=en_US.UTF-8     LC_MONETARY=fi_FI.UTF-8    LC_MESSAGES=en_US.UTF-8   
 [7] LC_PAPER=fi_FI.UTF-8       LC_NAME=C                  LC_ADDRESS=C              
[10] LC_TELEPHONE=C             LC_MEASUREMENT=fi_FI.UTF-8 LC_IDENTIFICATION=C       

attached base packages:
[1] splines   stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] rstan_2.21.3          StanHeaders_2.21.0-7  arm_1.12-2            lme4_1.1-28          
 [5] Matrix_1.4-0          reliabilitydiag_0.2.0 MASS_7.3-55           corrplot_0.92        
 [9] caret_6.0-90          lattice_0.20-45       GGally_2.1.2          fivethirtyeight_0.6.2
[13] projpred_2.0.2        bayesplot_1.8.1       forcats_0.5.1         stringr_1.4.0        
[17] dplyr_1.0.8           purrr_0.3.4           readr_2.1.2           tidyr_1.2.0          
[21] tibble_3.1.6          ggplot2_3.3.5         tidyverse_1.3.1       loo_2.4.1            
[25] rstanarm_2.21.1       Rcpp_1.0.8            rmarkdown_2.11        knitr_1.37           

loaded via a namespace (and not attached):
  [1] readxl_1.3.1         backports_1.4.1      plyr_1.8.6           igraph_1.2.11       
  [5] listenv_0.8.0        crosstalk_1.2.0      usethis_2.1.5        rstantools_2.1.1    
  [9] inline_0.3.19        digest_0.6.29        foreach_1.5.2        htmltools_0.5.2     
 [13] rsconnect_0.8.25     fansi_1.0.2          magrittr_2.0.2       checkmate_2.0.0     
 [17] memoise_2.0.1        tzdb_0.2.0           remotes_2.4.2        globals_0.14.0      
 [21] recipes_0.1.17       gower_1.0.0          modelr_0.1.8         RcppParallel_5.1.5  
 [25] matrixStats_0.61.0   xts_0.12.1           prettyunits_1.1.1    colorspace_2.0-2    
 [29] rvest_1.0.2          haven_2.4.3          xfun_0.29            callr_3.7.0         
 [33] crayon_1.4.2         jsonlite_1.7.3       iterators_1.0.14     survival_3.2-13     
 [37] zoo_1.8-9            glue_1.6.1           gtable_0.3.0         ipred_0.9-12        
 [41] distributional_0.3.0 pkgbuild_1.3.1       future.apply_1.8.1   abind_1.4-5         
 [45] scales_1.1.1         DBI_1.1.2            miniUI_0.1.1.1       progress_1.2.2      
 [49] xtable_1.8-4         lava_1.6.10          prodlim_2019.11.13   stats4_4.1.2        
 [53] DT_0.20              htmlwidgets_1.5.4    httr_1.4.2           threejs_0.3.3       
 [57] RColorBrewer_1.1-2   posterior_1.2.0      ellipsis_0.3.2       reshape_0.8.8       
 [61] pkgconfig_2.0.3      farver_2.1.0         nnet_7.3-17          sass_0.4.0          
 [65] dbplyr_2.1.1         utf8_1.2.2           labeling_0.4.2       tidyselect_1.1.1    
 [69] rlang_1.0.1          reshape2_1.4.4       later_1.3.0          munsell_0.5.0       
 [73] cellranger_1.1.0     tools_4.1.2          cachem_1.0.6         cli_3.1.1           
 [77] generics_0.1.2       devtools_2.4.3       broom_0.7.12         ggridges_0.5.3      
 [81] evaluate_0.14        fastmap_1.1.0        yaml_2.2.2           ModelMetrics_1.2.2.2
 [85] processx_3.5.2       fs_1.5.2             future_1.23.0        nlme_3.1-155        
 [89] mime_0.12            xml2_1.3.3           brio_1.1.3           compiler_4.1.2      
 [93] shinythemes_1.2.0    rstudioapi_0.13      gamm4_0.2-6          testthat_3.1.2      
 [97] reprex_2.0.1         bslib_0.3.1          stringi_1.7.6        highr_0.9           
[101] ps_1.6.0             desc_1.4.0           nloptr_1.2.2.3       markdown_1.1        
[105] shinyjs_2.1.0        tensorA_0.36.2       vctrs_0.3.8          pillar_1.7.0        
[109] lifecycle_1.0.1      jquerylib_0.1.4      data.table_1.14.2    httpuv_1.6.5        
[113] R6_2.5.1             promises_1.2.0.1     gridExtra_2.3        parallelly_1.30.0   
[117] sessioninfo_1.2.2    codetools_0.2-18     boot_1.3-28          colourpicker_1.1.1  
[121] gtools_3.9.2         assertthat_0.2.1     pkgload_1.2.4        rprojroot_2.0.2     
[125] withr_2.4.3          shinystan_2.5.0      mgcv_1.8-38          parallel_4.1.2      
[129] hms_1.1.1            rpart_4.1.16         timeDate_3043.102    grid_4.1.2          
[133] coda_0.19-4          class_7.3-20         minqa_1.2.4          pROC_1.18.0         
[137] shiny_1.7.1          lubridate_1.8.0      base64enc_0.1-3      dygraphs_1.1.1.6