Demonstration of covariance matrix and basis function implementation of Gaussian process model in Stan.
The basics of the covariance matrix approach is based on the Chapter 10 of Stan User’s Guide, Version 2.26 by Stan Development Team (2021). https://mc-stan.org/docs/stan-users-guide/
The basics of the Hilbert space basis function approximation is based on Riutort-Mayol, Bürkner, Andersen, Solin, and Vehtari (2020). Practical Hilbert space approximate Bayesian Gaussian processes for probabilistic programming. https://arxiv.org/abs/2004.11408
Data are measurements of head acceleration in a simulated motorcycle accident, used to test crash helmets.
Data are modelled with normal distribution having Gaussian process prior on mean and log standard deviation: \[ y \sim \mbox{normal}(\mu(x), \exp(\eta(x))\\ \mu \sim GP(0, K_1)\\ \eta \sim GP(0, K_2) \] \(K_1\) and \(K_2\) are exponentiated quadratic covariance functions.
library(cmdstanr)
library(posterior)
options(pillar.neg = FALSE, pillar.subtle=FALSE, pillar.sigfig=2)
library(tidyr)
library(dplyr)
library(ggplot2)
library(bayesplot)
theme_set(bayesplot::theme_default(base_family = "sans", base_size=16))
set1 <- RColorBrewer::brewer.pal(7, "Set1")
SEED <- 48927 # set random seed for reproducability
Load data
data(mcycle, package="MASS")
head(mcycle)
times accel
1 2.4 0.0
2 2.6 -1.3
3 3.2 -2.7
4 3.6 0.0
5 4.0 -2.7
6 6.2 -2.7
Plot data
mcycle %>%
ggplot(aes(x=times,y=accel))+
geom_point()+
labs(x="Time (ms)", y="Acceleration (g)")
Model code
file1 <- "gpbf1.stan"
writeLines(readLines(file1))
functions {
#include gpbasisfun_functions.stan
}
data {
int<lower=1> N; // number of observations
vector[N] x; // univariate covariate
vector[N] y; // target variable
real<lower=0> c_f; // factor c to determine the boundary value L
int<lower=1> M_f; // number of basis functions
real<lower=0> c_g; // factor c to determine the boundary value L
int<lower=1> M_g; // number of basis functions
}
transformed data {
// Normalize data
real xmean = mean(x);
real ymean = mean(y);
real xsd = sd(x);
real ysd = sd(y);
vector[N] xn = (x - xmean)/xsd;
vector[N] yn = (y - ymean)/ysd;
// Basis functions for f
real L_f = c_f*max(xn);
matrix[N,M_f] PHI_f = PHI_EQ(N, M_f, L_f, xn);
// Basis functions for g
real L_g= c_g*max(xn);
matrix[N,M_g] PHI_g = PHI_EQ(N, M_g, L_g, xn);
}
parameters {
real intercept; //
vector[M_f] beta_f; // the basis functions coefficients
vector[M_g] beta_g; // the basis functions coefficients
real<lower=0> lengthscale_f; // lengthscale of f
real<lower=0> sigma_f; // scale of f
real<lower=0> lengthscale_g; // lengthscale of g
real<lower=0> sigma_g; // scale of g
}
model {
// spectral densities for f and g
vector[M_f] diagSPD_f = diagSPD_EQ(sigma_f, lengthscale_f, L_f, M_f);
vector[M_g] diagSPD_g = diagSPD_EQ(sigma_g, lengthscale_g, L_g, M_g);
// priors
intercept ~ normal(0, 1);
beta_f ~ normal(0, 1);
beta_g ~ normal(0, 1);
lengthscale_f ~ normal(0, 1);
lengthscale_g ~ normal(0, 1);
sigma_f ~ normal(0, .5);
sigma_g ~ normal(0, .5);
// model
yn ~ normal(intercept + PHI_f * (diagSPD_f .* beta_f),
exp(PHI_g * (diagSPD_g .* beta_g)));
}
generated quantities {
vector[N] f;
vector[N] sigma;
{
// spectral densities
vector[M_f] diagSPD_f = diagSPD_EQ(sigma_f, lengthscale_f, L_f, M_f);
vector[M_g] diagSPD_g = diagSPD_EQ(sigma_g, lengthscale_g, L_g, M_g);
// function scaled back to the original scale
f = (intercept + PHI_f * (diagSPD_f .* beta_f))*ysd + ymean;
sigma = exp(PHI_g * (diagSPD_g .* beta_g))*ysd;
}
}
The model code includes Hilbert space basis function helpers
writeLines(readLines("gpbasisfun_functions.stan"))
vector diagSPD_EQ(real alpha, real rho, real L, int M) {
return sqrt((alpha^2) * sqrt(2*pi()) * rho * exp(-0.5*(rho*pi()/2/L)^2 * linspaced_vector(M, 1, M)^2));
}
/* real spd_Matt(real alpha, real rho, real w) { */
/* real S = 4*alpha^2 * (sqrt(3)/rho)^3 * 1/((sqrt(3)/rho)^2 + w^2)^2; */
/* return sqrt(S); */
/* } */
vector diagSPD_periodic(real alpha, real rho, int M) {
real a = 1/rho^2;
int one_to_M[M];
for (m in 1:M) one_to_M[m] = m;
vector[M] q = sqrt(alpha^2 * 2 / exp(a) * to_vector(modified_bessel_first_kind(one_to_M, a)));
return append_row(q,q);
}
matrix PHI_EQ(int N, int M, real L, vector x) {
return sin(diag_post_multiply(rep_matrix(pi()/(2*L) * (x+L), M), linspaced_vector(M, 1, M)))/sqrt(L);
}
matrix PHI_periodic(int N, int M, real w0, vector x) {
matrix[N,M] mw0x = diag_post_multiply(rep_matrix(w0*x, M), linspaced_vector(M, 1, M));
return append_col(cos(mw0x), sin(mw0x));
}
Compile Stan model
model1 <- cmdstan_model(stan_file = file1, include_paths = ".")
Data to be passed to Stan
standata1 <- list(x=mcycle$times,
y=mcycle$accel,
N=length(mcycle$times),
c_f=1.5, # factor c of basis functions for GP for f1
M_f=40, # number of basis functions for GP for f1
c_g=1.5, # factor c of basis functions for GP for g3
M_g=40) # number of basis functions for GP for g3
Sample using dynamic HMC
fit1 <- model1$sample(data=standata1, iter_warmup=500, iter_sampling=500,
chains=4, parallel_chains=2, adapt_delta=0.9)
Check whether parameters have reasonable values
draws1 <- fit1$draws()
summarise_draws(subset(draws1, variable=c('intercept','sigma_','lengthscale_'), regex=TRUE))
# A tibble: 5 x 10
variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
<chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 intercept 0.27 0.26 0.32 0.30 -0.26 0.81 1.0 1717. 1173.
2 sigma_f 0.81 0.79 0.17 0.16 0.58 1.1 1.0 1355. 1520.
3 sigma_g 1.3 1.3 0.22 0.22 0.95 1.7 1.0 1388. 1306.
4 lengthscale_f 0.34 0.34 0.050 0.052 0.26 0.43 1.0 945. 1627.
5 lengthscale_g 0.52 0.52 0.095 0.092 0.37 0.68 1.0 1214. 1352.
Compare the model to the data
draws1m <- as_draws_matrix(draws1)
Ef <- colMeans(subset(draws1m, variable='f'))
sigma <- colMeans(subset(draws1m, variable='sigma'))
pred<-data.frame(Ef=Ef,sigma=sigma)
cbind(mcycle,pred) %>%
ggplot(aes(x=times,y=accel))+
geom_point()+
labs(x="Time (ms)", y="Acceleration (g)")+
geom_line(aes(y=Ef), color=set1[1])+
geom_line(aes(y=Ef-2*sigma), color=set1[1],linetype="dashed")+
geom_line(aes(y=Ef+2*sigma), color=set1[1],linetype="dashed")
Plot posterior draws and posterior mean of the mean function
subset(draws1, variable="f") %>%
thin_draws(thin=5)%>%
as_draws_df() %>%
pivot_longer(!starts_with("."),
names_to="ind",
names_transform = list(ind = readr::parse_number),
values_to="mu") %>%
mutate(time=mcycle$times[ind])%>%
ggplot(aes(time, mu, group = .draw)) +
geom_line(color=set1[2], alpha = 0.1) +
geom_point(data=mcycle, mapping=aes(x=times,y=accel), inherit.aes=FALSE)+
geom_line(data=cbind(mcycle,pred), mapping=aes(x=times,y=Ef), inherit.aes=FALSE, color=set1[1], size=1)+
labs(x="Time (ms)", y="Acceleration (g)")
Model code
file2 <- "gpcov.stan"
writeLines(readLines(file2))
data {
int<lower=1> N; // number of observations
vector[N] x; // univariate covariate
vector[N] y; // target variable
}
transformed data {
// Normalize data
real xmean = mean(x);
real ymean = mean(y);
real xsd = sd(x);
real ysd = sd(y);
real xn[N] = to_array_1d((x - xmean)/xsd);
vector[N] yn = (y - ymean)/ysd;
real sigma_intercept = 0.1;
vector[N] jitter = rep_vector(1e-9, N);
}
parameters {
real<lower=0> lengthscale_f; // lengthscale of f
real<lower=0> sigma_f; // scale of f
real<lower=0> lengthscale_g; // lengthscale of g
real<lower=0> sigma_g; // scale of g
vector[N] z_f;
vector[N] z_g;
}
model {
// covariances and Cholesky decompositions
matrix[N, N] K_f = gp_exp_quad_cov(xn, sigma_f, lengthscale_f)+
sigma_intercept;
matrix[N, N] L_f = cholesky_decompose(add_diag(K_f, jitter));
matrix[N, N] K_g = gp_exp_quad_cov(xn, sigma_g, lengthscale_g)+
sigma_intercept;
matrix[N, N] L_g = cholesky_decompose(add_diag(K_g, jitter));
// priors
z_f ~ std_normal();
z_g ~ std_normal();
lengthscale_f ~ lognormal(log(.3), .2);
lengthscale_g ~ lognormal(log(.5), .2);
sigma_f ~ normal(0, .5);
sigma_g ~ normal(0, .5);
// model
yn ~ normal(L_f * z_f, exp(L_g * z_g));
}
generated quantities {
vector[N] f;
vector[N] sigma;
{
// covariances and Cholesky decompositions
matrix[N, N] K_f = gp_exp_quad_cov(xn, sigma_f, lengthscale_f)+
sigma_intercept;
matrix[N, N] L_f = cholesky_decompose(add_diag(K_f, jitter));
matrix[N, N] K_g = gp_exp_quad_cov(xn, sigma_g, lengthscale_g)+
sigma_intercept;
matrix[N, N] L_g = cholesky_decompose(add_diag(K_g, jitter));
// function scaled back to the original scale
f = (L_f * z_f)*ysd + ymean;
sigma = exp(L_g * z_g)*ysd;
}
}
Compile Stan model
model2 <- cmdstan_model(stan_file = file2)
Data to be passed to Stan
standata2 <- list(x=mcycle$times,
y=mcycle$accel,
N=length(mcycle$times))
Sample using dynamic HMC
fit2 <- model2$sample(data=standata2, iter_warmup=100, iter_sampling=100,
chains=4, parallel_chains=2, refresh=10)
Check whether parameters have reasonable values
draws2 <- fit2$draws()
summarise_draws(subset(draws2, variable=c('sigma_','lengthscale_'), regex=TRUE))
# A tibble: 4 x 10
variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
<chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 sigma_f 0.79 0.77 0.15 0.14 0.56 1.1 1.0 406. 362.
2 sigma_g 1.2 1.1 0.23 0.22 0.83 1.6 1.0 428. 259.
3 lengthscale_f 0.32 0.32 0.037 0.036 0.26 0.38 1.0 186. 238.
4 lengthscale_g 0.48 0.48 0.070 0.075 0.37 0.60 1.0 185. 225.
Compare the model to the data
draws2m <- as_draws_matrix(draws2)
Ef <- colMeans(subset(draws2m, variable='f'))
sigma <- colMeans(subset(draws2m, variable='sigma'))
pred<-data.frame(Ef=Ef,sigma=sigma)
cbind(mcycle,pred) %>%
ggplot(aes(x=times,y=accel))+
geom_point()+
labs(x="Time (ms)", y="Acceleration (g)")+
geom_line(aes(y=Ef), color=set1[1])+
geom_line(aes(y=Ef-2*sigma), color=set1[1],linetype="dashed")+
geom_line(aes(y=Ef+2*sigma), color=set1[1],linetype="dashed")
Plot posterior draws and posterior mean of the mean function
subset(draws2, variable="f") %>%
as_draws_df() %>%
pivot_longer(!starts_with("."),
names_to="ind",
names_transform = list(ind = readr::parse_number),
values_to="mu") %>%
mutate(time=mcycle$times[ind])%>%
ggplot(aes(time, mu, group = .draw)) +
geom_line(color=set1[2], alpha = 0.1) +
geom_point(data=mcycle, mapping=aes(x=times,y=accel), inherit.aes=FALSE)+
geom_line(data=cbind(mcycle,pred), mapping=aes(x=times,y=Ef), inherit.aes=FALSE, color=set1[1], size=1) +
labs(x="Time (ms)", y="Acceleration (g)")