Load packages
library(ggplot2)
library(tidyr)
library(gridExtra)
library(rstanarm)
library(brms)
library(bayesplot)
theme_set(bayesplot::theme_default(base_family = "sans", base_size = 16))
library(patchwork)
library(loo)
library(rprojroot)
root<-has_file(".BDA_R_demos_root")$make_fix_file()
This notebook demonstrates time series analysis for traffic deaths per year in Finland. Currently when the the number of traffic deaths during previous year are reported, the press release claims that the the traffic safety in Finland has improved or worsened depending whether the number is smaller or larger than the year before. Time series analysis can be used to separate random fluctuation from the slowly changing traffic safety.
Read the data (there would data for earlier years, too, but this is sufficient for the demonstration)
# file preview shows a header row
deaths <- read.csv(root("demos_rstan", "trafficdeaths.csv"), header = TRUE)
head(deaths)
## year deaths
## 1 1993 434
## 2 1994 423
## 3 1995 411
## 4 1996 355
## 5 1997 391
## 6 1998 367
First plot just the data.
deaths |>
ggplot(aes(x=year, y=deaths)) +
geom_point() +
labs(y = 'Traffic deaths', x= "Year") +
guides(linetype = "none")
The number of deaths is count data, so we use Poisson observation model. We first fit log-linear model for the Poisson intensity, which corresponds to assuming constant proportional change in the rate.
fit_lin <- stan_glm(deaths ~ year, data=deaths, family=poisson,
refresh=1000, iter=1000, chains=4, seed=583829, refresh=0)
ESS’s and Rhat’s are ok (see, e.g., RStan workflow). Let’s look at the posterior predictive distribution (median and 5% and 95% intervals).
x_predict <- seq(1993,2030)
N_predict <- length(x_predict)
y_predict_lin <- posterior_predict(fit_lin, newdata=data.frame(year=x_predict))
mu <- apply(t(y_predict_lin), 1, quantile, c(0.05, 0.5, 0.95)) %>%
t() %>% data.frame(x = x_predict, .) %>% gather(pct, y, -x)
pfit <- ggplot() +
geom_point(aes(year, deaths), data = deaths, size = 1) +
geom_line(aes(x, y, linetype = pct), data = mu, color = 'red') +
scale_linetype_manual(values = c(2,1,2)) +
labs(x = 'Year', y = 'Traffic deaths') +
guides(linetype = F)
## Warning: The `<scale>` argument of `guides()` cannot be `FALSE`. Use "none" instead as
## of ggplot2 3.3.4.
(pfit)
Next we fit a non-linear spline model with stan_gamm4
fit_gam <- stan_gamm4(deaths ~ year + s(year), data=deaths,
family=poisson, adapt_delta=0.999,
refresh=1000, iter=2000, chain=4, seed=583829, refresh=0)
## Warning: There were 6 divergent transitions after warmup. See
## https://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup
## to find out why this is a problem and how to eliminate them.
## Warning: Examine the pairs() plot to diagnose sampling problems
ESS is clearly smaller than for the linear model, but Rhat’s are ok.
Let’s look at the posterior predictive distribution.
x_predict=seq(1993,2030)
N_predict=length(x_predict)
y_predict_gam <- posterior_predict(fit_gam, newdata=data.frame(year=x_predict))
mu <- apply(t(y_predict_gam), 1, quantile, c(0.05, 0.5, 0.95)) %>%
t() %>% data.frame(x = x_predict, .) %>% gather(pct, y, -x)
pfit <- ggplot() +
geom_point(aes(year, deaths), data = deaths, size = 1) +
geom_line(aes(x, y, linetype = pct), data = mu, color = 'red') +
scale_linetype_manual(values = c(2,1,2)) +
labs(x = 'Year', y = 'Traffic deaths') +
guides(linetype = F)
(pfit)
The predictive median is clearly nonlinear. The predictive mean for future years stays at the same level as the most recent observations, but uncertainty increases quickly.
Finally we fit Gaussian process centered on linear model. We use brms for this:
fit_gp <- brm(deaths ~ year + gp(year), data=deaths,
family=poisson, adapt_delta=0.95,
refresh=1000, iter=2000, chain=4, seed=583829, refresh=0,
backend='cmdstanr')
Running MCMC with 4 sequential chains...
Chain 1 finished in 5.9 seconds.
Chain 2 finished in 5.5 seconds.
Chain 3 finished in 6.0 seconds.
Chain 4 finished in 5.0 seconds.
All 4 chains finished successfully.
Mean chain execution time: 5.6 seconds.
Total execution time: 22.7 seconds.
x_predict=seq(1993,2030)
N_predict=length(x_predict)
y_predict_gp <- posterior_predict(fit_gp, newdata=data.frame(year=x_predict))
mu <- apply(t(y_predict_gp), 1, quantile, c(0.05, 0.5, 0.95)) %>%
t() %>% data.frame(x = x_predict, .) %>% gather(pct, y, -x)
pfit <- ggplot() +
geom_point(aes(year, deaths), data = deaths, size = 1) +
geom_line(aes(x, y, linetype = pct), data = mu, color = 'red') +
scale_linetype_manual(values = c(2,1,2)) +
labs(x = 'Year', y = 'Traffic deaths') +
guides(linetype = F)
(pfit)
Finally we compare models using PSIS-LOO predictive performance estimates.
(loo_lin<-loo(fit_lin, save_psis=TRUE))
##
## Computed from 2000 by 29 log-likelihood matrix
##
## Estimate SE
## elpd_loo -134.0 5.6
## p_loo 2.4 0.6
## looic 267.9 11.3
## ------
## Monte Carlo SE of elpd_loo is 0.1.
##
## All Pareto k estimates are good (k < 0.5).
## See help('pareto-k-diagnostic') for details.
(loo_gam<-loo(fit_gam, save_psis=TRUE))
##
## Computed from 4000 by 29 log-likelihood matrix
##
## Estimate SE
## elpd_loo -132.5 4.3
## p_loo 6.8 1.3
## looic 265.1 8.5
## ------
## Monte Carlo SE of elpd_loo is 0.1.
##
## Pareto k diagnostic values:
## Count Pct. Min. n_eff
## (-Inf, 0.5] (good) 28 96.6% 624
## (0.5, 0.7] (ok) 1 3.4% 1705
## (0.7, 1] (bad) 0 0.0% <NA>
## (1, Inf) (very bad) 0 0.0% <NA>
##
## All Pareto k estimates are ok (k < 0.7).
## See help('pareto-k-diagnostic') for details.
(loo_gp<-loo(fit_gp, save_psis=TRUE))
##
## Computed from 4000 by 29 log-likelihood matrix
##
## Estimate SE
## elpd_loo -132.0 3.5
## p_loo 8.6 1.5
## looic 264.0 7.1
## ------
## Monte Carlo SE of elpd_loo is 0.1.
##
## Pareto k diagnostic values:
## Count Pct. Min. n_eff
## (-Inf, 0.5] (good) 22 75.9% 265
## (0.5, 0.7] (ok) 7 24.1% 1085
## (0.7, 1] (bad) 0 0.0% <NA>
## (1, Inf) (very bad) 0 0.0% <NA>
##
## All Pareto k estimates are ok (k < 0.7).
## See help('pareto-k-diagnostic') for details.
loo_compare(loo_lin, loo_gam, loo_gp)
## elpd_diff se_diff
## fit_gp 0.0 0.0
## fit_gam -0.5 1.0
## fit_lin -2.0 2.7
There are no practical differences in predictive performance, which is partially due to small number of observations. Based on the posterior predictive distributions there are clear differences in the future predictions. We can also look at the calibration of leave-one-out predictive distributions
ppc_loo_intervals(deaths$deaths,
y_predict_lin[,1:29],
psis_object=loo_lin$psis_object)+
labs(title='PPC-LOO linear model')
ppc_loo_intervals(deaths$deaths,
y_predict_gp[,1:29],
psis_object=loo_gp$psis_object)+
labs(title='PPC-LOO GP model')
There is a small difference in favor of GP model.
sessionInfo()
## R version 4.2.2 (2022-10-31)
## Platform: x86_64-pc-linux-gnu (64-bit)
## Running under: Ubuntu 20.04.5 LTS
##
## Matrix products: default
## BLAS: /usr/lib/x86_64-linux-gnu/openblas-pthread/libblas.so.3
## LAPACK: /usr/lib/x86_64-linux-gnu/openblas-pthread/liblapack.so.3
##
## locale:
## [1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C
## [3] LC_TIME=fi_FI.UTF-8 LC_COLLATE=en_US.UTF-8
## [5] LC_MONETARY=fi_FI.UTF-8 LC_MESSAGES=en_US.UTF-8
## [7] LC_PAPER=fi_FI.UTF-8 LC_NAME=C
## [9] LC_ADDRESS=C LC_TELEPHONE=C
## [11] LC_MEASUREMENT=fi_FI.UTF-8 LC_IDENTIFICATION=C
##
## attached base packages:
## [1] stats graphics grDevices utils datasets methods base
##
## other attached packages:
## [1] rprojroot_2.0.3 loo_2.5.1 patchwork_1.1.2 bayesplot_1.10.0
## [5] brms_2.18.1 rstanarm_2.21.3 Rcpp_1.0.9 gridExtra_2.3
## [9] tidyr_1.2.1 ggplot2_3.4.0
##
## loaded via a namespace (and not attached):
## [1] minqa_1.2.5 colorspace_2.0-3 ellipsis_0.3.2
## [4] markdown_1.4 base64enc_0.1-3 farver_2.1.1
## [7] rstan_2.30.1.9000 DT_0.26 fansi_1.0.3
## [10] mvtnorm_1.1-3 bridgesampling_1.1-2 codetools_0.2-18
## [13] splines_4.2.2 cachem_1.0.6 knitr_1.41
## [16] shinythemes_1.2.0 projpred_2.2.2 jsonlite_1.8.3
## [19] nloptr_2.0.3 shiny_1.7.3 compiler_4.2.2
## [22] backports_1.4.1 assertthat_0.2.1 Matrix_1.5-1
## [25] fastmap_1.1.0 cli_3.4.1 later_1.3.0
## [28] htmltools_0.5.3 prettyunits_1.1.1 tools_4.2.2
## [31] igraph_1.3.5 coda_0.19-4 gtable_0.3.1
## [34] glue_1.6.2 reshape2_1.4.4 dplyr_1.0.10
## [37] posterior_1.3.1 V8_4.2.2 jquerylib_0.1.4
## [40] vctrs_0.5.1 nlme_3.1-160 crosstalk_1.2.0
## [43] tensorA_0.36.2 xfun_0.35 stringr_1.4.1
## [46] ps_1.7.2 lme4_1.1-31 mime_0.12
## [49] miniUI_0.1.1.1 lifecycle_1.0.3 gtools_3.9.3
## [52] nleqslv_3.3.3 MASS_7.3-58 zoo_1.8-11
## [55] scales_1.2.1 colourpicker_1.2.0 promises_1.2.0.1
## [58] Brobdingnag_1.2-9 parallel_4.2.2 inline_0.3.19
## [61] shinystan_2.6.0 gamm4_0.2-6 yaml_2.3.6
## [64] curl_4.3.3 StanHeaders_2.30.1.9000 sass_0.4.3
## [67] stringi_1.7.8 highr_0.9 dygraphs_1.1.1.6
## [70] checkmate_2.1.0 boot_1.3-28 pkgbuild_1.3.1
## [73] cmdstanr_0.5.3 rlang_1.0.6 pkgconfig_2.0.3
## [76] matrixStats_0.63.0 distributional_0.3.1 evaluate_0.18
## [79] lattice_0.20-45 purrr_0.3.5 labeling_0.4.2
## [82] rstantools_2.2.0 htmlwidgets_1.5.4 processx_3.8.0
## [85] tidyselect_1.2.0 plyr_1.8.8 magrittr_2.0.3
## [88] R6_2.5.1 generics_0.1.3 DBI_1.1.3
## [91] mgcv_1.8-41 pillar_1.8.1 withr_2.5.0
## [94] xts_0.12.2 survival_3.4-0 abind_1.4-5
## [97] tibble_3.1.8 crayon_1.5.2 utf8_1.2.2
## [100] rmarkdown_2.18 grid_4.2.2 data.table_1.14.6
## [103] callr_3.7.3 threejs_0.3.3 digest_0.6.30
## [106] xtable_1.8-4 httpuv_1.6.6 RcppParallel_5.1.5
## [109] stats4_4.2.2 munsell_0.5.0 bslib_0.4.1
## [112] shinyjs_2.1.0