Load packages
library(tidyr)
library(rstanarm)
library(loo)
library(ggplot2)
theme_set(bayesplot::theme_default())
library(ggridges)
library(bridgesampling)
This notebook demonstrates a simple model we trust (no model misspecification). In this case, cross-validation is not needed, and we we can get better accuracy using the explicit model.
An experiment was performed to estimate the effect of beta-blockers on mortality of cardiac patients (the example is from Gelman et al., 2013, Ch 3). A group of patients were randomly assigned to treatment and control groups:
Data, where grp2
is a dummy variable that captures the difference of the intercepts in the first and the second group.
d_bin2 <- data.frame(N = c(674, 680), y = c(39,22), grp2 = c(0,1))
To analyse whether the treatment is useful, we can use Binomial model for both groups and compute odds-ratio.
fit_bin2 <- stan_glm(y/N ~ grp2, family = binomial(), data = d_bin2,
weights = N, refresh=0)
In general we recommend showing the full posterior of the quantity of interest, which in this case is the odds ratio.
samples_bin2 <- rstan::extract(fit_bin2$stanfit)
theta1 <- plogis(samples_bin2$alpha)
theta2 <- plogis(samples_bin2$alpha + samples_bin2$beta)
oddsratio <- (theta2/(1-theta2))/(theta1/(1-theta1))
ggplot() + geom_histogram(aes(oddsratio), bins = 50, fill = 'grey', color = 'darkgrey') +
labs(y = '') + scale_y_continuous(breaks = NULL)
The probability that odds-ratio is less than 1:
print(mean(oddsratio<1),2)
[1] 0.99
This posterior distribution of the odds-ratio (or some transformation of it) is the simplest and the most accurate way to analyse the effectiveness of the treatment. In this case, there is high probability that the treatment is effective and relatively big.
Although we recommend showing the full posterior, the probability that oddsratio < 1 can be a useful summary. Simulation experiment binom_odds_comparison.R
runs 100 simulations with simulated data with varying oddsratio (0.1,…,1.0) and computes for each run the probability that oddsratio<1. The following figures show the variation in the results.
Variation in probability that oddsratio<1 when true oddsratio is varied.
load(file="binom_test_densities.RData")
ggplot(betaprobs_densities, aes(x = values, y = ind, height = scaled)) +
geom_density_ridges(stat = "identity", scale=0.6)
Sometimes it is better to focus on observable space (we can’t observe \(\theta\) or odds-ratio directly, but we can observe \(y\)). In leave-one-out cross-validation, model is fitted \(n\) times with each observation left out at time in fitting and used to evaluate the predictive performance. This corresponds to using the already seen observations as pseudo Monte Carlo samples from the future data distribution, with the leave-trick used to avoid double use of data. With the often used log-score we get \[\mathrm{LOO} = \frac{1}{n} \sum_{i=1}^n \log {p(y_i|x_i,D_{-i},M_k)}.\] Cross-validation is useful when we don’t trust any model (the models might include good enough models, but we just don’t know if that is the case).
Next we demonstrate one of the weaknesses of cross-validation (same holds for WAIC etc.).
To use leave-one-out where “one” refers to an individual patient, we need to change the model formulation a bit. In the above model formulation, the individual observations have been aggregated to group observations and running loo(fit_bin2)
would try to leave one group completely. In case of having more groups, this could be what we want, but in case of just two groups it is unlikely. Thus, in the following we switch a Bernoulli model with each individual as it’s own observation.
d_bin2b <- data.frame(y = c(rep(1,39), rep(0,674-39), rep(1,22), rep(0,680-22)), grp2 = c(rep(0, 674), rep(1, 680)))
fit_bin2b <- stan_glm(y ~ grp2, family = binomial(), data = d_bin2b, seed=180202538, refresh=0)
We fit also a “null” model which doesn’t use the group variable and thus has common parameter for both groups.’
fit_bin2bnull <- stan_glm(y ~ 1, family = binomial(), data = d_bin2b, seed=180202538, refresh=0)
We can then use cross-validation to compare whether adding the treatment variable improves predictive performance. We use fast Pareto smoothed importance sampling leave-one-out cross-validation [PSIS-LOO; Vehtari, Gelman and Gabry (2017)].
(loo_bin2 <- loo(fit_bin2b))
Computed from 4000 by 1354 log-likelihood matrix
Estimate SE
elpd_loo -248.1 23.3
p_loo 2.0 0.2
looic 496.1 46.6
------
Monte Carlo SE of elpd_loo is 0.0.
All Pareto k estimates are good (k < 0.5).
See help('pareto-k-diagnostic') for details.
(loo_bin2null <- loo(fit_bin2bnull))
Computed from 4000 by 1354 log-likelihood matrix
Estimate SE
elpd_loo -249.7 23.4
p_loo 1.0 0.1
looic 499.4 46.7
------
Monte Carlo SE of elpd_loo is 0.0.
All Pareto k estimates are good (k < 0.5).
See help('pareto-k-diagnostic') for details.
All Pareto \(k<0.5\) and we can trust PSIS-LOO result (Vehtari, Gelman and Gabry, 2017; Vehtari et al., 2019).
Let’s make pairwise comparison.
loo_compare(loo_bin2null, loo_bin2)
elpd_diff se_diff
fit_bin2b 0.0 0.0
fit_bin2bnull -1.6 2.3
elpd_diff is small compared to se, and thus cross-validation is uncertain whether estimating the treatment effect improves the predictive performance. To put this in perspective, we have \(N_1=674\) and \(N_2=680\), and 5.8% and 3.2% deaths, and this is now too weak information for cross-validation.
Simulation experiment binom_odds_comparison.R
runs 100 simulations with simulated data with varying oddsratio (0.1,…,1.0) and computes LOO comparison for each run.
Variation in LOO comparison when true oddsratio is varied.
ggplot(looprobs_densities, aes(x = values, y = ind, height = scaled)) +
geom_density_ridges(stat = "identity", scale=0.6)
We see that using the posterior distribution from the model is more efficient to detect the effect, but cross-validation will detect it eventually too. The difference here comes that cross-validation doesn’t trust the model, compares the model predictions to the “future data” using very weak assumption about the future. The weak assumption about the future is also the cross-validation strength as we we’ll see in another notebook.
We can also do predictive performance estimates using stronger assumption about the future. A reference predictive estimate with log-score can be computed as \[\mathrm{elpd}_{\mathrm{ref}} = \int p(\tilde{y}|D,M_*) \log p(\tilde{y}|D,M_k) d\tilde{y}, \] where \(M_*\) is a reference model we trust. Using a reference model to assess the other models corresponds to \(M\)-completed case (Vehtari and Ojanen, 2012), where the true model is replaced with a model we trust to be close enough to the true model.
The next figure shows the results from the same simulation study using a reference predictive approach with the fit_bin2
model used as the reference.
ggplot(refprobs_densities, aes(x = values, y = ind, height = scaled)) +
geom_density_ridges(stat = "identity", scale=0.6)
We can see better accuracy than for cross-validation. The similar improvement in the model selection performance is observed in projection predictive variable selection (Piironen and Vehtari, 2017; Piironen, Paasiniemi and Vehtari, 2020) implemented in projpred
package.
As comparison we include marginal likelihood based approach to compute the posterior probabilities for the null model (treatment effect is zero) and the model with unknown treatment effect. As the data and models are very simple, we may assume that the model is well specified. Marginal likelihoods and relative posterior probabilities can be sensitive to selected prior on the more complex model. Here we simply use the same prior as for the above examples. Marginal likelihoods are computed using the default bridge sampling approach implemented in bridge_sampling
package.
# rerun models with diagnostic file required by bridge_sampler
fit_bin2 <- stan_glm(y/N ~ grp2, family = binomial(), data = d_bin2,
weights = N, refresh=0,
diagnostic_file = file.path(tempdir(), "df.csv"))
(ml_bin2 <- bridge_sampler(fit_bin2, silent=TRUE))
Bridge sampling estimate of the log marginal likelihood: -11.47253
Estimate obtained in 4 iteration(s) via method "normal".
fit_bin2null <- stan_glm(y/N ~ 1, family = binomial(), data = d_bin2,
weights = N, refresh=0,
diagnostic_file = file.path(tempdir(), "df.csv"))
(ml_bin2null <- bridge_sampler(fit_bin2null, silent=TRUE))
Bridge sampling estimate of the log marginal likelihood: -11.46279
Estimate obtained in 4 iteration(s) via method "normal".
print(post_prob(ml_bin2, ml_bin2null), digits=2)
ml_bin2 ml_bin2null
0.5 0.5
Posterior probability computed from the marginal likelihoods is indecisive.
We repeat the simulation with marginal likelihood approach.
ggplot(bfprobs_densities, aes(x = values, y = ind, height = scaled)) +
geom_density_ridges(stat = "identity", scale=0.6)
We can see that marginal likelihood based approach favors more strongly null model for smaller treatment effects, requires a bigger effect than the other approaches to not favor the null model, but given big enough effect is more decisive on non-null model than cross-validation.
sessionInfo()
R version 4.1.2 (2021-11-01)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 20.04.3 LTS
Matrix products: default
BLAS/LAPACK: /opt/OpenBLAS/lib/libopenblas_haswellp-r0.3.17.so
locale:
[1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C LC_TIME=fi_FI.UTF-8
[4] LC_COLLATE=en_US.UTF-8 LC_MONETARY=fi_FI.UTF-8 LC_MESSAGES=en_US.UTF-8
[7] LC_PAPER=fi_FI.UTF-8 LC_NAME=C LC_ADDRESS=C
[10] LC_TELEPHONE=C LC_MEASUREMENT=fi_FI.UTF-8 LC_IDENTIFICATION=C
attached base packages:
[1] splines stats graphics grDevices utils datasets methods base
other attached packages:
[1] bridgesampling_1.1-2 ggridges_0.5.3 rstan_2.21.3 StanHeaders_2.21.0-7
[5] arm_1.12-2 lme4_1.1-28 Matrix_1.4-0 reliabilitydiag_0.2.0
[9] MASS_7.3-55 corrplot_0.92 caret_6.0-90 lattice_0.20-45
[13] GGally_2.1.2 fivethirtyeight_0.6.2 projpred_2.0.2 bayesplot_1.8.1
[17] forcats_0.5.1 stringr_1.4.0 dplyr_1.0.8 purrr_0.3.4
[21] readr_2.1.2 tidyr_1.2.0 tibble_3.1.6 ggplot2_3.3.5
[25] tidyverse_1.3.1 loo_2.4.1 rstanarm_2.21.1 Rcpp_1.0.8
[29] rmarkdown_2.11 knitr_1.37
loaded via a namespace (and not attached):
[1] utf8_1.2.2 tidyselect_1.1.1 htmlwidgets_1.5.4 grid_4.1.2
[5] pROC_1.18.0 devtools_2.4.3 munsell_0.5.0 codetools_0.2-18
[9] DT_0.20 future_1.23.0 miniUI_0.1.1.1 withr_2.4.3
[13] Brobdingnag_1.2-7 colorspace_2.0-2 highr_0.9 rstudioapi_0.13
[17] stats4_4.1.2 listenv_0.8.0 labeling_0.4.2 farver_2.1.0
[21] rprojroot_2.0.2 coda_0.19-4 parallelly_1.30.0 vctrs_0.3.8
[25] generics_0.1.2 ipred_0.9-12 xfun_0.29 R6_2.5.1
[29] markdown_1.1 gamm4_0.2-6 cachem_1.0.6 reshape_0.8.8
[33] assertthat_0.2.1 promises_1.2.0.1 scales_1.1.1 nnet_7.3-17
[37] gtable_0.3.0 globals_0.14.0 processx_3.5.2 timeDate_3043.102
[41] rlang_1.0.1 ModelMetrics_1.2.2.2 broom_0.7.12 checkmate_2.0.0
[45] inline_0.3.19 yaml_2.2.2 reshape2_1.4.4 abind_1.4-5
[49] modelr_0.1.8 threejs_0.3.3 crosstalk_1.2.0 backports_1.4.1
[53] httpuv_1.6.5 rsconnect_0.8.25 tensorA_0.36.2 tools_4.1.2
[57] lava_1.6.10 usethis_2.1.5 ellipsis_0.3.2 jquerylib_0.1.4
[61] posterior_1.2.0 RColorBrewer_1.1-2 sessioninfo_1.2.2 plyr_1.8.6
[65] base64enc_0.1-3 progress_1.2.2 ps_1.6.0 prettyunits_1.1.1
[69] rpart_4.1.16 zoo_1.8-9 haven_2.4.3 fs_1.5.2
[73] magrittr_2.0.2 data.table_1.14.2 colourpicker_1.1.1 reprex_2.0.1
[77] mvtnorm_1.1-3 matrixStats_0.61.0 pkgload_1.2.4 hms_1.1.1
[81] shinyjs_2.1.0 mime_0.12 evaluate_0.14 xtable_1.8-4
[85] shinystan_2.5.0 readxl_1.3.1 gridExtra_2.3 rstantools_2.1.1
[89] testthat_3.1.2 compiler_4.1.2 crayon_1.4.2 minqa_1.2.4
[93] htmltools_0.5.2 mgcv_1.8-38 later_1.3.0 tzdb_0.2.0
[97] RcppParallel_5.1.5 lubridate_1.8.0 DBI_1.1.2 dbplyr_2.1.1
[101] boot_1.3-28 brio_1.1.3 cli_3.1.1 parallel_4.1.2
[105] gower_1.0.0 igraph_1.2.11 pkgconfig_2.0.3 recipes_0.1.17
[109] xml2_1.3.3 foreach_1.5.2 dygraphs_1.1.1.6 bslib_0.3.1
[113] prodlim_2019.11.13 rvest_1.0.2 distributional_0.3.0 callr_3.7.0
[117] digest_0.6.29 cellranger_1.1.0 shiny_1.7.1 gtools_3.9.2
[121] nloptr_1.2.2.3 lifecycle_1.0.1 nlme_3.1-155 jsonlite_1.7.3
[125] desc_1.4.0 fansi_1.0.2 pillar_1.7.0 fastmap_1.1.0
[129] httr_1.4.2 pkgbuild_1.3.1 survival_3.2-13 glue_1.6.1
[133] xts_0.12.1 remotes_2.4.2 shinythemes_1.2.0 iterators_1.0.14
[137] class_7.3-20 stringi_1.7.6 sass_0.4.0 memoise_2.0.1
[141] future.apply_1.8.1