Workflow for iterative building of a time series model.

We analyse the relative number of births per day in USA 1969-1988 using Gaussian process time series model with several model components that can explain the long term, seasonal, weekly, day of year, and special floatind day variation.

Stan model codes are available in the corresponding git repo


Load packages

library("rprojroot")
root<-has_file(".Workflow-Examples-root")$make_fix_file()
library(tidyverse)
library(cmdstanr)
library(posterior)
options(pillar.neg = FALSE, pillar.subtle=FALSE, pillar.sigfig=2)
library(loo)
library(bayesplot)
theme_set(bayesplot::theme_default(base_family = "sans"))
library(patchwork)
set1 <- RColorBrewer::brewer.pal(7, "Set1")

Use English for names of weekdays and months

Sys.setlocale("LC_TIME", "en_GB.UTF-8")
[1] "en_GB.UTF-8"

Load and plot data

Load birthdays per day in USA 1969-1988:

data <- read_csv(root("Birthdays/data", "births_usa_1969.csv"))

Add date type column for plotting

data <- data %>%
  mutate(date = as.Date("1968-12-31") + id,
         births_relative100 = births/mean(births)*100)

Plot all births

We can see slow variation in trend, yearly pattern, and especially in the later years spread to lower and higher values.

data %>%
  ggplot(aes(x=date, y=births)) + geom_point(color=set1[2]) +
  labs(x="Date", y="Relative number of births")

Plot all births as relative to mean

To make the interpretation we switch to examine the relative change, with the mean level denoted with 100.

data %>%
  ggplot(aes(x=date, y=births_relative100)) + geom_point(color=set1[2]) +
  geom_hline(yintercept=100, color='gray') +
  labs(x="Date", y="Relative births per day")

Plot mean per day of year

We can see the generic pattern in yearly seasonal trend simply by averaging over each day of year (day_of_year has numbers from 1 to 366 every year with leap day being 60 and 1st March 61 also on non-leap-years).

data %>%
  group_by(day_of_year2) %>%
  summarise(meanbirths=mean(births_relative100)) %>%
  ggplot(aes(x=as.Date("1986-12-31")+day_of_year2, y=meanbirths)) +
  geom_point(color=set1[2]) +
  geom_hline(yintercept=100, color='gray') +
  scale_x_date(date_breaks = "1 month", date_labels = "%b") +
  labs(x="Day of year", y="Relative births per day of year")

Plot mean per day of week

We can see the generic pattern in weekly trend simply by averaging over each day of week.

data %>%
  group_by(day_of_week) %>%
  summarise(meanbirths=mean(births_relative100)) %>%
  ggplot(aes(x=day_of_week, y=meanbirths)) +
  geom_point(color=set1[2], size=4) +
  geom_hline(yintercept=100, color='gray') +
  scale_x_continuous(breaks = 1:7, labels=c('Mon','Tue','Wed','Thu','Fri','Sat','Sun')) +
  labs(x="Day of week", y="Relative number of births of week")

Previous analyses

We have analysed the same data before in BDA3 and thus had idea of what kind of model to use. For BDA3 we used GPstuff software which is Gaussian process specific software for Matlab and Octave. As Stan has aimed to be very generic it can be slower than specialized software for some specific models such as Gaussian processes, but Stan provides more flexibility in the model definition.

Riutort-Mayol et al (2020) demonstrate Hilbert space approximate basis function approximation of Gaussian processes also for the same birthday data. In the experiments the inference was slower than expected raising suspicion of inefficient model code or bad posterior shape due to bad model specification.

Workflow for quick iterative model building

Even we have general idea for the model (slow trend, seasonal trend, weekday effect, etc), adding them all at once to the model makes the model complex and difficult to debug and solve the computational problems. It is thus natural to build the model gradually and check that each addition works before adding the next model component. During this iterative model building we want the inference to be fast, but it doesn’t need to be very accurate as long as qualitatively the new model is reasonable. For quick testing and iterative model building we can use optimization and shorter MCMC chains that we would not recommend for the final inference. Furthermore, in this specific example, the new additions are qualitatively so clear improvements that there is no need for quantitative model comparison whether the additions are ``significant’’ (see also Navarro, 2019) and there is no danger of overfitting. Although there is one part of the model where the data is weakly informative and the prior choices seem to matter and we’ll get back to this and consequences later. Overall we build tens of different models, but illustrate here only the main line.

Models for relative number of birthdays

As the relative number of births is positive it’s natural to model the logarithm value. The generic form of the models is \[ y \sim \mbox{normal}(f(x), \sigma), \] where \(f\) is different and gradually more complex function conditional on \(x\) that includes running day number, day of year, day of week and eventually some special floating US bank holidays.

Model 1: Slow trend

The model 1 is just the slow trend over the years using Hilbert space basis function approximated Gaussian process \[ f = \mbox{intercept} + f_1\\ \mbox{intercept} \sim \mbox{normal}(0,1)\\ f_1 \sim \mbox{GP}(0,K_1) \] where GP has exponentiated quadratic covariance function.

In this phase the code from Riutort-Mayol et al.(2020) was cleaned and written to be more efficient, but only the one GP component was included to make the testing easier. Although the code was made more efficient, the aim wasn’t to make it the fastest possible as the later model changes may have bigger effect on the performance (it’s good o avoid premature optimization). We also use quite small number of basis functions to make the code run faster, and only later examine more carefully whether the number of basis function is sufficient compared to the posterior of the length scale (see, Riutort-Mayol et al., 2020).

Compile Stan model gpbf1.stan which includes gpbasisfun_functions1.stan

model1 <- cmdstan_model(stan_file = root("Birthdays", "gpbf1.stan"),
                        include_paths = root("Birthdays"))

Data to be passed to Stan

standata1 <- list(x=data$id,
                  y=log(data$births_relative100),
                  N=length(data$id),
                  c_f1=1.5, # factor c of basis functions for GP for f1
                  M_f1=20)  # number of basis functions for GP for f1

As the basis function approximation and priors restrict the complexity of GP, we can safely use optimization to get a very quick initial result to check that the model code is computing what we intended. As there are only 14 parameters and 7305 observations it’s likely that the posterior is close to normal (in unconstrained space). In this case the optimization takes less than one second while MCMC sampling with default options would have taken several minutes. Although this result can be useful in a quick workflow, the result should not be used as the final result.

opt1 <- model1$optimize(data = standata1, init=0, algorithm='bfgs')

Check whether parameters have reasonable values

odraws1 <- opt1$draws()
subset(odraws1, variable=c('intercept','sigma_f1','lengthscale_f1','sigma'))
# A draws_matrix: 1 iterations, 1 chains, and 4 variables
    variable
draw intercept sigma_f1 lengthscale_f1 sigma
   1    -0.048      1.1           0.16  0.81

Compare the model to the data

oEf <- exp(as.numeric(subset(odraws1, variable='f')))
data %>%
  mutate(oEf = oEf) %>%
  ggplot(aes(x=date, y=births_relative100)) +
  geom_point(color=set1[2], alpha=0.2) +
  geom_line(aes(y=oEf), color=set1[1]) +
  geom_hline(yintercept=100, color='gray') +
  labs(x="Date", y="Relative number of births")

After we get the model working using optimization we can compare the result to using short MCMC chains which will also provide us additional information on speed of different code implementations for the same model. We intentionally use just 1/10th length from the usual recommendation, as during the iterative model building a rough results are sufficient. When testing the code we initially used just one chain, but at this point running four chains with four core CPU doesn’t add much to the wall clock time, but gives more information of how easy it is sample from the posterior and can reveal if there are multiple modes. Although the result from short chains can be useful in a quick workflow, the result should not be used as th final result.

fit1 <- model1$sample(data=standata1, iter_warmup=100, iter_sampling=100,
                      chains=4, parallel_chains=4, seed=3891)

Depending on the random seed and luck, we sometimes observed that some of the chains got stuck in different modes. We could see this in high Rhat and low ESS diagnostic values.

draws1 <- fit1$draws()
summarise_draws(subset(draws1, variable=c('intercept','sigma_f1','lengthscale_f1','sigma')))
# A tibble: 4 × 10
  variable        mean median    sd   mad    q5   q95  rhat ess_bulk ess_tail
  <chr>          <dbl>  <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>    <dbl>    <dbl>
1 intercept      0.083  0.096 0.15  0.083 -0.20  0.28   1.7     64.      108.
2 sigma_f1       0.39   0.29  0.22  0.18   0.17  0.75   2.2      5.6      24.
3 lengthscale_f1 1.7    1.7   1.5   2.2    0.19  3.2    1.8      6.3      30.
4 sigma          0.91   0.90  0.093 0.14   0.80  1.0    1.9      6.2      55.

Examining the trace plots shows the multimodality clearly.

mcmc_trace(draws1, regex_pars=c('intercept','sigma_f1','lengthscale_f1','sigma'))

In this case it was easy to figure out that some of the chains got stuck in qualitatively much worse modes. We don’t in general recommend to start from the mode as the mode is not usually representative point in hierarchical model posterior or in high dimensional posterior, but we can use this again to speed up the iterative model building as long as we check that the optimization result is sensible and later do more careful inference. Although the result from short chains can be useful in a quick workflow, the result should not be used as the final result.

init1 <- sapply(c('intercept','sigma_f1','lengthscale_f1','beta_f1','sigma'),
                function(variable) {as.numeric(subset(odraws1, variable=variable))})
fit1 <- model1$sample(data=standata1, iter_warmup=100, iter_sampling=100,
                      chains=4, parallel_chains=4,
                      init=function() { init1 })

We now observe better Rhat and ESS diagnostic values, although due to very short chains they are not yet perfect. We are likely to also observe Hamiltonian Monte Carlo divergences and treedepth exceedences in dynamic building of the Hamiltonian trajectory, but there is no need to worry about those as long as the model results are qualitatively sensible as these computational issues can also go away when the model itself is improved. In all the following short MCMC samplings we get some or many divergences and usually very large number of treedepth exceedences. Divergences indicate possible bias and should be eventually investigated carefully. Treedepth exceedences indicate strong posterior dependencies and slow mixing and sometimes the posterior can be much improved by changing the parameterization or priors, but as the treedepth exceedences don’t indicate bias there is no need for more careful analysis if the resulting ESS and MCSE values are good for the purpose in hand. We’ll come back later to more careful analysis of the final models.

draws1 <- fit1$draws()
summarise_draws(subset(draws1, variable=c('intercept','sigma_f1','lengthscale_f1','sigma')))
# A tibble: 4 × 10
  variable        mean median     sd    mad    q5   q95  rhat ess_bulk ess_tail
  <chr>          <dbl>  <dbl>  <dbl>  <dbl> <dbl> <dbl> <dbl>    <dbl>    <dbl>
1 intercept      0.050  0.060 0.24   0.23   -0.37  0.38   1.0     296.     277.
2 sigma_f1       0.58   0.56  0.12   0.11    0.43  0.81   1.0     218.     336.
3 lengthscale_f1 0.23   0.23  0.039  0.037   0.17  0.29   1.0     224.     231.
4 sigma          0.81   0.81  0.0064 0.0068  0.80  0.82   1.0     400.     232.

Trace plot shows slow mixing but no multimodality.

mcmc_trace(draws1, regex_pars=c('intercept','sigma_f1','lengthscale_f1','sigma'))

The model result from short MCMC chains looks very similar to the optimization result.

draws1 <- as_draws_matrix(draws1)
Ef <- exp(apply(subset(draws1, variable='f'), 2, median))
data %>%
  mutate(Ef = Ef) %>%
  ggplot(aes(x=date, y=births_relative100)) + geom_point(color=set1[2], alpha=0.2) +
  geom_line(aes(y=Ef), color=set1[1]) +
  geom_hline(yintercept=100, color='gray') +
  labs(x="Date", y="Relative number of births")

If we compare the result from short sampling to optimizing, we don’t see practical difference in the predictions (although we see later more differences between optimization and MCMC).

data %>%
  mutate(Ef = Ef,
         oEf = oEf) %>%
  ggplot(aes(x=Ef, y=oEf)) + geom_point(color=set1[2]) +
  geom_abline() +
  labs(x="Ef from short Markov chain", y="Ef from optimizing")

After the first version of this notebook, Nikolas Siccha examined more carefully the posterior correlations and noticed strong correlation between intercept and the first basis function. Stan’s dynamic HMC is so efficient that the inference is succesful anyway. Nikolas suggested removing the intercept term. The intercept term is not necessarily needed as the data has been centered. We test a model without the explicit intercept term.

Compile Stan model gpbf1b.stan

model1b <- cmdstan_model(stan_file = root("Birthdays", "gpbf1b.stan"),
                        include_paths = root("Birthdays"))

We sample using the default initialization.

fit1b <- model1b$sample(data=standata1, iter_warmup=100, iter_sampling=100,
                      chains=4, parallel_chains=4, seed=3891)

The sampling performs better, indicating that the strong posterior correlation in the first model was causing troubles for the adaptation in the short warmup leading some chains to stay stuck.

draws1b <- fit1b$draws()
summarise_draws(subset(draws1b, variable=c('sigma_f1','lengthscale_f1','sigma')))
# A tibble: 3 × 10
  variable        mean median     sd    mad    q5   q95  rhat ess_bulk ess_tail
  <chr>          <dbl>  <dbl>  <dbl>  <dbl> <dbl> <dbl> <dbl>    <dbl>    <dbl>
1 sigma_f1        0.60   0.58 0.14   0.12    0.41  0.87   1.0     221.     215.
2 lengthscale_f1  0.23   0.23 0.042  0.036   0.15  0.29   1.0     198.     252.
3 sigma           0.81   0.81 0.0068 0.0067  0.80  0.82   1.0     401.     313.

Examining the trace plots don’t show multimodality

mcmc_trace(draws1b, regex_pars=c('sigma_f1','lengthscale_f1','sigma'))