Setup

Load packages - Requires projpred 1.0.1+

library(rstanarm)
library(arm)
options(mc.cores = parallel::detectCores())
library(loo)
library(ggplot2)
library(bayesplot)
theme_set(bayesplot::theme_default(base_family = "sans"))
library(GGally)
library(projpred)

1 Introduction

This notebook demonstrates collinearity in multipredictor regression. Example of predicting the yields of mesquite bushes comes from Gelman and Hill (2007). The outcome variable is the total weight (in grams) of photosynthetic material as derived from actual harvesting of the bush. The predictor variables are:

  • diam1: diameter of the canopy (the leafy area of the bush) in meters, measured along the longer axis of the bush
  • diam2: canopy diameter measured along the shorter axis
  • canopy height: height of the canopy
  • total height: total height of the bush
  • density: plant unit density (# of primary stems per plant unit)
  • group: group of measurements (0 for the first group, 1 for the second group)

2 Data

dat <- read.table("mesquite.dat", header=TRUE) %>% 
  mutate_if(is.character, as.factor) %>%
  mutate_if(is.factor, as.numeric)
summary(dat)
      Obs            Group           Diam1           Diam2           TotHt      
 Min.   : 1.00   Min.   :1.000   Min.   :0.800   Min.   :0.400   Min.   :0.650  
 1st Qu.:12.25   1st Qu.:1.000   1st Qu.:1.400   1st Qu.:1.000   1st Qu.:1.200  
 Median :23.50   Median :2.000   Median :1.950   Median :1.525   Median :1.500  
 Mean   :23.50   Mean   :1.565   Mean   :2.099   Mean   :1.572   Mean   :1.482  
 3rd Qu.:34.75   3rd Qu.:2.000   3rd Qu.:2.475   3rd Qu.:1.900   3rd Qu.:1.700  
 Max.   :46.00   Max.   :2.000   Max.   :5.200   Max.   :4.000   Max.   :3.000  
     CanHt             Dens           LeafWt      
 Min.   :0.5000   Min.   :1.000   Min.   :  60.2  
 1st Qu.:0.8625   1st Qu.:1.000   1st Qu.: 219.6  
 Median :1.1000   Median :1.000   Median : 361.9  
 Mean   :1.1107   Mean   :1.674   Mean   : 559.7  
 3rd Qu.:1.3000   3rd Qu.:2.000   3rd Qu.: 688.7  
 Max.   :2.5000   Max.   :9.000   Max.   :4052.0  

Plot data

ggpairs(dat, diag=list(continuous="barDiag"))

Additional transformed variables

dat$CanVol <- dat$Diam1 * dat$Diam2 * dat$CanHt
dat$CanAre <- dat$Diam1 * dat$Diam2
dat$CanSha <- dat$Diam1 / dat$Diam2

It may be reasonable to fit on the logarithmic scale, so that effects are multiplicative rather than additive (we’ll return to checking this assumption in another notebook).

3 Maxiumum likelihood estimate

We first illustrate the problem with maxiumum likelihood estimate

lm1 <- lm(formula = log(LeafWt) ~ log(CanVol) + log(CanAre) + log(CanSha) + log(TotHt) + log(Dens) + Group, data = dat)
display(lm1)
lm(formula = log(LeafWt) ~ log(CanVol) + log(CanAre) + log(CanSha) + 
    log(TotHt) + log(Dens) + Group, data = dat)
            coef.est coef.se
(Intercept)  4.18     0.23  
log(CanVol)  0.37     0.28  
log(CanAre)  0.40     0.29  
log(CanSha) -0.38     0.23  
log(TotHt)   0.39     0.31  
log(Dens)    0.11     0.12  
Group        0.58     0.13  
---
n = 46, k = 7
residual sd = 0.33, R-Squared = 0.89

GroupMCD seems to be only variable which has coeffiecent far away from zero. Let’s try making a model with just the group variable.

lm2 <- lm(formula = log(LeafWt) ~ Group, data = dat)
display(lm2)
lm(formula = log(LeafWt) ~ Group, data = dat)
            coef.est coef.se
(Intercept)  6.49     0.44  
Group       -0.36     0.27  
---
n = 46, k = 2
residual sd = 0.91, R-Squared = 0.04

Hmmm…. R-squared dropped a lot, so it seems that other variables are useful even if estimated effects and their standard errors indicate that they are not relevant. There are approach for maximum likelihood estimated models to investigate this, but we’ll switch now to Bayesian inference using rstanarm.

4 Bayesian inference

The corresponding rstanarm model fit using stan_glm

fitg <- stan_glm(formula = log(LeafWt) ~ log(CanVol) + log(CanAre) + log(CanSha) + log(TotHt) + log(Dens) + Group, data = dat, refresh=0)

Print summary for some diagnostics.

summary(fitg)

Model Info:
 function:     stan_glm
 family:       gaussian [identity]
 formula:      log(LeafWt) ~ log(CanVol) + log(CanAre) + log(CanSha) + log(TotHt) + 
       log(Dens) + Group
 algorithm:    sampling
 sample:       4000 (posterior sample size)
 priors:       see help('prior_summary')
 observations: 46
 predictors:   7

Estimates:
              mean   sd   10%   50%   90%
(Intercept)  4.2    0.2  3.9   4.2   4.5 
log(CanVol)  0.4    0.3  0.0   0.4   0.7 
log(CanAre)  0.4    0.3  0.0   0.4   0.8 
log(CanSha) -0.4    0.2 -0.7  -0.4  -0.1 
log(TotHt)   0.4    0.3  0.0   0.4   0.8 
log(Dens)    0.1    0.1 -0.1   0.1   0.3 
Group        0.6    0.1  0.4   0.6   0.8 
sigma        0.3    0.0  0.3   0.3   0.4 

Fit Diagnostics:
           mean   sd   10%   50%   90%
mean_PPD 5.9    0.1  5.8   5.9   6.0  

The mean_ppd is the sample average posterior predictive distribution of the outcome variable (for details see help('summary.stanreg')).

MCMC diagnostics
              mcse Rhat n_eff
(Intercept)   0.0  1.0  3727 
log(CanVol)   0.0  1.0  1364 
log(CanAre)   0.0  1.0  1458 
log(CanSha)   0.0  1.0  2571 
log(TotHt)    0.0  1.0  1647 
log(Dens)     0.0  1.0  2702 
Group         0.0  1.0  2077 
sigma         0.0  1.0  2738 
mean_PPD      0.0  1.0  3458 
log-posterior 0.1  1.0  1517 

For each parameter, mcse is Monte Carlo standard error, n_eff is a crude measure of effective sample size, and Rhat is the potential scale reduction factor on split chains (at convergence Rhat=1).

Rhats and n_effs are good (see, e.g., RStan workflow), but QR transformation usually makes sampling work even better (see, The QR Decomposition For Regression Models)

Print summary for some diagnostics.

summary(fitg)

Model Info:
 function:     stan_glm
 family:       gaussian [identity]
 formula:      log(LeafWt) ~ log(CanVol) + log(CanAre) + log(CanSha) + log(TotHt) + 
       log(Dens) + Group
 algorithm:    sampling
 sample:       4000 (posterior sample size)
 priors:       see help('prior_summary')
 observations: 46
 predictors:   7

Estimates:
              mean   sd   10%   50%   90%
(Intercept)  4.2    0.2  3.9   4.2   4.5 
log(CanVol)  0.4    0.3  0.0   0.4   0.7 
log(CanAre)  0.4    0.3  0.0   0.4   0.8 
log(CanSha) -0.4    0.2 -0.7  -0.4  -0.1 
log(TotHt)   0.4    0.3  0.0   0.4   0.8 
log(Dens)    0.1    0.1 -0.1   0.1   0.3 
Group        0.6    0.1  0.4   0.6   0.8 
sigma        0.3    0.0  0.3   0.3   0.4 

Fit Diagnostics:
           mean   sd   10%   50%   90%
mean_PPD 5.9    0.1  5.8   5.9   6.0  

The mean_ppd is the sample average posterior predictive distribution of the outcome variable (for details see help('summary.stanreg')).

MCMC diagnostics
              mcse Rhat n_eff
(Intercept)   0.0  1.0  3727 
log(CanVol)   0.0  1.0  1364 
log(CanAre)   0.0  1.0  1458 
log(CanSha)   0.0  1.0  2571 
log(TotHt)    0.0  1.0  1647 
log(Dens)     0.0  1.0  2702 
Group         0.0  1.0  2077 
sigma         0.0  1.0  2738 
mean_PPD      0.0  1.0  3458 
log-posterior 0.1  1.0  1517 

For each parameter, mcse is Monte Carlo standard error, n_eff is a crude measure of effective sample size, and Rhat is the potential scale reduction factor on split chains (at convergence Rhat=1).

Use of QR decomposition improved sampling efficiency (actually we get superefficient sampling, ie better than independent sampling) and we continue with this model.

Instead of looking at the tables, it’s easier to look at plots

mcmc_areas(as.matrix(fitg), prob = .5, prob_outer = .95)

All 95% posterior intervals except for GroupMCD are overlapping 0 and it seems we have serious collinearity problem.

Looking at the pairwise posteriors we can see high correlations especially between log(CanVol) and log(CanAre).

mcmc_pairs(as.matrix(fitg),pars = c("log(CanVol)","log(CanAre)","log(CanSha)","log(TotHt)","log(Dens)"))

If look more carefully on of the subplots, we see that although marginal posterior intervals overlap 0, some pairwise joint posteriors are not overlapping 0. Let’s look more carefully the joint posterior of log(CanVol) and log(CanAre).

mcmc_scatter(as.matrix(fitg), pars = c("log(CanVol)","log(CanAre)")) +
  geom_vline(xintercept=0) +
  geom_hline(yintercept=0)

From the joint posterior scatter plot, we can see that 0 is far away fron the typical set.

In case of even more variables with some being relevant and some irrelevant, it will be difficult to analyse joint posterior to see which variables are jointly relevant. We can easily test whether any of the covariates are useful by using cross-validation to compare to a null model,

fitg0 <- update(fitg, formula = log(LeafWt) ~ 1, QR=FALSE)

We compute leave-one-out cross-validation elpd’s using PSIS-LOO (Vehtari, Gelman and Gabry, 2017)

(loog <- loo(fitg))

Computed from 4000 by 46 log-likelihood matrix

         Estimate   SE
elpd_loo    -19.4  5.4
p_loo         7.6  1.6
looic        38.8 10.8
------
Monte Carlo SE of elpd_loo is NA.

Pareto k diagnostic values:
                         Count Pct.    Min. n_eff
(-Inf, 0.5]   (good)     43    93.5%   811       
 (0.5, 0.7]   (ok)        2     4.3%   321       
   (0.7, 1]   (bad)       1     2.2%   324       
   (1, Inf)   (very bad)  0     0.0%   <NA>      
See help('pareto-k-diagnostic') for details.
(loog0 <- loo(fitg0))

Computed from 4000 by 46 log-likelihood matrix

         Estimate   SE
elpd_loo    -62.6  5.0
p_loo         1.9  0.5
looic       125.1 10.0
------
Monte Carlo SE of elpd_loo is 0.0.

All Pareto k estimates are good (k < 0.5).
See help('pareto-k-diagnostic') for details.

The model with variables has one bad Pareto \(k\) value (Vehtari, Gelman and Gabry, 2017). We can fix that by computing the corresponding leave-one-out-posterior exactly (Vehtari, Gelman and Gabry, 2017).

(loog <- loo(fitg, k_threshold=0.7))

Computed from 4000 by 46 log-likelihood matrix

         Estimate   SE
elpd_loo    -19.4  5.4
p_loo         7.6  1.6
looic        38.7 10.8
------
Monte Carlo SE of elpd_loo is 0.1.

Pareto k diagnostic values:
                         Count Pct.    Min. n_eff
(-Inf, 0.5]   (good)     43    95.6%   811       
 (0.5, 0.7]   (ok)        2     4.4%   321       
   (0.7, 1]   (bad)       0     0.0%   <NA>      
   (1, Inf)   (very bad)  0     0.0%   <NA>      

All Pareto k estimates are ok (k < 0.7).
See help('pareto-k-diagnostic') for details.

And then we can compare the models.

loo_compare(loog0, loog)
      elpd_diff se_diff
fitg    0.0       0.0  
fitg0 -43.2       7.0  

Based on cross-validation covariates together contain significant information to improve predictions.

We might want to choose some variables 1) because we don’t want to observe all the variables in the future (e.g. due to the measurement cost), or 2) we want to most relevant variables which we define here as a minimal set of variables which can provide similar predictions to the full model.

LOO can be used for model selection, but we don’t recommend it for variable selection as discussed by Piironen and Vehtari (2017). The reason for not using LOO in variable selection is that the selection process uses the data twice, and in case of large number variable combinations the selection process overfits and can produce really bad models. Using the usual posterior inference given the selected variables ignores that the selected variables are conditonal on the selection process and simply setting some variables to 0 ignores the uncertainty related to their relevance.

Piironen and Vehtari (2017) also show that a projection predictive approach can be used to make a model reduction, that is, choosing a smaller model with some coefficients set to 0. The projection predictive approach solves the problem how to do inference after the selection. The solution is to project the full model posterior to the restricted subspace. See more by Piironen, Paasiniemi and Vehtari (2020).

We make the projective predictive variable selection using the previous full model. A fast leave-one-out cross-validation approach (Vehtari, Gelman and Gabry, 2017) is used to choose the model size.

fitg_cv <- cv_varsel(fitg, method='forward', cv_method='LOO', nloo = nrow(dat))

We can now look at the estimated predictive performance of smaller models compared to the full model.

plot(fitg_cv, stats = c('elpd', 'rmse'))

And we get a loo-cv based recommendation for the model size to choose

(nsel <- suggest_size(fitg_cv, alpha=0.1))
[1] 3
(vsel <- solution_terms(fitg_cv)[1:nsel])
[1] "log(CanVol)" "Group"       "log(CanSha)"

We see that 3 variables is enough to get the same predictive accuracy as with all 4 variables.

Next we form the projected posterior for the chosen model.

projg <- project(fitg_cv, nv = nsel, ns = 4000)
projdraws <- as.matrix(projg)
round(colMeans(projdraws),1)
  Intercept log(CanVol)       Group log(CanSha)       sigma 
        4.3         0.8         0.6        -0.4         0.4 
round(posterior_interval(projdraws),1)
              5% 95%
Intercept    3.9 4.7
log(CanVol)  0.7 0.9
Group        0.4 0.8
log(CanSha) -0.8 0.0
sigma        0.3 0.4
mcmc_areas(projdraws, pars=c("Intercept",vsel), prob_outer=0.99)


References

Piironen, J., Paasiniemi, M. and Vehtari, A. (2020) ‘Projective inference in high-dimensional problems: Prediction and feature selection’, Electronic Journal of Statistics, 14(1), pp. 2155–2197.
Piironen, J. and Vehtari, A. (2017) ‘Comparison of Bayesian predictive methods for model selection’, Statistics and Computing, 27(3), pp. 711–735. doi: 10.1007/s11222-016-9649-y.
Vehtari, A., Gelman, A. and Gabry, J. (2017) ‘Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC, Statistics and Computing, 27(5), pp. 1413–1432. doi: 10.1007/s11222-016-9696-4.

Licenses

  • Code © 2018, Aki Vehtari, licensed under BSD-3.
  • Text © 2018, Aki Vehtari, licensed under CC-BY-NC 4.0.

Original Computing Environment

sessionInfo()
R version 4.1.2 (2021-11-01)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 20.04.3 LTS

Matrix products: default
BLAS/LAPACK: /opt/OpenBLAS/lib/libopenblas_haswellp-r0.3.17.so

locale:
 [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C               LC_TIME=fi_FI.UTF-8       
 [4] LC_COLLATE=en_US.UTF-8     LC_MONETARY=fi_FI.UTF-8    LC_MESSAGES=en_US.UTF-8   
 [7] LC_PAPER=fi_FI.UTF-8       LC_NAME=C                  LC_ADDRESS=C              
[10] LC_TELEPHONE=C             LC_MEASUREMENT=fi_FI.UTF-8 LC_IDENTIFICATION=C       

attached base packages:
[1] splines   stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] arm_1.12-2            lme4_1.1-28           Matrix_1.4-0          reliabilitydiag_0.2.0
 [5] MASS_7.3-55           corrplot_0.92         caret_6.0-90          lattice_0.20-45      
 [9] GGally_2.1.2          fivethirtyeight_0.6.2 projpred_2.0.2        bayesplot_1.8.1      
[13] forcats_0.5.1         stringr_1.4.0         dplyr_1.0.8           purrr_0.3.4          
[17] readr_2.1.2           tidyr_1.2.0           tibble_3.1.6          ggplot2_3.3.5        
[21] tidyverse_1.3.1       loo_2.4.1             rstanarm_2.21.1       Rcpp_1.0.8           
[25] rmarkdown_2.11        knitr_1.37           

loaded via a namespace (and not attached):
  [1] readxl_1.3.1         backports_1.4.1      plyr_1.8.6           igraph_1.2.11       
  [5] listenv_0.8.0        crosstalk_1.2.0      usethis_2.1.5        rstantools_2.1.1    
  [9] inline_0.3.19        digest_0.6.29        foreach_1.5.2        htmltools_0.5.2     
 [13] rsconnect_0.8.25     fansi_1.0.2          magrittr_2.0.2       checkmate_2.0.0     
 [17] memoise_2.0.1        tzdb_0.2.0           remotes_2.4.2        globals_0.14.0      
 [21] recipes_0.1.17       gower_1.0.0          modelr_0.1.8         RcppParallel_5.1.5  
 [25] matrixStats_0.61.0   xts_0.12.1           prettyunits_1.1.1    colorspace_2.0-2    
 [29] rvest_1.0.2          haven_2.4.3          xfun_0.29            callr_3.7.0         
 [33] crayon_1.4.2         jsonlite_1.7.3       iterators_1.0.14     survival_3.2-13     
 [37] zoo_1.8-9            glue_1.6.1           gtable_0.3.0         ipred_0.9-12        
 [41] distributional_0.3.0 pkgbuild_1.3.1       rstan_2.21.3         future.apply_1.8.1  
 [45] abind_1.4-5          scales_1.1.1         DBI_1.1.2            miniUI_0.1.1.1      
 [49] progress_1.2.2       xtable_1.8-4         lava_1.6.10          prodlim_2019.11.13  
 [53] stats4_4.1.2         StanHeaders_2.21.0-7 DT_0.20              htmlwidgets_1.5.4   
 [57] httr_1.4.2           threejs_0.3.3        RColorBrewer_1.1-2   posterior_1.2.0     
 [61] ellipsis_0.3.2       reshape_0.8.8        pkgconfig_2.0.3      farver_2.1.0        
 [65] nnet_7.3-17          sass_0.4.0           dbplyr_2.1.1         utf8_1.2.2          
 [69] labeling_0.4.2       tidyselect_1.1.1     rlang_1.0.1          reshape2_1.4.4      
 [73] later_1.3.0          munsell_0.5.0        cellranger_1.1.0     tools_4.1.2         
 [77] cachem_1.0.6         cli_3.1.1            generics_0.1.2       devtools_2.4.3      
 [81] broom_0.7.12         ggridges_0.5.3       evaluate_0.14        fastmap_1.1.0       
 [85] yaml_2.2.2           ModelMetrics_1.2.2.2 processx_3.5.2       fs_1.5.2            
 [89] future_1.23.0        nlme_3.1-155         mime_0.12            xml2_1.3.3          
 [93] brio_1.1.3           compiler_4.1.2       shinythemes_1.2.0    rstudioapi_0.13     
 [97] gamm4_0.2-6          testthat_3.1.2       reprex_2.0.1         bslib_0.3.1         
[101] stringi_1.7.6        highr_0.9            ps_1.6.0             desc_1.4.0          
[105] nloptr_1.2.2.3       markdown_1.1         shinyjs_2.1.0        tensorA_0.36.2      
[109] vctrs_0.3.8          pillar_1.7.0         lifecycle_1.0.1      jquerylib_0.1.4     
[113] data.table_1.14.2    httpuv_1.6.5         R6_2.5.1             promises_1.2.0.1    
[117] gridExtra_2.3        parallelly_1.30.0    sessioninfo_1.2.2    codetools_0.2-18    
[121] boot_1.3-28          colourpicker_1.1.1   gtools_3.9.2         assertthat_0.2.1    
[125] pkgload_1.2.4        rprojroot_2.0.2      withr_2.4.3          shinystan_2.5.0     
[129] mgcv_1.8-38          parallel_4.1.2       hms_1.1.1            rpart_4.1.16        
[133] timeDate_3043.102    grid_4.1.2           coda_0.19-4          class_7.3-20        
[137] minqa_1.2.4          pROC_1.18.0          shiny_1.7.1          lubridate_1.8.0     
[141] base64enc_0.1-3      dygraphs_1.1.1.6