library(rprojroot)
root <- has_file(".Bayesian-Workflow-root")$make_fix_file()
library(cmdstanr)
library(posterior)
options(pillar.neg = FALSE,
pillar.subtle = FALSE,
pillar.sigfig = 2)
options(width = 90)
library(tidyr)
library(dplyr)
library(ggplot2)
library(bayesplot)
library(RColorBrewer)
theme_set(bayesplot::theme_default(base_family = "sans", base_size = 14))
library(latex2exp)
library(patchwork)
set1 <- RColorBrewer::brewer.pal(7, "Set1")
SEED <- 48927 # set random seed for reproducibility
print_stan_file <- function(file) {
code <- readLines(file)
if (isTRUE(getOption("knitr.in.progress")) &
identical(knitr::opts_current$get("results"), "asis")) {
# In render: emit as-is so Pandoc/Quarto does syntax highlighting
block <- paste0("```stan", "\n", paste(code, collapse = "\n"), "\n", "```")
knitr::asis_output(block)
} else {
writeLines(code)
}
}Illustration of simple problematic posteriors
This notebook includes the code for Bayesian Workflow book Section 12.3 Failure modes and steps forward.
1 Introduction
This case study demonstrates using simple examples the most common failure modes in Markov chain Monte Carlo based Bayesian inference, how to recognize these using the diagnostics, and how to fix the problems.
Load packages
2 Improper posterior
An unbounded likelihood without a proper prior can lead to an improper posterior. We recommend to always use proper priors (integral over a proper distribution is finite) to guarantee proper posteriors.
A commonly used model that can have unbounded likelihood is logistic regression with complete separation in data.
2.1 Data
Univariate continous predictor x, binary target y, and the two classes are completely separable, which leads to unbounded likelihood.
set.seed(SEED + 4)
M <- 1
N <- 10
x <- matrix(sort(rnorm(N)), ncol = M)
y <- rep(c(0, 1), each = N / 2)
data_logit <- list(M = M, N = N, x = x, y = y)data.frame(data_logit) |>
ggplot(aes(x, y)) +
geom_point(size = 3, shape = 1, alpha = 0.6) +
scale_y_continuous(breaks = c(0, 1))2.2 Model
We use the following Stan logistic regression model, where we have `forgot'' to include prior for the coefficientbeta`.
code_logit <- root("problems", "logit_glm.stan")print_stan_file(code_logit)// logistic regression
data {
int<lower=0> N;
int<lower=0> M;
array[N] int<lower=0,upper=1> y;
matrix[N,M] x;
}
parameters {
real alpha;
vector[M] beta;
}
model {
y ~ bernoulli_logit_glm(x, alpha, beta);
}Sample
mod_logit <- cmdstan_model(stan_file = code_logit)
fit_logit <- mod_logit$sample(data = data_logit, seed = SEED, refresh = 0)2.3 Convergence diagnostics
When running Stan, we get warnings. We can also explicitly check the inference diagnostics:
fit_logit$diagnostic_summary()$num_divergent
[1] 409 424 408 317
$num_max_treedepth
[1] 591 576 592 683
$ebfmi
[1] 1.819008 1.975546 1.972463 2.061800
We can also check \widehat{R} end effective sample size (ESS) diagnostics (Vehtari et al. 2021)
draws <- as_draws_rvars(fit_logit$draws())
summarize_draws(draws)# A tibble: 3 × 10
variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
<chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 lp__ -5.6e-308 -5.6e-308 0 0 -5.6e-308 -5.6e-308 NA NA NA
2 alpha 1.8e+ 45 2.8e+ 30 3.8e45 4.1e30 1.1e+ 25 1.1e+ 46 3.1 4.6 16.
3 beta 4.2e+ 45 6.8e+ 30 8.9e45 1.0e31 2.8e+ 25 2.5e+ 46 3.1 4.6 15.
We see that \widehat{R} for both and are about 3 and Bulk-ESS is about 4, which indicate that the chains are not mixing at all.
The above diagnostics refer to a documentation (https://mc-stan.org/misc/warnings) that mentions possibility to adjust the sampling algorithm options (e.g., increasing adapt_delta and max_treedepth), but it is better first to investigate the posterior.
The following Figure shows the posterior draws as marginal histograms and joint scatterplots. The range of the values is huge, which is typical for improper posterior, but the values of alpha and beta in any practical application are likely to have much smaller magnitude. In this case, increasing adapt_delta and max_treedepth would not have solved the problem, and would have just caused waste of modeler and computation time.
mcmc_pairs(as_draws_array(draws), pars = c("alpha", "beta"),
off_diag_args = list(alpha = 0.2))2.4 Stan compiler pedantic check
The above diagnostics are applicable with any probabilistic programming framework. Stan compiler can also recognize some common problems. By default the pedantic mode is not enabled, but we can use option pedantic = TRUE at compilation time, or after compilation with the check_syntax method.
mod_logit$check_syntax(pedantic = TRUE)The pedantic check correctly warns that alpha and beta don’t have priors.
2.5 A fixed model with proper priors
We add proper weak priors and rerun inference.
code_logit2 <- root("problems", "logit_glm2.stan")print_stan_file(code_logit2)// logistic regression
data {
int<lower=0> N;
int<lower=0> M;
array[N] int<lower=0,upper=1> y;
matrix[N,M] x;
}
parameters {
real alpha;
vector[M] beta;
}
model {
alpha ~ normal(0,10);
beta ~ normal(0,10);
y ~ bernoulli_logit_glm(x, alpha, beta);
}Sample
mod_logit2 <- cmdstan_model(stan_file = code_logit2)
fit_logit2 <- mod_logit2$sample(data = data_logit, seed = SEED, refresh = 0)2.6 Convergence diagnostics
There were no convergence warnings. We can also explicitly check the inference diagnostics:
fit_logit2$diagnostic_summary()$num_divergent
[1] 0 0 0 0
$num_max_treedepth
[1] 0 0 0 0
$ebfmi
[1] 1.088640 1.205378 1.011549 1.004046
We check \widehat{R} end ESS values, which in this case all look good.
draws <- as_draws_rvars(fit_logit2$draws())
summarize_draws(draws)# A tibble: 3 × 10
variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
<chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 lp__ -2.6 -2.3 1.0 0.79 -4.7 -1.6 1.00 1231. 1462.
2 alpha 6.1 5.8 2.9 2.9 1.9 11. 1.0 862. 907.
3 beta 15. 14. 6.1 6.2 5.3 26. 1.0 839. 979.
The following figure shows the more reasonable marginal histograms and joint scatterplots of the posterior sample.
mcmc_pairs(as_draws_array(draws), pars = c("alpha", "beta"),
off_diag_args = list(alpha = 0.2))3 A model with unused parameter
When writing and editing models, a common mistake is to declare a parameter, but not use it in the model. If the parameter is not used at all, it doesn’t have proper prior and the likelihood doesn’t provide information about that parameter, and thus the posterior along that parameter is improper. We use the previous logistic regression model with proper priors on alpha and beta, but include extra parameter declaration real gamma.
3.1 Model
code_logit3 <- root("problems", "logit_glm3.stan")print_stan_file(code_logit3)// logistic regression
data {
int<lower=0> N;
int<lower=0> M;
array[N] int<lower=0,upper=1> y;
matrix[N,M] x;
}
parameters {
real alpha;
vector[M] beta;
real gamma; // intentionally not used in the model
}
model {
alpha ~ normal(0,1);
beta ~ normal(0,1);
y ~ bernoulli_logit_glm(x, alpha, beta);
}Sample
mod_logit3 <- cmdstan_model(stan_file = code_logit3)
fit_logit3 <- mod_logit3$sample(data = data_logit, seed = SEED, refresh = 0)3.2 Convergence diagnostics
There is sampler warning. We can also explicitly call inference diagnostics:
fit_logit3$diagnostic_summary()$num_divergent
[1] 0 0 0 0
$num_max_treedepth
[1] 432 409 391 435
$ebfmi
[1] 1.200907 1.383493 1.207598 1.316549
Instead of increasing max_treedepth, we check the other convergence diagnostics.
draws <- as_draws_rvars(fit_logit3$draws())
summarize_draws(draws)# A tibble: 4 × 10
variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
<chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 lp__ -6.6e+ 0 -6.3e+ 0 1.0e+ 0 7.2e- 1 -8.6e+ 0 -5.6e 0 1.0 1951. 2393.
2 alpha 3.4e- 1 3.3e- 1 6.1e- 1 6.2e- 1 -6.4e- 1 1.3e 0 1.0 3532. 2768.
3 beta 1.3e+ 0 1.2e+ 0 7.6e- 1 7.7e- 1 3.7e- 2 2.6e 0 1.0 3213. 2971.
4 gamma -6.0e+19 1.4e+18 1.3e+20 4.9e+19 -3.5e+20 4.4e19 2.4 5.3 16.
\widehat{R}, Bulk-ESS, and Tail-ESS look good for alpha and beta, but really bad forgamma, clearly pointing where to look for problems in the model code. The histogram ofgamma` posterior draws show huge magnitude of values (values larger than 10^{20}) indicating improper posterior.
mcmc_pairs(as_draws_array(draws), pars = c("alpha", "beta", "gamma"),
off_diag_args = list(alpha = 0.2))Non-mixing is well diagnosed by \widehat{R} and ESS, but the following Figure shows one of the rare cases where trace plots are useful to illustrate the type of non-mixing in case of improper uniform posterior for one the parameters.
mcmc_trace(as_draws_array(draws), pars = c("gamma"))3.3 Stan compiler pedantic check
Stan compiler pedantic check also recognizes that parameter gamma was declared but was not used in the density calculation.
mod_logit3$check_syntax(pedantic = TRUE)4 A posterior with two parameters competing
Sometimes the models have two or more parameters that have similar or exactly the same role. We illustrate this by adding an extra column to the previous data matrix. Sometimes the data matrix is augmented with a column of 1’s to present the intercept effect. In this case that is redundant as our model has the explicit intercept term alpha, and this redundancy will lead to problems.
4.1 Data
M <- 2
N <- 1000
x <- matrix(c(rep(1, N), sort(rnorm(N))), ncol = M)
y <- ((x[, 1] + rnorm(N) / 2) > 0) + 0
data_logit4 <- list(M = M, N = N, x = x, y = y)4.2 Model
We use the previous logistic regression model with proper priors (and no extra gamma).
code_logit2 <- root("problems", "logit_glm2.stan")print_stan_file(code_logit2)// logistic regression
data {
int<lower=0> N;
int<lower=0> M;
array[N] int<lower=0,upper=1> y;
matrix[N,M] x;
}
parameters {
real alpha;
vector[M] beta;
}
model {
alpha ~ normal(0,10);
beta ~ normal(0,10);
y ~ bernoulli_logit_glm(x, alpha, beta);
}Sample
mod_logit4 <- cmdstan_model(stan_file = code_logit2)
fit_logit4 <- mod_logit4$sample(data = data_logit4, seed = SEED, refresh = 0)The Stan sampling time per chain with the original data matrix was less than 0.1s per chain. Now the Stan sampling time per chain is several seconds, which is suspicious. There are no automatic convergence diagnostic warnings and checking other diagnostics don’t show anything really bad.
4.3 Convergence diagnostics
fit_logit4$diagnostic_summary()$num_divergent
[1] 0 0 0 0
$num_max_treedepth
[1] 0 0 0 0
$ebfmi
[1] 1.0913332 1.0802625 1.0307895 0.9725048
draws <- as_draws_rvars(fit_logit4$draws())
summarize_draws(draws)# A tibble: 4 × 10
variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
<chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 lp__ -79. -79. 1.3 1.0 -82. -78. 1.0 1243. 1866.
2 alpha 2.3 2.3 7.4 7.5 -9.5 15. 1.0 1014. 867.
3 beta[1] 2.0 2.0 7.4 7.6 -10. 14. 1.0 1013. 883.
4 beta[2] 0.24 0.25 0.26 0.26 -0.17 0.65 1.0 1586. 1614.
ESS estimates are above the recommended diagnostic thresholds (Vehtari et al. 2021), but lower than what we would expect in general from Stan for such a lower dimensional problem.
The following figure shows marginal histograms and joint scatterplots, and we can see that alpha and beta[1] are highly correlated.
mcmc_pairs(as_draws_array(draws), pars = c("alpha", "beta[1]", "beta[2]"),
off_diag_args = list(alpha = 0.2))We can compute the correlation.
cor(as_draws_matrix(draws)[, c("alpha", "beta[1]")])[1, 2][1] -0.9992898
The numerical value for the correlation is -0.999. The correlation close to 1 can happen also from other reasons (see the next example), but one possibility is that parameters have similar role in the model. Here the reason is the constant column in x, which we put there for the demonstration purposes. We may a have constant column, for example, if the predictor matrix is augmented with the intercept predictor, or if the observed data or subdata used in the specific analysis happens to have only one unique value.
4.4 Stan compiler pedantic check
Stan compiler pedantic check examining the code can’t recognize this issue, as the problem depends also on the data.
mod_logit4$check_syntax(pedantic = TRUE)5 A posterior with very high correlation
In the previous example the two parameters had the same role in the model, leading to high posterior correlation. High posterior correlations are common also in linear models when the predictor values are far from 0. We illustrate this with a linear regression model for the summer temperature in Kilpisjärvi, Finland, 1952–2013. We use the year as the covariate x without centering it.
5.1 Data
The data are Kilpisjärvi summer month temperatures 1952-2013 measured by Finnish Meteorological Institute.
data_kilpis <- read.delim(root("problems/data", "kilpisjarvi-summer-temp.csv"), sep = ";")
data_lin <- list(M = 1,
N = nrow(data_kilpis),
x = matrix(data_kilpis$year, ncol = 1),
y = data_kilpis[, 5])data.frame(data_lin) |>
ggplot(aes(x, y)) +
geom_point(size = 1) +
labs(y = 'Summer temp. @Kilpisjärvi', x = "Year") +
guides(linetype = "none")5.2 Model
We use the following Stan linear regression model
code_lin <- root("problems", "linear_glm_kilpis.stan")print_stan_file(code_lin)// logistic regression
data {
int<lower=0> N;
int<lower=0> M;
vector[N] y;
matrix[N,M] x;
}
parameters {
real alpha;
vector[M] beta;
real<lower=0> sigma;
}
model {
alpha ~ normal(0, 100);
beta ~ normal(0, 100);
sigma ~ normal(0, 1);
y ~ normal_id_glm(x, alpha, beta, sigma);
}mod_lin <- cmdstan_model(stan_file = code_lin)
fit_lin <- mod_lin$sample(data = data_lin, seed = SEED, refresh = 0)5.3 Convergence diagnostics
Stan gives a warning: There were X transitions after warmup that exceeded the maximum treedepth. As in the previous example, there are no other warnings.
fit_lin$diagnostic_summary()$num_divergent
[1] 0 0 0 0
$num_max_treedepth
[1] 0 0 15 0
$ebfmi
[1] 0.9888201 0.9097830 1.0498057 1.0725453
draws <- as_draws_rvars(fit_lin$draws())
summarize_draws(draws)# A tibble: 4 × 10
variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
<chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 lp__ -38. -38. 1.3 0.98 -41. -37. 1.0 1127. 1156.
2 alpha -31. -31. 16. 15. -57. -5.3 1.0 1021. 1232.
3 beta 0.020 0.020 0.0080 0.0076 0.0074 0.034 1.0 1021. 1232.
4 sigma 1.1 1.1 0.10 0.100 0.97 1.3 1.0 1259. 1236.
ESS estimates are above the diagnostic threshold, but lower than we would expect for such a low dimensional model, unless there are strong posterior correlations. The following Figure shows the marginal histograms and joint scatterplot for alpha and beta[1], which shows they are very highly correlated.
mcmc_pairs(as_draws_array(draws), pars = c("alpha", "beta"),
off_diag_args = list(alpha = 0.2))Here the reason is that the x values are in the range 1952–2013, and the intercept alpha denotes the temperature at year 0, which is very far away from the range of observed x. If the intercept alpha changes, the slope beta needs to change too. The high correlation makes the inference slower, and we can make it faster by centering x. Here we simply subtract 1982.5 from the predictor year, so that the mean of x is 0. We could also include the centering and back transformation to Stan code.
5.4 Centered data
data_lin <- list(
M = 1,
N = nrow(data_kilpis),
x = matrix(data_kilpis$year - 1982.5, ncol = 1),
y = data_kilpis[, 5]
)fit_lin <- mod_lin$sample(data = data_lin, seed = SEED, refresh = 0)5.5 Convergence diagnostics
We check the diagnostics
fit_lin$diagnostic_summary()$num_divergent
[1] 0 0 0 0
$num_max_treedepth
[1] 0 0 0 0
$ebfmi
[1] 1.045172 1.039249 1.048480 1.096556
draws <- as_draws_rvars(fit_lin$draws())
summarize_draws(draws)# A tibble: 4 × 10
variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
<chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 lp__ -38. -38. 1.3 1.0 -41. -37. 1.00 1903. 2526.
2 alpha 9.3 9.3 0.15 0.14 9.1 9.6 1.0 3521. 2726.
3 beta 0.020 0.020 0.0081 0.0080 0.0070 0.034 1.00 3965. 2755.
4 sigma 1.1 1.1 0.10 0.100 0.97 1.3 1.0 3055. 2502.
The following figure shows the scatter plot.
mcmc_pairs(as_draws_array(draws), pars = c("alpha", "beta"),
off_diag_args = list(alpha = 0.2))With this change, there is no posterior correlation, Bulk-ESS estimates are 3 times bigger, and the mean time per chain goes from 1.3s to less than 0.05s; that is, we get 2 orders of magnitude faster inference. In a bigger problems this could correspond to reduction of computation time from 24 hours to less than 20 minutes.
6 A bimodal posterior
Bimodal distributions can arise from many reasons as in mixture models or models with non-log-concave likelihoods or priors (that is, with distributions with thick tails). We illustrate the diagnostics revealing the multimodal posterior. We use a simple toy problem with t model and data that is not from a t distribution, but from a mixture of two normal distributions
6.1 Data
Bimodally distributed data
N <- 20
y <- c(rnorm(N / 2, mean = -5, sd = 1), rnorm(N / 2, mean = 5, sd = 1))
data_tt <- list(N = N, y = y)6.2 Model
Unimodal Student’s t model:
code_tt <- root("problems", "student.stan")print_stan_file(code_tt)// student-student model
data {
int<lower=0> N;
vector[N] y;
}
parameters {
real mu;
}
model {
mu ~ student_t(4, 0, 100);
y ~ student_t(4, mu, 1);
}Sample
mod_tt <- cmdstan_model(stan_file = code_tt)
fit_tt <- mod_tt$sample(data = data_tt, seed = SEED, refresh = 0)6.3 Convergence diagnostics
We check the diagnostics
fit_tt$diagnostic_summary()$num_divergent
[1] 0 0 0 0
$num_max_treedepth
[1] 0 0 0 0
$ebfmi
[1] 1.0855996 1.2158209 0.9634865 1.0107784
draws <- as_draws_rvars(fit_tt$draws())
summarize_draws(draws)# A tibble: 2 × 10
variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
<chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 lp__ -84. -83. 0.86 0.74 -85. -83. 1.2 12. 118.
2 mu -0.32 -0.20 4.2 6.2 -5.0 4.4 1.7 6.1 200.
The \widehat{R} values for mu are large and ESS values for mu are small indicating convergence problems. The following figure shows the histogram and trace plots of the posterior draws, clearly showing the bimodality and that chains are not mixing between the modes.
mcmc_hist(as_draws_array(draws), pars = c("mu"))In this toy example, with random initialization each chains has 50% probability of ending in either mode. We used Stan’s default of 4 chains, and when random initialization is used, there is 6% chance that when running Stan once, we would miss the multimodality. If the attraction areas within the random initialization range are not equal, the probability of missing one mode is even higher. There is a tradeoff between the default computation cost and cost of having higher probability of finding multiple modes. If there is a reason to suspect multimodality, it is useful to run more chains. Running more chains helps to diagnose the multimodality, but the probability of chains ending in different modes can be different from the relative probability mass of each mode, and running more chains doesn’t fix this. Other means are needed to improve mixing between the modes (e.g. Yao et al., 2020) or to approximately weight the chains (e.g. Yao et al., 2022).
7 Easy bimodal posterior
If the modes in the bimodal distribution are not strongly separated, MCMC can jump from one mode to another and there are no convergence issues.
N <- 20
y <- c(rnorm(N / 2, mean = -3, sd = 1), rnorm(N / 2, mean = 3, sd = 1))
data_tt <- list(N = N, y = y)fit_tt <- mod_tt$sample(data = data_tt, seed = SEED, refresh = 0)7.1 Convergence diagnostics
We check the diagnostics
fit_tt$diagnostic_summary()$num_divergent
[1] 0 0 0 0
$num_max_treedepth
[1] 0 0 0 0
$ebfmi
[1] 1.274083 1.338079 1.268666 1.310952
draws <- as_draws_rvars(fit_tt$draws())
summarize_draws(draws)# A tibble: 2 × 10
variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
<chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 lp__ -60. -59. 0.60 0.52 -61. -59. 1.0 1013. 1023.
2 mu 1.0 1.2 1.1 1.1 -1.0 2.5 1.0 866. 682.
Two modes are visible.
mcmc_hist(as_draws_array(draws), pars = c("mu"))Trace plot is not very useful. It shows the chains are jumping between modes, but it’s difficult to see whether the jumps happen often enough and chains are mixing well.
mcmc_trace(as_draws_array(draws), pars = c("mu"))Rank ECDF plot (Säilynoja, Bürkner, and Vehtari 2022) indicates good mixing as all chains have their lines inside the envelope (the envelope assumes no autocorrelation, which is the reason to thin the draws here)
draws |> thin_draws(ndraws(draws) / ess_basic(draws$mu)) |>
mcmc_rank_ecdf(pars = c("mu"), plot_diff = TRUE)8 Initial value issues
MCMC requires some initial values. By default, Stan generates them randomly from [-2,2] in unconstrained space (constraints on parameters are achieved by transformations). Sometimes these initial values can be bad and cause numerical issues. Computers, (in general) use finite number of bits to present numbers and with very small or large numbers, there can be problems of presenting them or there can be significant loss of accuracy.
The data is generated from a Poisson regression model. The Poisson intensity parameter has to be positive and usually the latent linear predictor is exponentiated to be positive (the exponentiation can also be justified by multiplicative effects on Poisson intensity).
We intentionally generate the data so that there are initialization problems, but the same problem is common with real data when the scale of the predictors is large or small compared to the unit scale. The following figure shows the simulated data.
8.1 Data
set.seed(SEED)
M <- 1
N <- 20
x <- 1e3 * matrix(c(sort(rnorm(N))), ncol = M)
y <- rpois(N, exp(1e-3 * x[, 1]))
data_pois <- list(M = M, N = N, x = x, y = y)data.frame(data_pois) |>
ggplot(aes(x, y)) +
geom_point(size = 3)8.2 Model
We use a Poisson regression model with proper priors. The line poisson_log_glm(x, alpha, beta) corresponds to a distribution in which the log intensity of the Poisson distribution is modeled with alpha + beta * x but is implemented with better computational efficiency.
code_pois <- root("problems", "pois_glm.stan")print_stan_file(code_pois)// Poisson regression
data {
int<lower=0> N;
int<lower=0> M;
array[N] int<lower=0> y;
matrix[N,M] x;
}
parameters {
real alpha;
vector[M] beta;
}
model {
alpha ~ normal(0,10);
beta ~ normal(0,10);
y ~ poisson_log_glm(x, alpha, beta);
}Sample
mod_pois <- cmdstan_model(stan_file = code_pois)
fit_pois <- mod_pois$sample(data = data_pois, seed = SEED, refresh = 0)We get a lot of warnings:
Chain 4 Rejecting initial value:
Chain 4 Log probability evaluates to log(0), i.e. negative infinity.
Chain 4 Stan can't start sampling from this initial value.
8.3 Convergence diagnostics
We check the diagnostics:
fit_pois$diagnostic_summary()$num_divergent
[1] 0 0 0 0
$num_max_treedepth
[1] 2 0 319 0
$ebfmi
[1] 0.016884 1.211112 1.166652 1.186390
draws <- as_draws_rvars(fit_pois$draws())
summarize_draws(draws)# A tibble: 3 × 10
variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
<chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 lp__ -7.7e+41 8.9 2.2e+42 1.3 -1.9e+42 10.0 1.6 6.8 11.
2 alpha 2.7e- 1 0.068 6.5e- 1 0.45 -5.1e- 1 1.3 1.5 7.4 NA
3 beta 1.1e- 2 0.0011 1.7e- 2 0.00030 7.5e- 4 0.041 1.6 6.8 11.
\widehat{R} values are large and ESS values are small, indicating bad mixing. Marginal histograms and joint scatterplots of the posterior draws in the figure below clearly show that two chains have been stuck away from two others.
mcmc_pairs(as_draws_array(draws), pars = c("alpha", "beta"),
off_diag_args = list(alpha = 0.2))The reason for the issue is that the initial values for beta are sampled from (-2, 2) and x has some large values. If the initial value for beta is higher than about 0.3 or lower than -0.4, some of the values of exp(alpha + beta * x) will overflow to floating point infinity (Inf).
8.4 Scaled data
Sometimes an easy option is to change the initialization range. For example, in this the sampling succeeds if the initial values are drawn from the range (-0.001, 0.001). Alternatively we can scale x to have scale close to unit scale. After this scaling, the computation is fast and all convergence diagnostics look good.
data_pois <- list(M = M, N = N, x = x / 1e3, y = y)data.frame(data_pois) |>
ggplot(aes(x, y)) +
geom_point(size = 3)mod_pois <- cmdstan_model(stan_file = code_pois)
fit_pois <- mod_pois$sample(data = data_pois, seed = SEED, refresh = 0)8.5 Convergence diagnostics
We check the diagnostics:
fit_pois$diagnostic_summary()$num_divergent
[1] 0 0 0 0
$num_max_treedepth
[1] 0 0 0 0
$ebfmi
[1] 1.094493 1.040770 1.061937 1.146358
draws <- as_draws_rvars(fit_pois$draws())
summarize_draws(draws)# A tibble: 3 × 10
variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
<chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 lp__ 9.1 9.4 0.99 0.68 7.2 10. 1.0 1029. 1833.
2 alpha -0.036 -0.029 0.26 0.26 -0.48 0.38 1.0 1051. 1279.
3 beta 1.0 1.0 0.18 0.18 0.70 1.3 1.0 973. 1199.
If the initial value warning comes only once, it is possible that MCMC was able to escape the bad region and rest of the inference is ok.
9 Thick tailed posterior
We return to the logistic regression example with separable data. Now we use proper, but thick tailed Cauchy prior.
9.1 Model
code_logit_glm4 <- root("problems", "logit_glm4.stan")print_stan_file(code_logit_glm4)// logistic regression
data {
int<lower=0> N;
int<lower=0> M;
array[N] int<lower=0,upper=1> y;
matrix[N,M] x;
}
parameters {
real alpha;
vector[M] beta;
}
model {
alpha ~ cauchy(0, 10);
beta ~ cauchy(0, 10);
y ~ bernoulli_logit_glm(x, alpha, beta);
}Sample
mod_logit_glm4 <- cmdstan_model(stan_file = code_logit_glm4)
fit_logit_glm4 <- mod_logit_glm4$sample(data = data_logit, seed = SEED, refresh = 0)9.2 Convergence diagnostics
We check diagnostics
fit_logit_glm4$diagnostic_summary()$num_divergent
[1] 0 0 0 0
$num_max_treedepth
[1] 0 0 0 0
$ebfmi
[1] 0.1913679 0.1705309 0.7056304 0.7055436
draws <- as_draws_rvars(fit_logit_glm4$draws())
summarize_draws(draws)# A tibble: 3 × 10
variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
<chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 lp__ -4.2 -3.3 2.6 1.6 -9.7 -1.9 1.0 173. 63.
2 alpha 21. 10. 46. 7.9 2.5 73. 1.0 162. 61.
3 beta 51. 24. 112. 19. 6.7 165. 1.0 160. 60.
The rounded \widehat{R} values look good, ESS values are low. Looking at the marginal histograms and joint scatterplots of the posterior draws in the following figure show a thick tail.
mcmc_pairs(as_draws_array(draws), pars = c("alpha", "beta"),
off_diag_args = list(alpha = 0.2))The dynamic HMC algorithm used by Stan, along with many other MCMC methods, have problems with such thick tails and mixing is slow.
Rank ECDF plot indicates good mixing as all chains have their lines inside the envelope (the envelope assumes no autocorrelation, which is the reason to thin the draws here)
draws |> thin_draws(ndraws(draws) / ess_bulk(draws$alpha)) |>
mcmc_rank_ecdf(pars = c("alpha"), plot_diff = TRUE)More iterations confirm a reasonable mixing.
fit_logit_glm4 <- mod_logit_glm4$sample(data = data_logit, seed = SEED, refresh = 0,
iter_sampling = 4000)
draws <- as_draws_rvars(fit_logit_glm4$draws())
summarize_draws(draws)draws |> thin_draws(ndraws(draws) / ess_bulk(draws$alpha)) |>
mcmc_rank_ecdf(pars = c("alpha"), plot_diff = TRUE)10 Funnel
A special case of varying curvature is known as the funnel, based on the shape of the typical set of the distribution. Consider a hierarchical model, y_i\sim\normal(\mu_{k[i]},\sigma) for i=1,\dots,N and group membership variable k[i] which takes on values from 1 through K. We shall assume the prior distribution, \mu_k\sim\normal(\mu_0,\sigma_0),j=1,\dots,J. When plotted on the scale \mu_1,\dots,\mu_K, \log\sigma_0, this prior can be visualized as a having shape of a funnel
If the funnel-shaped prior is combined with a weak likelihood, the posterior is also funnel shaped. As a toy example, we use the Kilpisjärvi temperature data, with each group being one year, with three summer month temperatures per year. With only three observations per group, the likelihood is weak for each \mu_k and the prior is likely to dominate the posterior shape. The number of groups is 71, and this high dimensionality makes the funnel challenging.
data_kilpis <- read.delim(root("problems/data", "kilpisjarvi-summer-temp.csv"), sep = ";")
data_grpy <-list(N = length(data_kilpis$year)*ncol(data_kilpis[,2:4]),
K = length(data_kilpis$year),
x = rep(1:length(data_kilpis$year), ncol(data_kilpis[,2:4])),
y = c(t(t(data_kilpis[,2:4]))))Here is a direct implementation of the hierarchical model in Stan. The parameterization used is also known as centered parameterization.
code_hier_cp <- root("problems", "hier_cp.stan")print_stan_file(code_hier_cp)data {
int<lower=0> N; // number of observations
int<lower=0> K; // number of groups
array[N] int<lower=1, upper=K> x; // discrete group indicators
vector[N] y; // real valued observations
}
parameters {
real mu0; // prior mean
real<lower=0> sigma0; // prior sd
vector[K] mu; // group means
real<lower=0> sigma; // common sd
}
model {
mu0 ~ normal(10, 10); // weakly informative prior
sigma0 ~ normal(0, 10); // weakly informative prior
mu ~ normal(mu0, sigma0); // population prior with unknown parameters
sigma ~ lognormal(0, .5); // weakly informative prior
y ~ normal(mu[x], sigma); // observation model
}We first try running Stan with its default settings.
SEED <- 48929
mod_hier_cp <- cmdstan_model(stan_file = code_hier_cp)
fit_hier_cp <- mod_hier_cp$sample(data = data_grpy, seed = SEED, refresh = 0)We get a warning that some transitions ended with a divergence. The convergence diagnostics \widehat{R}, bulk-ESS, and tail-ESS reveal that the chains are not mixing well:
fit_hier_cp variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
lp__ -173.30 -181.28 42.76 48.98 -232.77 -101.57 1.17 18 85
mu0 9.32 9.32 0.16 0.18 9.06 9.57 1.03 141 441
sigma0 0.32 0.27 0.21 0.23 0.07 0.72 1.17 19 36
mu[1] 9.23 9.26 0.37 0.30 8.56 9.77 1.02 392 1131
mu[2] 9.46 9.41 0.40 0.29 8.91 10.17 1.02 271 805
mu[3] 9.33 9.33 0.39 0.30 8.69 9.96 1.02 591 1302
mu[4] 9.21 9.26 0.40 0.31 8.47 9.78 1.02 416 727
mu[5] 9.22 9.26 0.39 0.31 8.48 9.78 1.02 302 658
mu[6] 9.17 9.22 0.41 0.32 8.41 9.74 1.02 322 855
mu[7] 9.25 9.28 0.37 0.30 8.61 9.81 1.01 461 1212
# showing 10 of 66 rows (change via 'max_rows' argument or 'cmdstanr_max_rows' option)
Plot scatter plot of \mu_1 vs \log\sigma_0 with divergences shown in red (bayesplot)
np <- fit_hier_cp$sampler_diagnostics(format="df")|>
mutate(Chain=.chain,Iteration=.iteration,Parameter="divergent__",Value=divergent__)|>
select(Chain,Iteration,Parameter,Value)
fit_hier_cp$draws(format="df")|>
mutate(log_sigma0=log(sigma0))|>
mcmc_scatter(pars=c("mu[1]","sigma0"),transform=list(sigma0="log"), alpha=0.1, shape=20,
np=np, np_style = scatter_style_np(div_shape = 18, div_size = 3)) +
labs(x=TeX(r"($\mu_1$)"), y=TeX(r"($\log\,\sigma_0$)"), title="a) Centered param.")Plot scatter plot of \mu_1 vs \log\sigma_0 with divergences shown in red (ggplot)
drws <- bind_draws(fit_hier_cp$draws(format="df"), fit_hier_cp$sampler_diagnostics(format="df")) |>
mutate(log_sigma0=log(sigma0))
p1 <- ggplot(data = NULL, aes(x=`mu[1]`,`log_sigma0`)) +
geom_point(data = drws |> filter(divergent__==0), shape = 20, color = bayesplot:::get_color("m"), fill = bayesplot:::get_color("lh"), alpha = 0.1, size = 2) +
geom_point(data = drws |> filter(divergent__==1), shape = 23, fill = "red", color = "white", size = 2) +
labs(x=TeX(r"($\mu_1$)"), y=TeX(r"($\log\,\sigma_0$)"), title="c) Centered param.")
p1We change the adapt_delta tuning parameter of the NUTS algorithm to force a smaller step size.
fit_hier_cp_999 <- mod_hier_cp$sample(data = data_grpy, seed = SEED, refresh=0, adapt_delta=0.999)However, the convergence diagnostics still indicate serious mixing problems:
fit_hier_cp_999 variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
lp__ -180.01 -186.80 40.75 44.39 -235.15 -103.47 1.13 28 128
mu0 9.30 9.30 0.16 0.16 9.04 9.57 1.03 168 630
sigma0 0.34 0.31 0.21 0.23 0.07 0.74 1.13 28 134
mu[1] 9.21 9.24 0.39 0.30 8.50 9.79 1.02 682 827
mu[2] 9.46 9.40 0.42 0.34 8.90 10.25 1.03 328 789
mu[3] 9.30 9.30 0.39 0.31 8.65 9.96 1.03 971 840
mu[4] 9.19 9.22 0.40 0.31 8.48 9.78 1.03 473 696
mu[5] 9.19 9.23 0.40 0.32 8.45 9.79 1.04 454 954
mu[6] 9.15 9.20 0.42 0.33 8.37 9.74 1.03 331 681
mu[7] 9.24 9.25 0.38 0.31 8.57 9.82 1.02 909 1284
# showing 10 of 66 rows (change via 'max_rows' argument or 'cmdstanr_max_rows' option)
Plot scatter plot of \mu_1 vs \log\sigma_0 with divergences shown in red (bayesplot)
np <- fit_hier_cp_999$sampler_diagnostics(format="df")|>
mutate(Chain=.chain,Iteration=.iteration,Parameter="divergent__",Value=divergent__)|>
select(Chain,Iteration,Parameter,Value)
fit_hier_cp_999$draws(format="df")|>
mutate(log_sigma0=log(sigma0))|>
mcmc_scatter(pars=c("mu[1]","sigma0"),transform=list(sigma0="log"), alpha=0.1, shape=20,
np=np, np_style = scatter_style_np(div_shape = 18, div_size = 3)) +
labs(x=TeX(r"($\mu_1$)"), y=TeX(r"($\log\,\sigma_0$)"), title="a) Centered param. + higher adapt_delta")Plot scatter plot of \mu_1 vs \log\sigma_0 with divergences shown in red (ggplot)
drws <- bind_draws(fit_hier_cp_999$draws(format="df"), fit_hier_cp_999$sampler_diagnostics(format="df")) |>
mutate(log_sigma0=log(sigma0))
p2 <- ggplot(data = NULL, aes(x=`mu[1]`,`log_sigma0`)) +
geom_point(data = drws |> filter(divergent__==0), shape = 20, color = bayesplot:::get_color("m"), fill = bayesplot:::get_color("lh"), alpha = 0.1, size = 2) +
geom_point(data = drws |> filter(divergent__==1), shape = 23, fill = "red", color = "white", size = 2) +
labs(x=TeX(r"($\mu_1$)"), y=TeX(r"($\log\,\sigma_0$)"), title="b) Centered param. + higher adapt_delta")
p2The usual approach to resolve the funnel problem is to change how the model is parameterized. The so-called non-centered parameterization provides the same model, but the sampling happens in a transformed space that does not have the difficult funnel geometry.
code_hier_ncp <- root("problems", "hier_ncp.stan")print_stan_file(code_hier_ncp)data {
int<lower=0> N; // number of observations
int<lower=0> K; // number of groups
array[N] int<lower=1, upper=K> x; // discrete group indicators
vector[N] y; // real valued observations
}
parameters {
real mu0; // prior mean
real<lower=0> sigma0; // prior sd
vector[K] z; // latent variable
real<lower=0> sigma; // common sd
}
transformed parameters {
vector[K] mu = mu0 + sigma0 * z; // group means
}
model {
mu0 ~ normal(10, 10); // weakly informative prior
sigma0 ~ normal(0, 10); // weakly informative prior
z ~ normal(0, 1); // unit normal
sigma ~ lognormal(0, .5); // weakly informative prior
y ~ normal(mu[x], sigma); // observation model
}We run Stan with its default settings.
mod_hier_ncp <- cmdstan_model(stan_file = code_hier_ncp)
fit_hier_ncp <- mod_hier_ncp$sample(data = data_grpy, seed = SEED, refresh=0)The convergence diagnostics \widehat{R}, bulk-ESS, and tail-ESS look good now.
fit_hier_ncp variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
lp__ -259.64 -259.58 6.98 6.85 -271.16 -248.13 1.00 984 1612
mu0 9.32 9.32 0.16 0.16 9.06 9.58 1.00 3970 2437
sigma0 0.33 0.30 0.22 0.24 0.03 0.73 1.00 1358 1880
z[1] -0.22 -0.22 0.97 0.97 -1.81 1.40 1.00 5465 2742
z[2] 0.36 0.37 1.00 1.00 -1.33 1.98 1.00 5659 2746
z[3] 0.02 0.02 0.98 0.99 -1.57 1.59 1.00 6581 2317
z[4] -0.26 -0.29 0.97 0.99 -1.84 1.35 1.00 6780 2817
z[5] -0.24 -0.24 0.95 0.97 -1.78 1.33 1.00 7142 2963
z[6] -0.33 -0.35 0.97 0.98 -1.88 1.28 1.00 5486 2815
z[7] -0.16 -0.15 0.96 0.97 -1.72 1.42 1.00 5771 2806
# showing 10 of 128 rows (change via 'max_rows' argument or 'cmdstanr_max_rows' option)
Plot scatter plot of \mu_1 vs \log\sigma_0 with divergences shown in red (bayesplot)
np <- fit_hier_ncp$sampler_diagnostics(format="df")|>
mutate(Chain=.chain,Iteration=.iteration,Parameter="divergent__",Value=divergent__)|>
select(Chain,Iteration,Parameter,Value)
fit_hier_ncp$draws(format="df")|>
mutate(log_sigma0=log(sigma0))|>
mcmc_scatter(pars=c("mu[1]","sigma0"),transform=list(sigma0="log"), alpha=0.1, shape=20,
np=np, np_style = scatter_style_np(div_shape = 18, div_size = 3)) +
labs(x=TeX(r"($\mu_1$)"), y=TeX(r"($\log\,\sigma_0$)"), title="a) Non-centered param.")Plot scatter plot of \mu_1 vs \log\sigma_0 with divergences shown in red (ggplot)
drws <- bind_draws(fit_hier_ncp$draws(format="df"), fit_hier_ncp$sampler_diagnostics(format="df")) |>
mutate(log_sigma0=log(sigma0))
p3 <- ggplot(data = NULL, aes(x=`mu[1]`,`log_sigma0`)) +
geom_point(data = drws |> filter(divergent__==0), shape = 20, color = bayesplot:::get_color("m"), fill = bayesplot:::get_color("lh"), alpha = 0.1, size = 2) +
geom_point(data = drws |> filter(divergent__==1), shape = 23, fill = "red", color = "white", size = 2) +
labs(x=TeX(r"($\mu_1$)"), y=TeX(r"($\log\,\sigma_0$)"), title="c) Non-centered param.")
p3If we compare the scatter plots side by side, we clearly see that increasing adapt_delta and getting rid of divergences did not solve the funnel problem and the posterior estimates with centered parameterization would be biased.
(p1 + p2 + p3) *
scale_y_continuous(lim=c(-8,0.3)) *
scale_x_continuous(lim=c(7.2,11.8)) *
theme(plot.title = element_text(size=16)) +
plot_layout(axis_titles="collect_y") 11 Variance parameter that is not constrained to be positive
Demonstration what happens if we forget to constrain a parameter that has to be positive. In Stan the constraint can be added when declaring the parameter as real<lower=0> sigma;
11.1 Data
We simulated x and y independently from independently from normal(0,1) and normal(0,0.1) respectively. As N=8 is small, there will be a lot of uncertainty about the parameters including the scale sigma.
M <- 1
N <- 8
set.seed(SEED)
x <- matrix(rnorm(N), ncol = M)
y <- rnorm(N) / 10
data_lin <- list(M = M, N = N, x = x, y = y)11.2 Model
We use linear regression model with proper priors.
code_lin <- root("problems", "linear_glm.stan")print_stan_file(code_lin)// logistic regression
data {
int<lower=0> N;
int<lower=0> M;
vector[N] y;
matrix[N,M] x;
}
parameters {
real alpha;
vector[M] beta;
real sigma; // intentionally missing lower=0
}
model {
alpha ~ normal(0, 1);
beta ~ normal(0, 1);
sigma ~ normal(0, 1);
y ~ normal_id_glm(x, alpha, beta, sigma);
}Sample
mod_lin <- cmdstan_model(stan_file = code_lin)
fit_lin <- mod_lin$sample(data = data_lin, seed = SEED, refresh = 0)We get many times the following warnings
Chain 4 Informational Message: The current Metropolis proposal is about to be rejected because of the following issue:
Chain 4 Exception: normal_id_glm_lpdf: Scale vector is -0.747476, but must be positive finite! (in '/tmp/RtmprEP4gg/model-7caa12ce8e405.stan', line 16, column 2 to column 43)
Chain 4 If this warning occurs sporadically, such as for highly constrained variable types like covariance matrices, then the sampler is fine,
Chain 4 but if this warning occurs often then your model may be either severely ill-conditioned or misspecified.
Sometimes these warnings appear in early phase of the sampling, even if the model has been correctly defined. Now we have too many of them, which indicates the samples is trying to jump to infeasible values, which here means the negative scale parameter values. Many rejections may lead to biased estimates.
There are some divergences reported, which is also indication that there might be some problem (as divergence diagnostic has an ad hoc diagnostic threshold, there can also be false positive warnings). Other convergence diagnostics are good, but due to many rejection warnings, it is good to check the model code and numerical accuracy of the computations.
11.3 Convergence diagnostics
We check diagnostics
fit_lin$diagnostic_summary()$num_divergent
[1] 0 3 0 2
$num_max_treedepth
[1] 0 0 0 0
$ebfmi
[1] 0.5917343 0.6731945 0.7169593 0.6044464
draws <- as_draws_rvars(fit_lin$draws())
summarize_draws(draws)# A tibble: 4 × 10
variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
<chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 lp__ 17. 17. 1.8 1.4 13. 19. 1.0 835. 833.
2 alpha -0.015 -0.015 0.032 0.027 -0.066 0.036 1.0 2344. 1770.
3 beta -0.0026 -0.0026 0.031 0.027 -0.052 0.047 1.0 2035. 1270.
4 sigma 0.086 0.077 0.036 0.025 0.049 0.16 1.0 1115. 940.
11.4 Stan compiler pedantic check
Stan compiler pedantic check can recognize that `A normal_id_glm distribution is given parameter sigma as a scale parameter (argument 4), but sigma was not constrained to be strictly positive. The pedantic check is also warning about the very wide priors.
mod_lin$check_syntax(pedantic = TRUE)After fixing the model with proper parameter constraint, MCMC runs without warnings and the sampling efficiency is better. In this specific case, the bias is negligible when running MCMC with the model code without the constraint, but it is difficult to diagnose without running the fixed model.
Fixed model includes <lower=0> constraint for sigma.
code_lin2 <- root("problems", "linear_glm2.stan")print_stan_file(code_lin2)// logistic regression
data {
int<lower=0> N;
int<lower=0> M;
vector[N] y;
matrix[N,M] x;
}
parameters {
real alpha;
vector[M] beta;
real<lower=0> sigma;
}
model {
alpha ~ normal(0, 1);
beta ~ normal(0, 1);
sigma ~ normal(0, 1);
y ~ normal_id_glm(x, alpha, beta, sigma);
}Sample
mod_lin2 <- cmdstan_model(stan_file = code_lin2)
fit_lin2 <- mod_lin2$sample(data = data_lin, seed = SEED, refresh = 0)We check diagnostics
draws2 <- as_draws_rvars(fit_lin2$draws())
summarize_draws(draws2)# A tibble: 4 × 10
variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
<chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 lp__ 14. 15. 1.6 1.3 11. 16. 1.0 1017. 1571.
2 alpha -0.016 -0.015 0.034 0.029 -0.071 0.034 1.0 1873. 1482.
3 beta -0.0037 -0.0032 0.033 0.026 -0.053 0.047 1.0 2422. 1724.
4 sigma 0.087 0.078 0.036 0.025 0.049 0.15 1.0 1440. 1814.
In this specific case, the bias is negligible when running MCMC with the model code without the constraint, but it is difficult to diagnose without running the fixed model.
References
Licenses
- Code © 2021–2025, Aki Vehtari, licensed under BSD-3.
- Text © 2021–2025, Aki Vehtari, licensed under CC-BY-NC 4.0.