Load packages - Requires projpred 1.0.1+
library(rstanarm)
library(arm)
options(mc.cores = parallel::detectCores())
library(loo)
library(ggplot2)
library(bayesplot)
theme_set(bayesplot::theme_default(base_family = "sans"))
library(GGally)
library(projpred)
This notebook demonstrates collinearity in multipredictor regression. Example of predicting the yields of mesquite bushes comes from Gelman and Hill (2007). The outcome variable is the total weight (in grams) of photosynthetic material as derived from actual harvesting of the bush. The predictor variables are:
dat <- read.table("mesquite.dat", header=TRUE) %>%
mutate_if(is.character, as.factor) %>%
mutate_if(is.factor, as.numeric)
summary(dat)
Obs Group Diam1 Diam2 TotHt
Min. : 1.00 Min. :1.000 Min. :0.800 Min. :0.400 Min. :0.650
1st Qu.:12.25 1st Qu.:1.000 1st Qu.:1.400 1st Qu.:1.000 1st Qu.:1.200
Median :23.50 Median :2.000 Median :1.950 Median :1.525 Median :1.500
Mean :23.50 Mean :1.565 Mean :2.099 Mean :1.572 Mean :1.482
3rd Qu.:34.75 3rd Qu.:2.000 3rd Qu.:2.475 3rd Qu.:1.900 3rd Qu.:1.700
Max. :46.00 Max. :2.000 Max. :5.200 Max. :4.000 Max. :3.000
CanHt Dens LeafWt
Min. :0.5000 Min. :1.000 Min. : 60.2
1st Qu.:0.8625 1st Qu.:1.000 1st Qu.: 219.6
Median :1.1000 Median :1.000 Median : 361.9
Mean :1.1107 Mean :1.674 Mean : 559.7
3rd Qu.:1.3000 3rd Qu.:2.000 3rd Qu.: 688.7
Max. :2.5000 Max. :9.000 Max. :4052.0
Plot data
ggpairs(dat, diag=list(continuous="barDiag"))
Additional transformed variables
dat$CanVol <- dat$Diam1 * dat$Diam2 * dat$CanHt
dat$CanAre <- dat$Diam1 * dat$Diam2
dat$CanSha <- dat$Diam1 / dat$Diam2
It may be reasonable to fit on the logarithmic scale, so that effects are multiplicative rather than additive (we’ll return to checking this assumption in another notebook).
We first illustrate the problem with maxiumum likelihood estimate
lm1 <- lm(formula = log(LeafWt) ~ log(CanVol) + log(CanAre) + log(CanSha) + log(TotHt) + log(Dens) + Group, data = dat)
display(lm1)
lm(formula = log(LeafWt) ~ log(CanVol) + log(CanAre) + log(CanSha) +
log(TotHt) + log(Dens) + Group, data = dat)
coef.est coef.se
(Intercept) 4.18 0.23
log(CanVol) 0.37 0.28
log(CanAre) 0.40 0.29
log(CanSha) -0.38 0.23
log(TotHt) 0.39 0.31
log(Dens) 0.11 0.12
Group 0.58 0.13
---
n = 46, k = 7
residual sd = 0.33, R-Squared = 0.89
GroupMCD seems to be only variable which has coeffiecent far away from zero. Let’s try making a model with just the group variable.
lm2 <- lm(formula = log(LeafWt) ~ Group, data = dat)
display(lm2)
lm(formula = log(LeafWt) ~ Group, data = dat)
coef.est coef.se
(Intercept) 6.49 0.44
Group -0.36 0.27
---
n = 46, k = 2
residual sd = 0.91, R-Squared = 0.04
Hmmm…. R-squared dropped a lot, so it seems that other variables are useful even if estimated effects and their standard errors indicate that they are not relevant. There are approach for maximum likelihood estimated models to investigate this, but we’ll switch now to Bayesian inference using rstanarm
.
The corresponding rstanarm
model fit using stan_glm
fitg <- stan_glm(formula = log(LeafWt) ~ log(CanVol) + log(CanAre) + log(CanSha) + log(TotHt) + log(Dens) + Group, data = dat, refresh=0)
Print summary for some diagnostics.
summary(fitg)
Model Info:
function: stan_glm
family: gaussian [identity]
formula: log(LeafWt) ~ log(CanVol) + log(CanAre) + log(CanSha) + log(TotHt) +
log(Dens) + Group
algorithm: sampling
sample: 4000 (posterior sample size)
priors: see help('prior_summary')
observations: 46
predictors: 7
Estimates:
mean sd 10% 50% 90%
(Intercept) 4.2 0.2 3.9 4.2 4.5
log(CanVol) 0.4 0.3 0.0 0.4 0.7
log(CanAre) 0.4 0.3 0.0 0.4 0.8
log(CanSha) -0.4 0.2 -0.7 -0.4 -0.1
log(TotHt) 0.4 0.3 0.0 0.4 0.8
log(Dens) 0.1 0.1 -0.1 0.1 0.3
Group 0.6 0.1 0.4 0.6 0.8
sigma 0.3 0.0 0.3 0.3 0.4
Fit Diagnostics:
mean sd 10% 50% 90%
mean_PPD 5.9 0.1 5.8 5.9 6.0
The mean_ppd is the sample average posterior predictive distribution of the outcome variable (for details see help('summary.stanreg')).
MCMC diagnostics
mcse Rhat n_eff
(Intercept) 0.0 1.0 3727
log(CanVol) 0.0 1.0 1364
log(CanAre) 0.0 1.0 1458
log(CanSha) 0.0 1.0 2571
log(TotHt) 0.0 1.0 1647
log(Dens) 0.0 1.0 2702
Group 0.0 1.0 2077
sigma 0.0 1.0 2738
mean_PPD 0.0 1.0 3458
log-posterior 0.1 1.0 1517
For each parameter, mcse is Monte Carlo standard error, n_eff is a crude measure of effective sample size, and Rhat is the potential scale reduction factor on split chains (at convergence Rhat=1).
Rhats and n_effs are good (see, e.g., RStan workflow), but QR transformation usually makes sampling work even better (see, The QR Decomposition For Regression Models)
Print summary for some diagnostics.
summary(fitg)
Model Info:
function: stan_glm
family: gaussian [identity]
formula: log(LeafWt) ~ log(CanVol) + log(CanAre) + log(CanSha) + log(TotHt) +
log(Dens) + Group
algorithm: sampling
sample: 4000 (posterior sample size)
priors: see help('prior_summary')
observations: 46
predictors: 7
Estimates:
mean sd 10% 50% 90%
(Intercept) 4.2 0.2 3.9 4.2 4.5
log(CanVol) 0.4 0.3 0.0 0.4 0.7
log(CanAre) 0.4 0.3 0.0 0.4 0.8
log(CanSha) -0.4 0.2 -0.7 -0.4 -0.1
log(TotHt) 0.4 0.3 0.0 0.4 0.8
log(Dens) 0.1 0.1 -0.1 0.1 0.3
Group 0.6 0.1 0.4 0.6 0.8
sigma 0.3 0.0 0.3 0.3 0.4
Fit Diagnostics:
mean sd 10% 50% 90%
mean_PPD 5.9 0.1 5.8 5.9 6.0
The mean_ppd is the sample average posterior predictive distribution of the outcome variable (for details see help('summary.stanreg')).
MCMC diagnostics
mcse Rhat n_eff
(Intercept) 0.0 1.0 3727
log(CanVol) 0.0 1.0 1364
log(CanAre) 0.0 1.0 1458
log(CanSha) 0.0 1.0 2571
log(TotHt) 0.0 1.0 1647
log(Dens) 0.0 1.0 2702
Group 0.0 1.0 2077
sigma 0.0 1.0 2738
mean_PPD 0.0 1.0 3458
log-posterior 0.1 1.0 1517
For each parameter, mcse is Monte Carlo standard error, n_eff is a crude measure of effective sample size, and Rhat is the potential scale reduction factor on split chains (at convergence Rhat=1).
Use of QR decomposition improved sampling efficiency (actually we get superefficient sampling, ie better than independent sampling) and we continue with this model.
Instead of looking at the tables, it’s easier to look at plots
mcmc_areas(as.matrix(fitg), prob = .5, prob_outer = .95)
All 95% posterior intervals except for GroupMCD are overlapping 0 and it seems we have serious collinearity problem.
Looking at the pairwise posteriors we can see high correlations especially between log(CanVol) and log(CanAre).
mcmc_pairs(as.matrix(fitg),pars = c("log(CanVol)","log(CanAre)","log(CanSha)","log(TotHt)","log(Dens)"))
If look more carefully on of the subplots, we see that although marginal posterior intervals overlap 0, some pairwise joint posteriors are not overlapping 0. Let’s look more carefully the joint posterior of log(CanVol) and log(CanAre).
mcmc_scatter(as.matrix(fitg), pars = c("log(CanVol)","log(CanAre)")) +
geom_vline(xintercept=0) +
geom_hline(yintercept=0)
From the joint posterior scatter plot, we can see that 0 is far away fron the typical set.
In case of even more variables with some being relevant and some irrelevant, it will be difficult to analyse joint posterior to see which variables are jointly relevant. We can easily test whether any of the covariates are useful by using cross-validation to compare to a null model,
fitg0 <- update(fitg, formula = log(LeafWt) ~ 1, QR=FALSE)
We compute leave-one-out cross-validation elpd’s using PSIS-LOO (Vehtari, Gelman and Gabry, 2017)
(loog <- loo(fitg))
Computed from 4000 by 46 log-likelihood matrix
Estimate SE
elpd_loo -19.4 5.4
p_loo 7.6 1.6
looic 38.8 10.8
------
Monte Carlo SE of elpd_loo is NA.
Pareto k diagnostic values:
Count Pct. Min. n_eff
(-Inf, 0.5] (good) 43 93.5% 811
(0.5, 0.7] (ok) 2 4.3% 321
(0.7, 1] (bad) 1 2.2% 324
(1, Inf) (very bad) 0 0.0% <NA>
See help('pareto-k-diagnostic') for details.
(loog0 <- loo(fitg0))
Computed from 4000 by 46 log-likelihood matrix
Estimate SE
elpd_loo -62.6 5.0
p_loo 1.9 0.5
looic 125.1 10.0
------
Monte Carlo SE of elpd_loo is 0.0.
All Pareto k estimates are good (k < 0.5).
See help('pareto-k-diagnostic') for details.
The model with variables has one bad Pareto \(k\) value (Vehtari, Gelman and Gabry, 2017). We can fix that by computing the corresponding leave-one-out-posterior exactly (Vehtari, Gelman and Gabry, 2017).
(loog <- loo(fitg, k_threshold=0.7))
Computed from 4000 by 46 log-likelihood matrix
Estimate SE
elpd_loo -19.4 5.4
p_loo 7.6 1.6
looic 38.7 10.8
------
Monte Carlo SE of elpd_loo is 0.1.
Pareto k diagnostic values:
Count Pct. Min. n_eff
(-Inf, 0.5] (good) 43 95.6% 811
(0.5, 0.7] (ok) 2 4.4% 321
(0.7, 1] (bad) 0 0.0% <NA>
(1, Inf) (very bad) 0 0.0% <NA>
All Pareto k estimates are ok (k < 0.7).
See help('pareto-k-diagnostic') for details.
And then we can compare the models.
loo_compare(loog0, loog)
elpd_diff se_diff
fitg 0.0 0.0
fitg0 -43.2 7.0
Based on cross-validation covariates together contain significant information to improve predictions.
We might want to choose some variables 1) because we don’t want to observe all the variables in the future (e.g. due to the measurement cost), or 2) we want to most relevant variables which we define here as a minimal set of variables which can provide similar predictions to the full model.
LOO can be used for model selection, but we don’t recommend it for variable selection as discussed by Piironen and Vehtari (2017). The reason for not using LOO in variable selection is that the selection process uses the data twice, and in case of large number variable combinations the selection process overfits and can produce really bad models. Using the usual posterior inference given the selected variables ignores that the selected variables are conditonal on the selection process and simply setting some variables to 0 ignores the uncertainty related to their relevance.
Piironen and Vehtari (2017) also show that a projection predictive approach can be used to make a model reduction, that is, choosing a smaller model with some coefficients set to 0. The projection predictive approach solves the problem how to do inference after the selection. The solution is to project the full model posterior to the restricted subspace. See more by Piironen, Paasiniemi and Vehtari (2020).
We make the projective predictive variable selection using the previous full model. A fast leave-one-out cross-validation approach (Vehtari, Gelman and Gabry, 2017) is used to choose the model size.
fitg_cv <- cv_varsel(fitg, method='forward', cv_method='LOO', nloo = nrow(dat))
We can now look at the estimated predictive performance of smaller models compared to the full model.
plot(fitg_cv, stats = c('elpd', 'rmse'))
And we get a loo-cv based recommendation for the model size to choose
(nsel <- suggest_size(fitg_cv, alpha=0.1))
[1] 3
(vsel <- solution_terms(fitg_cv)[1:nsel])
[1] "log(CanVol)" "Group" "log(CanSha)"
We see that 3 variables is enough to get the same predictive accuracy as with all 4 variables.
Next we form the projected posterior for the chosen model.
projg <- project(fitg_cv, nv = nsel, ns = 4000)
projdraws <- as.matrix(projg)
round(colMeans(projdraws),1)
Intercept log(CanVol) Group log(CanSha) sigma
4.3 0.8 0.6 -0.4 0.4
round(posterior_interval(projdraws),1)
5% 95%
Intercept 3.9 4.7
log(CanVol) 0.7 0.9
Group 0.4 0.8
log(CanSha) -0.8 0.0
sigma 0.3 0.4
mcmc_areas(projdraws, pars=c("Intercept",vsel), prob_outer=0.99)
sessionInfo()
R version 4.1.2 (2021-11-01)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 20.04.3 LTS
Matrix products: default
BLAS/LAPACK: /opt/OpenBLAS/lib/libopenblas_haswellp-r0.3.17.so
locale:
[1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C LC_TIME=fi_FI.UTF-8
[4] LC_COLLATE=en_US.UTF-8 LC_MONETARY=fi_FI.UTF-8 LC_MESSAGES=en_US.UTF-8
[7] LC_PAPER=fi_FI.UTF-8 LC_NAME=C LC_ADDRESS=C
[10] LC_TELEPHONE=C LC_MEASUREMENT=fi_FI.UTF-8 LC_IDENTIFICATION=C
attached base packages:
[1] splines stats graphics grDevices utils datasets methods base
other attached packages:
[1] arm_1.12-2 lme4_1.1-28 Matrix_1.4-0 reliabilitydiag_0.2.0
[5] MASS_7.3-55 corrplot_0.92 caret_6.0-90 lattice_0.20-45
[9] GGally_2.1.2 fivethirtyeight_0.6.2 projpred_2.0.2 bayesplot_1.8.1
[13] forcats_0.5.1 stringr_1.4.0 dplyr_1.0.8 purrr_0.3.4
[17] readr_2.1.2 tidyr_1.2.0 tibble_3.1.6 ggplot2_3.3.5
[21] tidyverse_1.3.1 loo_2.4.1 rstanarm_2.21.1 Rcpp_1.0.8
[25] rmarkdown_2.11 knitr_1.37
loaded via a namespace (and not attached):
[1] readxl_1.3.1 backports_1.4.1 plyr_1.8.6 igraph_1.2.11
[5] listenv_0.8.0 crosstalk_1.2.0 usethis_2.1.5 rstantools_2.1.1
[9] inline_0.3.19 digest_0.6.29 foreach_1.5.2 htmltools_0.5.2
[13] rsconnect_0.8.25 fansi_1.0.2 magrittr_2.0.2 checkmate_2.0.0
[17] memoise_2.0.1 tzdb_0.2.0 remotes_2.4.2 globals_0.14.0
[21] recipes_0.1.17 gower_1.0.0 modelr_0.1.8 RcppParallel_5.1.5
[25] matrixStats_0.61.0 xts_0.12.1 prettyunits_1.1.1 colorspace_2.0-2
[29] rvest_1.0.2 haven_2.4.3 xfun_0.29 callr_3.7.0
[33] crayon_1.4.2 jsonlite_1.7.3 iterators_1.0.14 survival_3.2-13
[37] zoo_1.8-9 glue_1.6.1 gtable_0.3.0 ipred_0.9-12
[41] distributional_0.3.0 pkgbuild_1.3.1 rstan_2.21.3 future.apply_1.8.1
[45] abind_1.4-5 scales_1.1.1 DBI_1.1.2 miniUI_0.1.1.1
[49] progress_1.2.2 xtable_1.8-4 lava_1.6.10 prodlim_2019.11.13
[53] stats4_4.1.2 StanHeaders_2.21.0-7 DT_0.20 htmlwidgets_1.5.4
[57] httr_1.4.2 threejs_0.3.3 RColorBrewer_1.1-2 posterior_1.2.0
[61] ellipsis_0.3.2 reshape_0.8.8 pkgconfig_2.0.3 farver_2.1.0
[65] nnet_7.3-17 sass_0.4.0 dbplyr_2.1.1 utf8_1.2.2
[69] labeling_0.4.2 tidyselect_1.1.1 rlang_1.0.1 reshape2_1.4.4
[73] later_1.3.0 munsell_0.5.0 cellranger_1.1.0 tools_4.1.2
[77] cachem_1.0.6 cli_3.1.1 generics_0.1.2 devtools_2.4.3
[81] broom_0.7.12 ggridges_0.5.3 evaluate_0.14 fastmap_1.1.0
[85] yaml_2.2.2 ModelMetrics_1.2.2.2 processx_3.5.2 fs_1.5.2
[89] future_1.23.0 nlme_3.1-155 mime_0.12 xml2_1.3.3
[93] brio_1.1.3 compiler_4.1.2 shinythemes_1.2.0 rstudioapi_0.13
[97] gamm4_0.2-6 testthat_3.1.2 reprex_2.0.1 bslib_0.3.1
[101] stringi_1.7.6 highr_0.9 ps_1.6.0 desc_1.4.0
[105] nloptr_1.2.2.3 markdown_1.1 shinyjs_2.1.0 tensorA_0.36.2
[109] vctrs_0.3.8 pillar_1.7.0 lifecycle_1.0.1 jquerylib_0.1.4
[113] data.table_1.14.2 httpuv_1.6.5 R6_2.5.1 promises_1.2.0.1
[117] gridExtra_2.3 parallelly_1.30.0 sessioninfo_1.2.2 codetools_0.2-18
[121] boot_1.3-28 colourpicker_1.1.1 gtools_3.9.2 assertthat_0.2.1
[125] pkgload_1.2.4 rprojroot_2.0.2 withr_2.4.3 shinystan_2.5.0
[129] mgcv_1.8-38 parallel_4.1.2 hms_1.1.1 rpart_4.1.16
[133] timeDate_3043.102 grid_4.1.2 coda_0.19-4 class_7.3-20
[137] minqa_1.2.4 pROC_1.18.0 shiny_1.7.1 lubridate_1.8.0
[141] base64enc_0.1-3 dygraphs_1.1.1.6