Variational Inference: MCMC and Hybrid Samplers

This page covers gradient-based MCMC (HMC, NUTS), the hybrid samplers that combine a variational warm-up with HMC chains, and posterior predictive sampling from an MCMCResult. The variational families and ELBO objectives live in the SVI guide.

MCMC: HMC and NUTS

When variational families underfit, fall back to gradient-based MCMC. The kernel runs on the registry's unconstrained vector; gradients flow through torch.autograd.grad.

from quivers.inference import NUTSKernel, MCMC

kernel = NUTSKernel(
    target_accept=0.8,
    max_tree_depth=10,
    mass_matrix="diagonal",
)
mcmc = MCMC(
    kernel=kernel,
    num_warmup=1000,
    num_samples=2000,
    num_chains=4,
)
result = mcmc.run(model, x, observations)

print(result.r_hat)             # per-site split R-hat (Vehtari et al. 2021)
print(result.ess)               # effective sample size
print(result.divergence_counts) # per-chain divergence count
print(result.total_divergences) # sum across chains
samples = result.samples        # dict[str, Tensor] of shape (chains, draws, ...)

Both HMCKernel and NUTSKernel implement Nesterov dual-averaging step-size adaptation and Welford-online mass-matrix adaptation during warmup. The leapfrog integrator vectorizes num_chains chains as a leading batch axis; warmup runs unvectorised (adaptation is impure), sampling runs vectorized (kernel is pure).

The MCMCResult exposes per-site split and effective sample size (Vehtari et al. 2021) via r_hat and ess, per-chain divergence counts via divergence_counts (with total_divergences summing across chains), and posterior log densities per draw via log_densities.

Hybrid samplers

AutoDAIS

Differentiable annealed importance sampling wraps a base guide with \(K\) HMC trajectories along an annealing path between base and target. The base mean / scale, the step size, and the inverse temperatures are jointly trained via SVI. Closes the parity gap with NumPyro / Pyro AutoDAIS.

from quivers.inference import AutoNormalGuide, AutoDAIS

base = AutoNormalGuide(model, observed_names={"y"})
guide = AutoDAIS(
    base,
    model=model,
    observations=observations,
    num_steps=8,
    init_step_size=0.05,
    init_temperature=0.1,
)
# Plug into SVI exactly like any other guide.

WarmupThenHMC

Train a variational guide to convergence, then initialise HMC chains from the guide's posterior mean. Pareto-dominates cold-start HMC on hierarchical models with skewed prior support.

from quivers.inference import (
    AutoMultivariateNormalGuide, NUTSKernel, WarmupThenHMC
)

sampler = WarmupThenHMC(
    guide=AutoMultivariateNormalGuide(model, observed_names={"y"}),
    kernel=NUTSKernel(),
    svi_steps=1000,
    mcmc_warmup=500,
    mcmc_samples=2000,
)
svi_losses, result = sampler.run(model, x, observations)

Predictive with MCMC

Predictive consumes either a Guide or an MCMCResult. With an MCMCResult, it iterates over posterior samples instead of calling guide.rsample.

from quivers.inference import Predictive

predictive = Predictive(
    model=conditioned.model,
    posterior=result,
    num_samples=500,
)
samples = predictive(x_new)

See also

References