The No-U-Turn Sampler (NUTS) / Dynamic Hamiltonian Monte Carlo

The Hamiltonian Monte Carlo, henceforth HMC, enables efficient exploration of complex and confined target distributions, but its performance is largely governed by the step size \(\epsilon\) and number of steps \(L\) used in simulating the Hamiltonian trajectories. Depending on how large \(L\) is, too small a step size may unduly increase the computation time or lead to exploration of the target distribution that resembles pure random walk, which we want to avoid when using HMC. Too large a step size makes the error explode in simulating the Hamiltonian trajectory, which results in low acceptance rates and poor exploration. Assuming that we have found a stable value for \(\epsilon\), finding a good value for \(L\) is not easy in general. We would like \(L\) to be large enough so that the trajectory travels long enough through the target distribution yielding almost uncorrelated draws. However, if \(L\) is too large, then due to Hamiltonian dynamics, the trajectory starts to rotate towards the starting point and eventually makes a full cycle resulting in wasted computation time, and even worse, the trajectory can stop near the starting point. Too few steps can also dilute the exploration to random walk.

In the static version of HMC the user needs to hand-tune the parameters either by visualizing the trajectories, performing several preliminary runs, assessing the convergence of the chains or, more preferably, a combination of these. For many distributions, however, optimal parameter values in one region may be suboptimal in some other region. For example, if the target distribution is constrained in some directions, consider e.g. heavily correlated random variables, then we would like to use long trajectories in the less constrained directions. Also, when the number of parameters grows, visualizing and interpreting the trajectories becomes more and more difficult. Fortunately, setting the parameters via trial and error can be avoided for both \(\epsilon\) and \(L\), and they can be adapted dynamically instead. In this demo, however, we are only considering the dynamic adaptation of the number of steps \(L\).

We would like to find an algebraic representation for the criterion of having simulated the trajectory long enough. The trajectory is considered to be “long enough” when the distance between the starting point and the proposal is maximized. Let \(\theta\) be the starting point of the simulation. Now a natural option would be to simulate deterministically as in the static HMC (BDA3 p. 300) until we are at a position-momentum pair \((\tilde{\theta}, \tilde{\phi})\) for which it holds that

\[ (\tilde{\theta} - \theta)^T \tilde{\phi} < 0 \]

That is, we should stop the simulation when the angle between the current momentum \(\tilde{\phi}\) and the vector from the starting point \(\theta\) to the current point \(\tilde{\theta}\) is less than 90 degrees. This criterion corresponds to stopping when the trajectory starts to rotate towards the starting point and the distance between \(\theta\) and \(\tilde{\theta}\) starts to decrease. Unfortunately, the above scheme does not guarantee reversibility of the Markov chain, and thus it is not guaranteed to converge to the desired target distribution. The reversibility can be remediated by making the simulation process random by allowing steps both forward and backward in time. This is incorporated in the NUTS algorithm (Hoffman and Gelman 2014) which we will present next. Further variants of dynamic HMC are described by Betancourt (2018). The specific algorithm and implementation details of dynamic HMC in Stan slightly differ from the ones described by Hoffman and Gelman (2014) and Betancourt (2018), and unfortunately there is no detailed description of the differences available except in the C++ code. The algorithm presented here follows the original NUTS algorithm by Hoffman and Gelman (2014).

Assume that \(p(\theta | y)\) is the target distribution and \(\theta^{t}\) is the current draw. Let \(\epsilon > 0\) be the step size for the leapfrog steps. The NUTS algorithm for drawing \(\theta^{t+1}\) consists of the following steps.

  1. Sample a momentum vector \(\phi \sim N(0, M)\), where \(M\) is a symmetric positive-definite mass matrix. This step is identical to the first step in static HMC.

  2. Sample a slice variable \(u | \theta^t, \phi \sim \text{Uniform}([0, e^{-H(\theta^t, \phi)}])\), where \(H(\theta, \phi) = -\log(p(\theta | y)) + \frac{1}{2} \phi^T M^{-1} \phi\) is the Hamiltonian function that specifies the “total energy” at the position-momentum pair \((\theta, \phi)\).

  3. Generate a set \(C\) of proposal points as follows. Set \(C = \{ \theta^t \}\). Starting the Hamiltonian simulation from \((\theta^t, \phi)\), for \(j=0,1,2,...\),

    • Sample \(v_j \sim \text{Uniform}(\{ -1, 1 \})\) and perform \(2^j\) leapfrog steps with the step size \(v_j \epsilon\), always appending to the trajectory from the previous iteration.

    • The newly generated \(2^j\) points are equivalent to the leaf nodes of a perfect, ordered binary tree with depth \(j\), see Figure 1 below. Now consider each of the \(2^j-1\) subtrees that have depth greater than \(0\), and for a given subtree denote the points associated with the leftmost and rightmost leaves by \((\theta^-, \phi^-)\) and \((\theta^+, \phi^+)\), respectively. If for some subtree it holds that

      \[ (\theta^+ - \theta^-)^T \phi^- < 0 \quad \text{or} \quad (\theta^+ - \theta^-)^T \phi^+ < 0, \]

      then terminate the simulation loop. If for some point \((\theta, \phi)\) from the new \(2^j\) points it holds that

      \[ H(\theta, \phi) + \log u > \Delta_{\text{max}} \]

      for some nonnegative \(\Delta_{\text{max}}\), then also terminate the simulation loop. We will use the value \(\Delta_{\text{max}} = 1000\).

    • If the former stopping criterion above holds for the leftmost and rightmost points of the full simulated trajectory, then for all \((\theta, \phi)\) of the newly simulated \(2^j\) points, add \(\theta\) to \(C\) if \(e^{-H(\theta, \phi)} \geq u\). Finally, terminate the simulation loop.

    • If no stopping criterion was met, then for all \((\theta, \phi)\) of the new points, add \(\theta\) to \(C\) if \(e^{-H(\theta, \phi)} \geq u\), and move on to the next iteration.

  4. Sample a value for \(\theta^{t+1}\) uniformly from \(C\).

Figure 1: The new points resulting from the $2^j$ (here $j=2$) leapfrog steps are equivalent to the leaf nodes of a perfect, ordered binary tree with depth $j$. The stopping conditions are checked for all the leaves and all the $2^j-1$ subtrees with depth greater than $0$. Note that the interior nodes are not important.

Figure 1: The new points resulting from the \(2^j\) (here \(j=2\)) leapfrog steps are equivalent to the leaf nodes of a perfect, ordered binary tree with depth \(j\). The stopping conditions are checked for all the leaves and all the \(2^j-1\) subtrees with depth greater than \(0\). Note that the interior nodes are not important.

In summary, after each iteration the number of leapfrog steps is doubled, and the simulation is stopped whenever any subtrajectory equivalent to a binary subtree satisfies one of the above stopping criteria. The slice variable \(u\) is chosen so that \(\theta^{t+1}\) can be drawn uniformly from \(C\). Note that if we were able to simulate the Hamiltonian trajectory exactly, \(H(\theta, \phi)\) would remain constant over the whole trajectory, and \(e^{-H(\theta, \phi)} \geq u\) would hold for all points on the trajectory. Thus, the condition \(e^{-H(\theta, \phi)} \geq u\) can be interpreted as discarding all points for which the simulation error is too large. The stopping condition \(H(\theta, \phi) + \log u > \Delta_{\text{max}}\) in step 3 is also related to terminating the simulation if the simulation error becomes too large. Note that the number of leapfrog steps increases exponentially, but as long as \(\epsilon\) is small enough and the target distribution \(p(\theta | y)\) is sufficiently well-behaved, the simulation error shouldn’t increase in the number of leapfrog steps, which makes the above dynamic scheme attractive. As usual with HMC, the leapfrog steps require the gradient of the log-density, i.e. \(\frac{d \log p(\theta | y)}{d\theta}\).

Implementation and visualization of NUTS

library(latex2exp)
library(ggplot2)
theme_set(theme_minimal())
library(tidyr)
library(MASS)
library(gganimate)
library(posterior)

Next we will implement the NUTS algorithm and visualize how the trajectories look like. The following subroutines will be useful in the implementation.

# Perform a single leapfrog step.
leapfrog <- function(theta, phi, epsilon, grad_log_p, M_inv) {
  phi_tilde <- phi + 0.5 * epsilon * grad_log_p(theta)
  theta_tilde <- theta + epsilon * M_inv %*% phi_tilde
  phi_tilde <- phi_tilde + 0.5 * epsilon * grad_log_p(theta_tilde)
  list(theta=c(theta_tilde), phi=c(phi_tilde))
}

# Evaluate the Hamiltonian.
hamiltonian <- function(theta, phi, log_p, M_inv) {
  -log_p(theta) + 0.5 * sum(phi * (M_inv %*% phi))
}

The step 3 of the NUTS algorithm above is the trickiest to implement. In practice, instead of considering all the subtrees only after simulating the \(2^j\) leapfrog steps, it’s easier to perform the steps and construct the trees recursively at the same time.

# Perform the leapfrog steps and construct the subtrees recursively.
build_tree <- function(theta, phi, u, v, j, epsilon, log_p, grad_log_p, M_inv, trajectory) {
  if (j == 0) {
    # Base case. Take just one leapfrog step.
    step <- leapfrog(theta, phi, v*epsilon, grad_log_p, M_inv)
    theta <- step$theta
    phi <- step$phi
    
    # Color the new theta red if it's certainly not a valid proposal, and
    # color it yellow otherwise.
    if (u <= exp(-hamiltonian(theta, phi, log_p, M_inv))) {
      color <- "yellow"
    } else {
      color <- "red"
    }
    
    # Update the trajectory.
    if (v == -1) {
      trajectory$thetas <- rbind(theta, trajectory$thetas)
      trajectory$colors <- c(color, trajectory$colors)
    } else {
      trajectory$thetas <- rbind(trajectory$thetas, theta)
      trajectory$colors <- c(trajectory$colors, color)
    }
    
    # Check the second stopping condition in step 3. We use Delta_max=1000.
    terminate <- hamiltonian(theta, phi, log_p, M_inv) + log(u) > 1000
    
    return(list(theta_minus=theta, phi_minus=phi, theta_plus=theta, phi_plus=phi,
                trajectory=trajectory, terminate=terminate))
  } else {
    # Build the left and right subtrees recursively.
    tree1 <- build_tree(theta, phi, u, v, j-1, epsilon, 
                        log_p, grad_log_p, M_inv, trajectory)
    theta_minus <- tree1$theta_minus
    phi_minus <- tree1$phi_minus
    theta_plus <- tree1$theta_plus
    phi_plus <- tree1$phi_plus
    if (v == -1) {
      tree2 <- build_tree(theta_minus, phi_minus, u, v, j-1, epsilon,
                          log_p, grad_log_p, M_inv, tree1$trajectory)
      theta_minus <- tree2$theta_minus
      phi_minus <- tree2$phi_minus
    } else {
      tree2 <- build_tree(theta_plus, phi_plus, u, v, j-1, epsilon,
                          log_p, grad_log_p, M_inv, tree1$trajectory)
      theta_plus <- tree2$theta_plus
      phi_plus <- tree2$phi_plus
    }
    
    # Check the U-turn stopping condition and also take into account
    # the stopping flags from building the subtrees.
    condition1 <- sum((theta_plus - theta_minus) * phi_minus) < 0
    condition2 <- sum((theta_plus - theta_minus) * phi_plus) < 0
    terminate <- condition1 || condition2 || tree1$terminate || tree2$terminate
    
    return(list(theta_minus=theta_minus, phi_minus=phi_minus, 
                theta_plus=theta_plus, phi_plus=phi_plus,
                trajectory=tree2$trajectory, terminate=terminate))
  }
}

We are now ready to implement the sampling function.

# Draw the given number of samples from the target distribution using NUTS.
#
# Parameters:
#   - theta         Starting point of the chain.
#   - log_p         Log target distribution function.
#   - grad_log_p    Gradient of the log target distribution function.
#   - n_iter        Number of draws.
#   - M             Symmetric positive definite mass matrix. Default: identity matrix.
#   - epsilon       Step size for the leapfrog steps. Default: 0.1
#   - store_paths   Whether to store and return the individual trajectories
#                   for visualization purposes. Default: TRUE
#
# If store_paths=TRUE, returns a list consisting of the n_iter draws and their
# corresponding trajectories.
# If store_paths=FALSE, returns just the draws.
NUTS <- function(theta, log_p, grad_log_p, n_iter, M=NULL, epsilon=0.1, store_paths=TRUE) {
  if (is.null(M)) {
    M <- diag(length(theta))
  }
  M_inv <- solve(M)
  
  chain <- list(draws=matrix(nrow=n_iter+1, ncol=length(theta)), 
                trajectories=list())
  chain$draws[1,] <- theta
  
  for (i in 1:n_iter) {
    # Step 1. Draw a momentum vector from N(0, M).
    phi <- mvrnorm(n=1, rep(0, length(theta)), M)
  
    # Step 2. Draw a slice variable u.
    u <- runif(1, 0, exp(-hamiltonian(chain$draws[i,], phi, log_p, M_inv)))
    
    # Step 3. Perform the leapfrog steps. Color the starting point black.
    theta_minus <- chain$draws[i,]
    phi_minus <- phi
    theta_plus <- chain$draws[i,]
    phi_plus <- phi
    trajectory <- list(thetas=matrix(theta_minus, nrow=1), colors=c("black"))
    j <- 0
    terminate <- FALSE
    while (!terminate) {
      # Draw a direction.
      v <- sample(c(-1, 1), 1)
      if (v == -1) {
        tree <- build_tree(theta_minus, phi_minus, u, v, j, epsilon,
                           log_p, grad_log_p, M_inv, trajectory)
        theta_minus <- tree$theta_minus
        phi_minus <- tree$phi_minus
      } else {
        tree <- build_tree(theta_plus, phi_plus, u, v, j, epsilon,
                           log_p, grad_log_p, M_inv, trajectory)
        theta_plus <- tree$theta_plus
        phi_plus <- tree$phi_plus
      }
      trajectory$thetas <- tree$trajectory$thetas
      trajectory$colors <- tree$trajectory$colors
      
      # If a stopping condition was met, the newly generated subtrajectory
      # cannot be used in the sampling and is thus colored red.
      if (tree$terminate) {
        for (k in 1:2^j) {
          if (v == 1) {
            trajectory$colors[length(trajectory$colors)+1-k] <- "red"
          } else {
            trajectory$colors[k] <- "red"
          }
        }
      }
      
      # Check the stopping conditions.
      condition1 <- sum((theta_plus - theta_minus) * phi_minus) < 0
      condition2 <- sum((theta_plus - theta_minus) * phi_plus) < 0
      terminate <- tree$terminate || condition1 || condition2
      
      j <- j + 1
    }
    
    # Step 4. Draw the next sample uniformly amongst the starting point and
    # the yellow-colored points. The chosen point is colored green.
    valid_indices <- 1:length(trajectory$colors)
    valid_indices <- valid_indices[(trajectory$colors == "yellow") |
                                   (trajectory$colors == "black")]
    draw_id <- sample(valid_indices, 1)
    trajectory$colors[draw_id] <- "green"
    chain$draws[i+1,] <- trajectory$thetas[draw_id,]
    
    if (store_paths) {
      chain$trajectories[[length(chain$trajectories)+1]] <- trajectory
    }
  }
  
  if (store_paths) {
    return(chain)
  } else {
    return(chain$draws)
  }
}

Let’s then visualize the NUTS algorithm. As the target distribution we use a bivariate normal distribution \(p(\theta | y) = N(\theta | 0, \Sigma)\), where

\[ \Sigma = \begin{bmatrix} 1 & 0.8 \\ 0.8 & 1 \end{bmatrix} \]

Since \(p(\theta | y) \propto \exp(-\frac{1}{2} \theta^T \Sigma^{-1} \theta)\), the log-density is given, up to an additive constant, by \(\log p(\theta | y) = -\frac{1}{2} \theta^T \Sigma^{-1} \theta\), and the gradient of the log-density is given by \(\frac{d \log p(\theta | y)}{d\theta} = -\Sigma^{-1} \theta\).

Sigma <- matrix(c(1, 0.8, 0.8, 1), nrow=2)
Sigma_inv <- solve(Sigma)

log_p <- function(theta) {
  -0.5 * sum(theta * (Sigma_inv %*% theta))
}

grad_log_p <- function(theta) {
  -Sigma_inv %*% theta
}

Let \(\theta^0 = (-2.5, 2.5)\) be the starting point. Simulate 2000 draws.

# Draw samples from the target distribution to visualize its 90%
# HPD region with ggplot's stat_ellipse function.
target_draws <- data.frame(mvrnorm(100000, c(0, 0), Sigma))

theta_0 <- c(-2.5, 2.5)
n_iter <- 2000
nuts_chain <- NUTS(theta_0, log_p, grad_log_p, n_iter)

The following example shows how to plot a single trajectory

# Try to find a trajectory with at least 32 points.
t <- nuts_chain$trajectories[[100]]
for (i in 1:length(nuts_chain$trajectories)) {
  if (length(nuts_chain$trajectories[[i]]$colors) >= 32) {
    t <- nuts_chain$trajectories[[i]]
    break
  }
}
df <- data.frame(theta1     = t$thetas[,1],
                 theta2     = t$thetas[,2],
                 color      = t$colors,
                 theta1_end = c(t$thetas[1,1], t$theta[-nrow(t$thetas), 1]),
                 theta2_end = c(t$thetas[1,2], t$theta[-nrow(t$thetas), 2]))
ggplot() +
  stat_ellipse(data = target_draws, aes(x = X1, y = X2, color = "HPD"), level = 0.9) +
  geom_segment(data = df, aes(x = theta1, xend = theta1_end, color = "gray",
                              y = theta2, yend = theta2_end), alpha = 0.5) +
  geom_point(data = df, aes(theta1, theta2, color = color), size = 2) + 
  coord_cartesian(xlim = c(-4, 4), ylim = c(-4, 4)) + 
  labs(x = TeX("$\\theta_1$"), y = TeX("$\\theta_2$")) +
  scale_color_manual(values = c("black" = "black", 
                                "red" = "red", 
                                "yellow" = "yellow3", 
                                "green" = "green4", 
                                "HPD" = "blue"),
                     labels = c("black" = "Starting point",
                                "red" = "Outside of the slice",
                                "yellow" = "Within the slice",
                                "green" = "Draw from the slice",
                                "HPD" = "90% HPD")) +
  guides(color = guide_legend(override.aes = list(shape = c(16, 16, 16, 16, NA),
                                                  linetype = c(0,0,0,0,1)))) +
  theme(legend.position = "bottom", legend.title = element_blank())

The following example displays how to animate the trajectories. Let’s consider the first 50 draws.

nframes <- 50 # Set this to the number of iterations you wish to visualize.
df_trajectories <- data.frame()
df_draws <- data.frame()
for (i in 1:nframes) {
  t <- nuts_chain$trajectories[[i]]
  df_trajectories <- rbind(df_trajectories, 
                           list(rep(i, nrow(t$thetas)),
                                t$thetas[,1],
                                t$thetas[,2],
                                t$colors,
                                c(t$thetas[1,1], t$theta[-nrow(t$thetas), 1]),
                                c(t$thetas[1,2], t$theta[-nrow(t$thetas), 2])))
  df_draws <- rbind(df_draws, list(rep(i, i),
                                   nuts_chain$draws[1:i, 1],
                                   nuts_chain$draws[1:i, 2]))
}
names(df_trajectories) <- c("iter", "theta1", "theta2", 
                            "color", "theta1_end", "theta2_end")
names(df_draws) <- c("iter", "theta1", "theta2")

p <- ggplot() +
  stat_ellipse(data = target_draws, aes(x = X1, y = X2, color = "HPD"), level = 0.9) +
  geom_point(data = df_draws, aes(x = theta1, y = theta2, color = "black"),
             size = 1, alpha = 0.5) +
  geom_segment(data = df_trajectories, aes(x = theta1, xend = theta1_end,
                                           y = theta2, yend = theta2_end,
                                           color = "gray"), alpha = 0.5) +
  geom_point(data = df_trajectories, aes(theta1, theta2, color = color), size = 2) +
  coord_cartesian(xlim = c(-4, 4), ylim = c(-4, 4)) +
  labs(x = TeX("$\\theta_1$"), y = TeX("$\\theta_2$")) +
  scale_color_manual(values = c("black" = "black", 
                                "red" = "red", 
                                "yellow" = "yellow3", 
                                "green" = "green4", 
                                "HPD" = "blue"),
                     labels = c("black" = "Starting point",
                                "red" = "Outside of the slice",
                                "yellow" = "Within the slice",
                                "green" = "Draw from the slice",
                                "HPD" = "90% HPD")) +
  guides(color = guide_legend(override.aes = list(shape = c(16, 16, 16, 16, NA),
                                                  linetype = c(0,0,0,0,1)))) +
  theme(legend.position = "bottom", legend.title = element_blank()) 
  
# Make sure that nframes is equal to the number of iterations that you wish to visualize.
anim <- animate(p + transition_time(iter), nframes = nframes, fps = 1)

Show the animation

anim