Assignment 7

Hierarchical model in Stan

Author

anonymous

Warning

Currently, rendering on github is broken, such that the rendered template at https://avehtari.github.io/BDA_course_Aalto/assignments/template7.html looks weird. Rendering should however work on Aalto’s JupyterLab, but we will also try to fix rendering on github ASAP.

1 General information

This is the template for assignment 7. You can download the separate model with bad priors and the qmd-file or copy the code from this rendered document after clicking on </> Code in the top right corner.

Please replace the instructions in this template by your own text, explaining what you are doing in each exercise.

Setup

This block will only be visible in your HTML output, but will be hidden when rendering to PDF with quarto for the submission. Make sure that this does not get displayed in the PDF!

The following loads several needed packages:

library(aaltobda)
library(bayesplot)
This is bayesplot version 1.10.0
- Online documentation and vignettes at mc-stan.org/bayesplot
- bayesplot theme set to bayesplot::theme_default()
   * Does _not_ affect other ggplot2 plots
   * See ?bayesplot_theme_set for details on theme setting
library(cmdstanr)
This is cmdstanr version 0.5.3
- CmdStanR documentation and vignettes: mc-stan.org/cmdstanr
- CmdStan path: /root/.cmdstan/cmdstan-2.31.0
- CmdStan version: 2.31.0

A newer version of CmdStan is available. See ?install_cmdstan() to install it.
To disable this check set option or environment variable CMDSTANR_NO_VER_CHECK=TRUE.
library(dplyr)

Attaching package: 'dplyr'
The following objects are masked from 'package:stats':

    filter, lag
The following objects are masked from 'package:base':

    intersect, setdiff, setequal, union
library(ggplot2)
library(ggdist) # for stat_dotsinterval
library(posterior)
This is posterior version 1.4.0

Attaching package: 'posterior'
The following object is masked from 'package:bayesplot':

    rhat
The following object is masked from 'package:aaltobda':

    mcse_quantile
The following objects are masked from 'package:stats':

    mad, sd, var
The following objects are masked from 'package:base':

    %in%, match
if(!require(brms)){
    install.packages("brms")
    library(brms)
}
Loading required package: brms
Loading required package: Rcpp
Loading 'brms' package (version 2.19.0). Useful instructions
can be found by typing help('brms'). A more detailed introduction
to the package is available through vignette('brms_overview').

Attaching package: 'brms'
The following objects are masked from 'package:ggdist':

    dstudent_t, pstudent_t, qstudent_t, rstudent_t
The following object is masked from 'package:bayesplot':

    rhat
The following object is masked from 'package:stats':

    ar
# Set more readable themes with bigger font for plotting packages.
ggplot2::theme_set(theme_minimal(base_size = 14))
bayesplot::bayesplot_theme_set(theme_minimal(base_size = 14))

# This registers CmdStan as the backend for compiling cmdstan-chunks.
check_cmdstan_toolchain(fix = TRUE, quiet = TRUE)
register_knitr_engine(override = FALSE)

2 Hierarchical Model: Chicken Data with Stan (6p)

2.1 Choosing a weakly informative prior by intuition

A word of caution on eliciting the priors below

Please note that in the below, we intend to set a prior on \(\mu\) (the mean chick weight), but the intuition we ilicit is based on the weight of individual chicks. We do so to help create intuition about what the mean could be, however, it would be theoretically more accurate to ilicit priors about mean chick weights.

Important

We have made changes to the assignment text and some of the rubrics to make it clearer.

2.2 (a)

2.3 (b)

2.4 (c)

2.5 (d)

2.6 (e)

2.7 Choosing a weakly informative prior using external references

2.8 (f)

2.9 (g)

2.10 (h)

2.11 (i)

2.12 Non-normal priors

2.13 (j)

2.14 Modeling diet effects on chicken weight

data("ChickWeight")

Chick12 <- ChickWeight |> filter(Time == 12)

head(Chick12)
Grouped Data: weight ~ Time | Chick
  weight Time Chick Diet
1    106   12     1    1
2    122   12     2    1
3    115   12     3    1
4    102   12     4    1
5    141   12     5    1
6    141   12     6    1
Sample from the posterior

This block will only be visible in your HTML output, but will be hidden when rendering to PDF with quarto for the submission. Make sure that this does not get displayed in the PDF!

To sample from the posterior using Stan, use:

stan_data <- list(
  N_observations = nrow(Chick12),
  N_diets = length(unique(Chick12$Diet)),
  diet_idx = Chick12$Diet,
  weight = Chick12$weight
)

model_separate <- cmdstan_model(stan_file = "additional_files/assignment7/chickens_separate.stan")

# Sampling from the posterior distribution happens here:
fit_separate <- model_separate$sample(data = stan_data, refresh=0,
                                      show_messages=FALSE,
                                      show_exceptions=FALSE)
Warning: No chains finished successfully. Unable to retrieve the fit.

Fit objects returned by the sample() method, by default print a summary of the posterior draws. These are NOT the results you would expect to turn in your report. You will need to change the priors in the code for the separate model.

fit_separate
Error: Fitting failed. Unable to print.

Quick model convergence check (as in assignment 6):

fit_separate$cmdstan_diagnose()
Error: No CmdStan runs finished successfully. Unable to run bin/diagnose.

2.15 (k)

2.16 (l)

For the figures below, we use the earlier draws for the separate model with bad priors. When you have implemented the pooled and hierarchical models, edit the code below to include draws from your model posterior into the figures.

Data preparation and sampling from the posterior

This block will only be visible in your HTML output, but will be hidden when rendering to PDF with quarto for the submission. Make sure that this does not get displayed in the PDF!

fit_pooled <- fit_separate
fit_hierarchical <- fit_separate

Below, we collect the corresponding posterior draws from the three models into a shared data frame using the extract_variable function. This makes plotting the posterior in a single plot easier.

# Expect the same number of posterior draws from each model.
ndraws <- nrow(fit_hierarchical$sampler_diagnostics(format = "matrix"))
Error: No chains finished successfully. Unable to retrieve the sampler diagnostics.
# Collect posterior draws and the model used to a data frame.
mean_diet_4_separate = extract_variable(fit_separate, "mean_diet[4]")
Error: No chains finished successfully. Unable to retrieve the draws.
mean_diet_4_pooled = extract_variable(fit_pooled, "mean_diet[4]")
Error: No chains finished successfully. Unable to retrieve the draws.
mean_diet_4_hierarchical = extract_variable(fit_hierarchical, "mean_diet[4]")
Error: No chains finished successfully. Unable to retrieve the draws.
posterior_mean_diet_4 <- data.frame(
  model_name = rep(c("Separate", "Pooled", "Hierarchical"),
              each = ndraws),
  mean_diet_4 = c(
   mean_diet_4_separate, mean_diet_4_pooled, mean_diet_4_hierarchical
  ))
Error in data.frame(model_name = rep(c("Separate", "Pooled", "Hierarchical"), : object 'mean_diet_4_separate' not found
predicted_weight_diet_4 <- data.frame(
  model_name = rep(c("Separate", "Pooled", "Hierarchical"),
              each = ndraws),
  predicted_weight = c(
   extract_variable(fit_separate, "weight_pred"),
   extract_variable(fit_pooled, "weight_pred"),
   extract_variable(fit_hierarchical, "weight_pred")
  ))
Error: No chains finished successfully. Unable to retrieve the draws.
# Collect posterior draws and the model used to a long data frame.
posterior_mean_diet_5 <- data.frame(
  model_name = rep(c("Separate", "Pooled", "Hierarchical"),
    each = ndraws
  ),
  mean_diet_5 = c(
    extract_variable(fit_separate, "mean_five"),
    extract_variable(fit_pooled, "mean_five"),
    extract_variable(fit_hierarchical, "mean_five")
  )
)
Error: No chains finished successfully. Unable to retrieve the draws.
# Mean observed weight per diet, these help to compare the posteriors to data.
diet_means <- sapply(
  1:4, function(diet) mean(Chick12[Chick12$Diet == diet, "weight"])
)

2.17 (m)

ggplot(posterior_mean_diet_4, aes(x = mean_diet_4, y = model_name)) +
  stat_dotsinterval(quantiles = 100, scale = .9) +
  vline_at(diet_means[4], size = 1, linetype = "dashed") +
  # Annotate the vline from above.
  annotate("text", label = "Observation mean", x = diet_means[4] - 5, y = .7,
           hjust = "right", size = 6) +
  # Add title and axis labels. One line to make everything so much more clear!
  labs(
    title = "Mean of diet 4",
    x = "Weight (g)",
    y = "Model"
  )
Error in ggplot(posterior_mean_diet_4, aes(x = mean_diet_4, y = model_name)): object 'posterior_mean_diet_4' not found

2.18 (n)

ggplot(predicted_weight_diet_4, aes(x = predicted_weight, y = model_name)) +
  stat_dotsinterval(quantiles = 100, scale = .9) +
  vline_at(diet_means[4], size = 1, linetype = "dashed") +
  # Annotate the vline from above.
  annotate("text", label = "Observation mean", x = diet_means[4] - 5, y = .7,
           hjust = "right", size = 6) +
  # Add title and axis labels. One line to make everything so much more clear!
  labs(
    title = "Weigth of a chick with diet 4",
    x = "Weight (g)",
    y = "Model"
  )
Error in ggplot(predicted_weight_diet_4, aes(x = predicted_weight, y = model_name)): object 'predicted_weight_diet_4' not found

2.19 (o)

ggplot(posterior_mean_diet_5, aes(x = mean_diet_5, y = model_name)) +
  # Draw the mean of each diet from the data as a dashed vertical line.
  vline_at(diet_means, size = .5, linetype = "dashed") +
  # dotsinterval gives mean, 50%, and 90% intervals + dotsplot with each dot
  # representing 1% of data (quantiles = 100).
  stat_dotsinterval(quantiles = 100, scale = .9) +
  # Annotate the vline from above.
  annotate(geom = "text", label = "Means of observed diets", y = .7, x = 100,
           hjust = "right", size = 5, family = "sans") +
  # Add title and axis labels. One line to make everything so much more clear!
  labs(title = "Mean of a new diet",
       x = "Weight (g)",
       y = "Model")
Error in ggplot(posterior_mean_diet_5, aes(x = mean_diet_5, y = model_name)): object 'posterior_mean_diet_5' not found

2.20 (p)

3 Hierarchical model with BRMS (3p)

Important

We have made changes to the assignment text and some of the rubrics to make it clearer.

3.1 (a)

bayesplot::mcmc_scatter(x = fit_hierarchical$draws(variables = c("mean_diet[4]", "sd_diets")),
                        np = nuts_params(fit_hierarchical)) +
  scale_y_log10() +
  labs(x = expression(mean_diet[4]), y = expression(sd_diets)) +
  ylim(c(0,NA))
Error: No chains finished successfully. Unable to retrieve the draws.

3.2 (b)

Create a brms model and sample from the posterior

brms_fit = brm(
  weight ~ 1 + (1 | Diet),
  data=Chick12,
  prior=c(
    # REPLACE WITH YOUR PRIOR DERIVED in 2)
    prior(normal(0,10), class="Intercept"), 
    # YOU CAN LEAVE THE BELOW PRIORS
    prior(exponential(.02), class="sd"), 
    prior(exponential(.02), class="sigma"), 
  ),
  backend = "cmdstanr",
  save_pars = save_pars(manual = c("z_1[1,4]"))
)
Error in c(prior(normal(0, 10), class = "Intercept"), prior(exponential(0.02), : argument 4 is empty

3.3 (c)

# Draws for mu_4
mu_4 = posterior_epred(brms_fit, newdata = data.frame(Diet=4))
Error in posterior_epred(brms_fit, newdata = data.frame(Diet = 4)): object 'brms_fit' not found
# Compute the mean, and quantiles. Remember to round your answers accordingly.
# ...

3.4 (d)

Due the non-centered parametrization, we need to transform compute the \(\mu_d\) term as the sum of the population intercept and the group specific deviation from the intercept. You can choose which diet to plot by modifying the d integer in r_Diet[d,Intercept].

draws = as_draws_df(brms_fit) |>
  posterior::mutate_variables(mean_diet_4 = `r_Diet[4,Intercept]` + b_Intercept)
Error in as_draws_df(brms_fit): object 'brms_fit' not found
bayesplot::mcmc_scatter(draws,
                        pars = c("mean_diet_4", "sd_Diet__Intercept"),
                        np = nuts_params(brms_fit)) +
  scale_y_log10() +
  xlab(expression(mean_diet[4])) +
  ylab(expression(sd_diets))
Error in posterior::is_draws(x): object 'draws' not found

3.5 (e)