Setup

Load packages

library(tidyr)
library(rstanarm)
library(loo)
library(ggplot2)
theme_set(bayesplot::theme_default())
library(ggridges)
library(bridgesampling)

1 Introduction

This notebook demonstrates a simple model we trust (no model misspecification). In this case, cross-validation is not needed, and we we can get better accuracy using the explicit model.

2 Comparison of two groups with Binomial

An experiment was performed to estimate the effect of beta-blockers on mortality of cardiac patients (the example is from Gelman et al., 2013, Ch 3). A group of patients were randomly assigned to treatment and control groups:

  • out of 674 patients receiving the control, 39 died
  • out of 680 receiving the treatment, 22 died

Data, where grp2 is a dummy variable that captures the difference of the intercepts in the first and the second group.

d_bin2 <- data.frame(N = c(674, 680), y = c(39,22), grp2 = c(0,1))

2.1 Analysis of the observed data

To analyse whether the treatment is useful, we can use Binomial model for both groups and compute odds-ratio.

fit_bin2 <- stan_glm(y/N ~ grp2, family = binomial(), data = d_bin2,
                     weights = N, refresh=0)

In general we recommend showing the full posterior of the quantity of interest, which in this case is the odds ratio.

samples_bin2 <- rstan::extract(fit_bin2$stanfit)
theta1 <- plogis(samples_bin2$alpha)
theta2 <- plogis(samples_bin2$alpha + samples_bin2$beta)
oddsratio <- (theta2/(1-theta2))/(theta1/(1-theta1))
ggplot() + geom_histogram(aes(oddsratio), bins = 50, fill = 'grey', color = 'darkgrey') +
  labs(y = '') + scale_y_continuous(breaks = NULL)

The probability that odds-ratio is less than 1:

print(mean(oddsratio<1),2)
[1] 0.99

This posterior distribution of the odds-ratio (or some transformation of it) is the simplest and the most accurate way to analyse the effectiveness of the treatment. In this case, there is high probability that the treatment is effective and relatively big.

2.2 Simulation experiment

Although we recommend showing the full posterior, the probability that oddsratio < 1 can be a useful summary. Simulation experiment binom_odds_comparison.R runs 100 simulations with simulated data with varying oddsratio (0.1,…,1.0) and computes for each run the probability that oddsratio<1. The following figures show the variation in the results.

Variation in probability that oddsratio<1 when true oddsratio is varied.

load(file="binom_test_densities.RData")
ggplot(betaprobs_densities, aes(x = values, y = ind, height = scaled)) + 
  geom_density_ridges(stat = "identity", scale=0.6)

3 Cross-validation

Sometimes it is better to focus on observable space (we can’t observe \(\theta\) or odds-ratio directly, but we can observe \(y\)). In leave-one-out cross-validation, model is fitted \(n\) times with each observation left out at time in fitting and used to evaluate the predictive performance. This corresponds to using the already seen observations as pseudo Monte Carlo samples from the future data distribution, with the leave-trick used to avoid double use of data. With the often used log-score we get \[\mathrm{LOO} = \frac{1}{n} \sum_{i=1}^n \log {p(y_i|x_i,D_{-i},M_k)}.\] Cross-validation is useful when we don’t trust any model (the models might include good enough models, but we just don’t know if that is the case).

Next we demonstrate one of the weaknesses of cross-validation (same holds for WAIC etc.).

3.1 Analysis of the observed data

To use leave-one-out where “one” refers to an individual patient, we need to change the model formulation a bit. In the above model formulation, the individual observations have been aggregated to group observations and running loo(fit_bin2) would try to leave one group completely. In case of having more groups, this could be what we want, but in case of just two groups it is unlikely. Thus, in the following we switch a Bernoulli model with each individual as it’s own observation.

d_bin2b <- data.frame(y = c(rep(1,39), rep(0,674-39), rep(1,22), rep(0,680-22)), grp2 = c(rep(0, 674), rep(1, 680)))
fit_bin2b <- stan_glm(y ~ grp2, family = binomial(), data = d_bin2b, seed=180202538, refresh=0)

We fit also a “null” model which doesn’t use the group variable and thus has common parameter for both groups.’

fit_bin2bnull <- stan_glm(y ~ 1, family = binomial(), data = d_bin2b, seed=180202538, refresh=0)

We can then use cross-validation to compare whether adding the treatment variable improves predictive performance. We use fast Pareto smoothed importance sampling leave-one-out cross-validation [PSIS-LOO; Vehtari, Gelman and Gabry (2017)].

(loo_bin2 <- loo(fit_bin2b))

Computed from 4000 by 1354 log-likelihood matrix

         Estimate   SE
elpd_loo   -248.1 23.3
p_loo         2.0  0.2
looic       496.1 46.6
------
Monte Carlo SE of elpd_loo is 0.0.

All Pareto k estimates are good (k < 0.5).
See help('pareto-k-diagnostic') for details.
(loo_bin2null <- loo(fit_bin2bnull))

Computed from 4000 by 1354 log-likelihood matrix

         Estimate   SE
elpd_loo   -249.7 23.4
p_loo         1.0  0.1
looic       499.4 46.7
------
Monte Carlo SE of elpd_loo is 0.0.

All Pareto k estimates are good (k < 0.5).
See help('pareto-k-diagnostic') for details.

All Pareto \(k<0.5\) and we can trust PSIS-LOO result (Vehtari, Gelman and Gabry, 2017; Vehtari et al., 2019).

Let’s make pairwise comparison.

loo_compare(loo_bin2null, loo_bin2)
              elpd_diff se_diff
fit_bin2b      0.0       0.0   
fit_bin2bnull -1.6       2.3   

elpd_diff is small compared to se, and thus cross-validation is uncertain whether estimating the treatment effect improves the predictive performance. To put this in perspective, we have \(N_1=674\) and \(N_2=680\), and 5.8% and 3.2% deaths, and this is now too weak information for cross-validation.

3.2 Simulation experiment

Simulation experiment binom_odds_comparison.R runs 100 simulations with simulated data with varying oddsratio (0.1,…,1.0) and computes LOO comparison for each run.

Variation in LOO comparison when true oddsratio is varied.

ggplot(looprobs_densities, aes(x = values, y = ind, height = scaled)) + 
  geom_density_ridges(stat = "identity", scale=0.6)

We see that using the posterior distribution from the model is more efficient to detect the effect, but cross-validation will detect it eventually too. The difference here comes that cross-validation doesn’t trust the model, compares the model predictions to the “future data” using very weak assumption about the future. The weak assumption about the future is also the cross-validation strength as we we’ll see in another notebook.

4 Reference predictive approach

We can also do predictive performance estimates using stronger assumption about the future. A reference predictive estimate with log-score can be computed as \[\mathrm{elpd}_{\mathrm{ref}} = \int p(\tilde{y}|D,M_*) \log p(\tilde{y}|D,M_k) d\tilde{y}, \] where \(M_*\) is a reference model we trust. Using a reference model to assess the other models corresponds to \(M\)-completed case (Vehtari and Ojanen, 2012), where the true model is replaced with a model we trust to be close enough to the true model.

4.1 Simulation experiment

The next figure shows the results from the same simulation study using a reference predictive approach with the fit_bin2 model used as the reference.

ggplot(refprobs_densities, aes(x = values, y = ind, height = scaled)) + 
  geom_density_ridges(stat = "identity", scale=0.6)

We can see better accuracy than for cross-validation. The similar improvement in the model selection performance is observed in projection predictive variable selection (Piironen and Vehtari, 2017; Piironen, Paasiniemi and Vehtari, 2020) implemented in projpred package.

5 Marginal likelihood

As comparison we include marginal likelihood based approach to compute the posterior probabilities for the null model (treatment effect is zero) and the model with unknown treatment effect. As the data and models are very simple, we may assume that the model is well specified. Marginal likelihoods and relative posterior probabilities can be sensitive to selected prior on the more complex model. Here we simply use the same prior as for the above examples. Marginal likelihoods are computed using the default bridge sampling approach implemented in bridge_sampling package.

5.1 Analysis of the observed data

# rerun models with diagnostic file required by bridge_sampler
fit_bin2 <- stan_glm(y/N ~ grp2, family = binomial(), data = d_bin2,
                     weights = N, refresh=0,
                     diagnostic_file = file.path(tempdir(), "df.csv"))
(ml_bin2 <- bridge_sampler(fit_bin2, silent=TRUE))
Bridge sampling estimate of the log marginal likelihood: -11.47253
Estimate obtained in 4 iteration(s) via method "normal".
fit_bin2null <- stan_glm(y/N ~ 1, family = binomial(), data = d_bin2,
                     weights = N, refresh=0,
                     diagnostic_file = file.path(tempdir(), "df.csv"))
(ml_bin2null <- bridge_sampler(fit_bin2null, silent=TRUE))
Bridge sampling estimate of the log marginal likelihood: -11.46279
Estimate obtained in 4 iteration(s) via method "normal".
print(post_prob(ml_bin2, ml_bin2null), digits=2)
    ml_bin2 ml_bin2null 
        0.5         0.5 

Posterior probability computed from the marginal likelihoods is indecisive.

5.2 Simulation experiment

We repeat the simulation with marginal likelihood approach.

ggplot(bfprobs_densities, aes(x = values, y = ind, height = scaled)) + 
  geom_density_ridges(stat = "identity", scale=0.6)

We can see that marginal likelihood based approach favors more strongly null model for smaller treatment effects, requires a bigger effect than the other approaches to not favor the null model, but given big enough effect is more decisive on non-null model than cross-validation.


References

Gelman, A., Carlin, J. B., Stern, H. S., Dunson, D. B., Vehtari, A. and Rubin, D. B. (2013) Bayesian data analysis, third edition. CRC Press.
Piironen, J., Paasiniemi, M. and Vehtari, A. (2020) ‘Projective inference in high-dimensional problems: Prediction and feature selection’, Electronic Journal of Statistics, 14(1), pp. 2155–2197.
Piironen, J. and Vehtari, A. (2017) ‘Comparison of Bayesian predictive methods for model selection’, Statistics and Computing, 27(3), pp. 711–735. doi: 10.1007/s11222-016-9649-y.
Vehtari, A., Gelman, A. and Gabry, J. (2017) ‘Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC, Statistics and Computing, 27(5), pp. 1413–1432. doi: 10.1007/s11222-016-9696-4.
Vehtari, A. and Ojanen, J. (2012) ‘A survey of Bayesian predictive methods for model assessment, selection and comparison’, Statistics Surveys, 6, pp. 142–228. doi: 10.1214/12-SS102.
Vehtari, A., Simpson, D., Gelman, A., Yao, Y. and Gabry, J. (2019) ‘Pareto smoothed importance sampling’, arXiv preprint arXiv:1507.02646. Available at: https://arxiv.org/abs/1507.02646v6.

Licenses

  • Code © 2018-2019, Aki Vehtari, licensed under BSD-3.
  • Text © 2018-2019, Aki Vehtari, licensed under CC-BY-NC 4.0.
  • Part of the code copied from rstanarm_demo.Rmd written by Aki Vehtari and Markus Paasiniemi

Original Computing Environment

sessionInfo()
R version 4.1.2 (2021-11-01)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 20.04.3 LTS

Matrix products: default
BLAS/LAPACK: /opt/OpenBLAS/lib/libopenblas_haswellp-r0.3.17.so

locale:
 [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C               LC_TIME=fi_FI.UTF-8       
 [4] LC_COLLATE=en_US.UTF-8     LC_MONETARY=fi_FI.UTF-8    LC_MESSAGES=en_US.UTF-8   
 [7] LC_PAPER=fi_FI.UTF-8       LC_NAME=C                  LC_ADDRESS=C              
[10] LC_TELEPHONE=C             LC_MEASUREMENT=fi_FI.UTF-8 LC_IDENTIFICATION=C       

attached base packages:
[1] splines   stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] bridgesampling_1.1-2  ggridges_0.5.3        rstan_2.21.3          StanHeaders_2.21.0-7 
 [5] arm_1.12-2            lme4_1.1-28           Matrix_1.4-0          reliabilitydiag_0.2.0
 [9] MASS_7.3-55           corrplot_0.92         caret_6.0-90          lattice_0.20-45      
[13] GGally_2.1.2          fivethirtyeight_0.6.2 projpred_2.0.2        bayesplot_1.8.1      
[17] forcats_0.5.1         stringr_1.4.0         dplyr_1.0.8           purrr_0.3.4          
[21] readr_2.1.2           tidyr_1.2.0           tibble_3.1.6          ggplot2_3.3.5        
[25] tidyverse_1.3.1       loo_2.4.1             rstanarm_2.21.1       Rcpp_1.0.8           
[29] rmarkdown_2.11        knitr_1.37           

loaded via a namespace (and not attached):
  [1] utf8_1.2.2           tidyselect_1.1.1     htmlwidgets_1.5.4    grid_4.1.2          
  [5] pROC_1.18.0          devtools_2.4.3       munsell_0.5.0        codetools_0.2-18    
  [9] DT_0.20              future_1.23.0        miniUI_0.1.1.1       withr_2.4.3         
 [13] Brobdingnag_1.2-7    colorspace_2.0-2     highr_0.9            rstudioapi_0.13     
 [17] stats4_4.1.2         listenv_0.8.0        labeling_0.4.2       farver_2.1.0        
 [21] rprojroot_2.0.2      coda_0.19-4          parallelly_1.30.0    vctrs_0.3.8         
 [25] generics_0.1.2       ipred_0.9-12         xfun_0.29            R6_2.5.1            
 [29] markdown_1.1         gamm4_0.2-6          cachem_1.0.6         reshape_0.8.8       
 [33] assertthat_0.2.1     promises_1.2.0.1     scales_1.1.1         nnet_7.3-17         
 [37] gtable_0.3.0         globals_0.14.0       processx_3.5.2       timeDate_3043.102   
 [41] rlang_1.0.1          ModelMetrics_1.2.2.2 broom_0.7.12         checkmate_2.0.0     
 [45] inline_0.3.19        yaml_2.2.2           reshape2_1.4.4       abind_1.4-5         
 [49] modelr_0.1.8         threejs_0.3.3        crosstalk_1.2.0      backports_1.4.1     
 [53] httpuv_1.6.5         rsconnect_0.8.25     tensorA_0.36.2       tools_4.1.2         
 [57] lava_1.6.10          usethis_2.1.5        ellipsis_0.3.2       jquerylib_0.1.4     
 [61] posterior_1.2.0      RColorBrewer_1.1-2   sessioninfo_1.2.2    plyr_1.8.6          
 [65] base64enc_0.1-3      progress_1.2.2       ps_1.6.0             prettyunits_1.1.1   
 [69] rpart_4.1.16         zoo_1.8-9            haven_2.4.3          fs_1.5.2            
 [73] magrittr_2.0.2       data.table_1.14.2    colourpicker_1.1.1   reprex_2.0.1        
 [77] mvtnorm_1.1-3        matrixStats_0.61.0   pkgload_1.2.4        hms_1.1.1           
 [81] shinyjs_2.1.0        mime_0.12            evaluate_0.14        xtable_1.8-4        
 [85] shinystan_2.5.0      readxl_1.3.1         gridExtra_2.3        rstantools_2.1.1    
 [89] testthat_3.1.2       compiler_4.1.2       crayon_1.4.2         minqa_1.2.4         
 [93] htmltools_0.5.2      mgcv_1.8-38          later_1.3.0          tzdb_0.2.0          
 [97] RcppParallel_5.1.5   lubridate_1.8.0      DBI_1.1.2            dbplyr_2.1.1        
[101] boot_1.3-28          brio_1.1.3           cli_3.1.1            parallel_4.1.2      
[105] gower_1.0.0          igraph_1.2.11        pkgconfig_2.0.3      recipes_0.1.17      
[109] xml2_1.3.3           foreach_1.5.2        dygraphs_1.1.1.6     bslib_0.3.1         
[113] prodlim_2019.11.13   rvest_1.0.2          distributional_0.3.0 callr_3.7.0         
[117] digest_0.6.29        cellranger_1.1.0     shiny_1.7.1          gtools_3.9.2        
[121] nloptr_1.2.2.3       lifecycle_1.0.1      nlme_3.1-155         jsonlite_1.7.3      
[125] desc_1.4.0           fansi_1.0.2          pillar_1.7.0         fastmap_1.1.0       
[129] httr_1.4.2           pkgbuild_1.3.1       survival_3.2-13      glue_1.6.1          
[133] xts_0.12.1           remotes_2.4.2        shinythemes_1.2.0    iterators_1.0.14    
[137] class_7.3-20         stringi_1.7.6        sass_0.4.0           memoise_2.0.1       
[141] future.apply_1.8.1