Variational Inference: MCMC and Hybrid Samplers¶
This page covers gradient-based MCMC (HMC, NUTS), the hybrid
samplers that combine a variational warm-up with HMC chains, and
posterior predictive sampling from an
MCMCResult. The variational
families and ELBO objectives live in
the SVI guide.
MCMC: HMC and NUTS¶
When variational families underfit, fall back to gradient-based
MCMC. The kernel runs on the registry's unconstrained vector;
gradients flow through
torch.autograd.grad.
from quivers.inference import NUTSKernel, MCMC
kernel = NUTSKernel(
target_accept=0.8,
max_tree_depth=10,
mass_matrix="diagonal",
)
mcmc = MCMC(
kernel=kernel,
num_warmup=1000,
num_samples=2000,
num_chains=4,
)
result = mcmc.run(model, x, observations)
print(result.r_hat) # per-site split R-hat (Vehtari et al. 2021)
print(result.ess) # effective sample size
print(result.divergence_counts) # per-chain divergence count
print(result.total_divergences) # sum across chains
samples = result.samples # dict[str, Tensor] of shape (chains, draws, ...)
Both
HMCKernel and
NUTSKernel implement
Nesterov dual-averaging
step-size adaptation and Welford-online mass-matrix adaptation
during warmup. The leapfrog integrator vectorizes num_chains
chains as a leading batch axis; warmup runs unvectorised
(adaptation is impure), sampling runs vectorized (kernel is pure).
The MCMCResult exposes per-site
split R̂ and effective sample size (Vehtari et al.
2021) via r_hat and ess,
per-chain divergence counts via divergence_counts (with
total_divergences summing across chains), and posterior log
densities per draw via log_densities.
Hybrid samplers¶
AutoDAIS¶
Differentiable annealed importance
sampling wraps a base
guide with \(K\) HMC trajectories along an annealing path between
base and target. The base mean / scale, the step size, and the
inverse temperatures are jointly trained via SVI. Closes the parity
gap with NumPyro / Pyro AutoDAIS.
from quivers.inference import AutoNormalGuide, AutoDAIS
base = AutoNormalGuide(model, observed_names={"y"})
guide = AutoDAIS(
base,
model=model,
observations=observations,
num_steps=8,
init_step_size=0.05,
init_temperature=0.1,
)
# Plug into SVI exactly like any other guide.
WarmupThenHMC¶
Train a variational guide to convergence, then initialise HMC chains from the guide's posterior mean. Pareto-dominates cold-start HMC on hierarchical models with skewed prior support.
from quivers.inference import (
AutoMultivariateNormalGuide, NUTSKernel, WarmupThenHMC
)
sampler = WarmupThenHMC(
guide=AutoMultivariateNormalGuide(model, observed_names={"y"}),
kernel=NUTSKernel(),
svi_steps=1000,
mcmc_warmup=500,
mcmc_samples=2000,
)
svi_losses, result = sampler.run(model, x, observations)
Predictive with MCMC¶
Predictive consumes either a
Guide or an
MCMCResult. With an
MCMCResult, it iterates over posterior samples instead of calling
guide.rsample.
from quivers.inference import Predictive
predictive = Predictive(
model=conditioned.model,
posterior=result,
num_samples=500,
)
samples = predictive(x_new)
See also¶
- Inference Foundations: the trace,
conditioning, and
LatentRegistryprimitives every MCMC kernel consumes. - SVI guide: the variational counterpart, and
the SVI driver wrapping the
AutoDAIShybrid sampler. - Analysis Pipelines: Fitting and Diagnostics:
the high-level
fit(...)surface that wraps the MCMC and SVI drivers under one entry point.