MCMC

Markov-chain Monte Carlo: the HMC and NUTS kernels, the MCMC runner, and the MCMCResult summary.

The runner targets a MonadicProgram directly. For models that declare every latent as a sample site this is immediate; for models with nn.Parameters or intermediate latent sites the bayesian_lift_parameters lift produces the matching MonadicProgram.

mcmc

MCMC kernels and driver.

Public surface (also re-exported from quivers.inference):

  • MCMCKernel — ABC for Markov kernels on the flat unconstrained latent vector.
  • HMCKernel — Hamiltonian Monte Carlo with leapfrog integration, dual-averaging step-size adaptation, and Welford mass-matrix adaptation.
  • NUTSKernel — No-U-Turn Sampler with multinomial sampling.
  • MCMC — Chain orchestrator with warmup, parallel chains, and posterior diagnostics (split-:math:\hat R, effective sample size).
  • MCMCResult — Posterior samples + per-chain diagnostics.

MCMC

MCMC(kernel: MCMCKernel, num_warmup: int, num_samples: int, num_chains: int = 4, init_strategy: InitStrategy = 'prior')

MCMC chain runner.

PARAMETER DESCRIPTION
kernel

Markov kernel (e.g. HMCKernel, NUTSKernel).

TYPE: MCMCKernel

num_warmup

Number of adaptation steps. The kernel's adaptation machinery (dual averaging, Welford covariance) runs over this prefix.

TYPE: int

num_samples

Post-warmup samples per chain.

TYPE: int

num_chains

Independent chains. Default 4 (Stan / NumPyro default).

TYPE: int DEFAULT: 4

init_strategy

How to pick each chain's initial position.

TYPE: ('prior', 'zero', 'guide') DEFAULT: "prior"

Source code in src/quivers/inference/mcmc/driver.py
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
def __init__(
    self,
    kernel: MCMCKernel,
    num_warmup: int,
    num_samples: int,
    num_chains: int = 4,
    init_strategy: InitStrategy = "prior",
) -> None:
    if num_warmup < 0:
        raise ValueError(f"MCMC: num_warmup must be >= 0, got {num_warmup}")
    if num_samples < 1:
        raise ValueError(f"MCMC: num_samples must be >= 1, got {num_samples}")
    if num_chains < 1:
        raise ValueError(f"MCMC: num_chains must be >= 1, got {num_chains}")
    self.kernel = kernel
    self.num_warmup = num_warmup
    self.num_samples = num_samples
    self.num_chains = num_chains
    self.init_strategy = init_strategy

run

run(model: MonadicProgram, x: Tensor, observations: dict[str, Tensor], guide: Guide | None = None) -> MCMCResult

Run the configured kernel for num_chains chains of num_warmup + num_samples steps each.

Source code in src/quivers/inference/mcmc/driver.py
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
def run(
    self,
    model: MonadicProgram,
    x: torch.Tensor,
    observations: dict[str, torch.Tensor],
    guide: Guide | None = None,
) -> MCMCResult:
    """Run the configured kernel for ``num_chains`` chains of
    ``num_warmup + num_samples`` steps each."""
    observed_names = set(observations.keys())
    registry = LatentRegistry.from_model(model, observed_names)
    potential = PotentialFn(model, registry, x, observations)

    site_shapes: dict[str, tuple[int, ...]] = {
        site.name: site.constrained_shape or (1,)
        for site in registry.sites.values()
    }
    per_chain_samples: dict[str, list[torch.Tensor]] = {n: [] for n in site_shapes}
    per_chain_log_density = torch.empty(self.num_chains, self.num_samples)
    per_chain_accept = torch.empty(self.num_chains)
    per_chain_divergences = torch.empty(self.num_chains)

    for chain in range(self.num_chains):
        init_pos = self._initial_position(registry, guide, chain)
        state = self.kernel.init(registry, model, x, observations, init_pos)
        divergence_count = 0
        # Warmup.
        if self.num_warmup > 0:
            self.kernel.start_adaptation()
            for _ in range(self.num_warmup):
                state = self.kernel.step(state, potential)
                if state.diverged:
                    divergence_count += 1
            self.kernel.stop_adaptation()
        # Reset accept-count so the reported rate excludes warmup.
        sampling_state = KernelState(
            position=state.position,
            log_density=state.log_density,
            grad_log_density=state.grad_log_density,
            step_count=state.step_count,
            accept_count=0,
            diverged=False,
            extras=state.extras,
        )
        chain_samples: dict[str, list[torch.Tensor]] = {n: [] for n in site_shapes}
        sampling_divergences = 0
        sampling_accept = 0
        for s in range(self.num_samples):
            sampling_state = self.kernel.step(sampling_state, potential)
            if sampling_state.diverged:
                sampling_divergences += 1
            if sampling_state.extras.get("accept_prob", 0.0) > 0:
                sampling_accept += 1
            draws = self._to_constrained(registry, sampling_state.position)
            for n, v in draws.items():
                chain_samples[n].append(v)
            per_chain_log_density[chain, s] = sampling_state.log_density
        per_chain_accept[chain] = sampling_state.accept_count / float(
            self.num_samples
        )
        per_chain_divergences[chain] = float(sampling_divergences)
        for n, draws_list in chain_samples.items():
            stacked = torch.stack(draws_list, dim=0)
            per_chain_samples[n].append(stacked)

    # Stack per-chain → (num_chains, num_samples, *site_shape).
    samples: dict[str, torch.Tensor] = {}
    for n, chain_draws in per_chain_samples.items():
        samples[n] = torch.stack(chain_draws, dim=0)

    r_hat, ess = _diagnostics(samples)
    return MCMCResult(
        samples=samples,
        log_densities=per_chain_log_density,
        acceptance_rates=per_chain_accept,
        divergence_counts=per_chain_divergences,
        r_hat=r_hat,
        ess=ess,
        num_warmup=self.num_warmup,
        num_samples=self.num_samples,
    )

MCMCResult dataclass

MCMCResult(samples: dict[str, Tensor], log_densities: Tensor, acceptance_rates: Tensor, divergence_counts: Tensor, r_hat: dict[str, Tensor], ess: dict[str, Tensor], num_warmup: int, num_samples: int)

Posterior samples and per-chain diagnostics.

ATTRIBUTE DESCRIPTION
samples

Per-site posterior draws on the constrained support. Shape (num_chains, num_samples, *site_shape).

TYPE: dict[str, Tensor]

log_densities

Unconstrained-space log-density (Jacobian-corrected) at every posterior draw. Shape (num_chains, num_samples).

TYPE: Tensor

acceptance_rates

Per-chain post-warmup acceptance rate. Shape (num_chains,).

TYPE: Tensor

divergence_counts

Per-chain post-warmup divergence count. Shape (num_chains,).

TYPE: Tensor

r_hat

Per-site split-:math:\hat R. Each site's tensor has the site's shape (one scalar per coordinate).

TYPE: dict[str, Tensor]

ess

Per-site effective sample size. Same shape convention as r_hat.

TYPE: dict[str, Tensor]

num_warmup

TYPE: int

num_samples

TYPE: int

HMCKernel

HMCKernel(step_size: float = 0.1, num_steps: int = 10, mass_matrix: MassMatrixKind = 'identity', target_accept: float = 0.65, divergence_threshold: float = 1000.0, adapt_step_size: bool = True, adapt_mass_matrix: bool = True)

Bases: MCMCKernel

Hamiltonian Monte Carlo kernel with fixed trajectory length.

PARAMETER DESCRIPTION
step_size

Leapfrog step size. Adapted during warmup when adapt_step_size is true.

TYPE: float DEFAULT: 0.1

num_steps

Leapfrog steps per proposal.

TYPE: int DEFAULT: 10

mass_matrix

Mass-matrix shape. Diagonal / dense are adapted during warmup from the empirical covariance of warmup samples.

TYPE: ('identity', 'diagonal', 'dense') DEFAULT: "identity"

target_accept

Target Metropolis acceptance for dual averaging. Default 0.65 (Beskos et al.'s HMC-optimal acceptance for product-form targets).

TYPE: float DEFAULT: 0.65

divergence_threshold

Energy-error threshold for marking a proposal as divergent. Divergent steps still respect Metropolis correctness but are reported separately so the user can spot pathological regions.

TYPE: float DEFAULT: 1000.0

Source code in src/quivers/inference/mcmc/hmc.py
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
def __init__(
    self,
    step_size: float = 0.1,
    num_steps: int = 10,
    mass_matrix: MassMatrixKind = "identity",
    target_accept: float = 0.65,
    divergence_threshold: float = 1000.0,
    adapt_step_size: bool = True,
    adapt_mass_matrix: bool = True,
) -> None:
    if step_size <= 0:
        raise ValueError(f"HMCKernel: step_size must be > 0, got {step_size}")
    if num_steps < 1:
        raise ValueError(f"HMCKernel: num_steps must be >= 1, got {num_steps}")
    if not 0.0 < target_accept < 1.0:
        raise ValueError(
            f"HMCKernel: target_accept must be in (0, 1), got {target_accept}"
        )
    self._step_size = step_size
    self._num_steps = num_steps
    self._mass_kind = mass_matrix
    self._target_accept = target_accept
    self._divergence_threshold = divergence_threshold
    self._adapt_step_size = adapt_step_size
    self._adapt_mass_matrix = adapt_mass_matrix and mass_matrix != "identity"
    self._mass: _MassMatrix | None = None
    self._dual_avg: DualAveraging | None = None
    self._welford: WelfordCovariance | None = None

NUTSKernel

NUTSKernel(step_size: float = 0.1, max_tree_depth: int = 10, mass_matrix: MassMatrixKind = 'diagonal', target_accept: float = 0.8, divergence_threshold: float = 1000.0, adapt_step_size: bool = True, adapt_mass_matrix: bool = True)

Bases: MCMCKernel

No-U-Turn Sampler with multinomial sampling and the standard U-turn termination (Hoffman-Gelman 2014 algorithms 3 + 6, Betancourt 2017's generalized slice variant for multinomial sampling).

PARAMETER DESCRIPTION
step_size

Initial leapfrog step size; adapted via dual averaging.

TYPE: float DEFAULT: 0.1

max_tree_depth

Maximum tree doubling depth. Default 10 (Stan default).

TYPE: int DEFAULT: 10

target_accept

Target tree-averaged Metropolis acceptance for dual averaging. Default 0.8.

TYPE: float DEFAULT: 0.8

divergence_threshold

Energy-error threshold above which a leapfrog substep is marked divergent and terminates the tree on its branch.

TYPE: float DEFAULT: 1000.0

Source code in src/quivers/inference/mcmc/hmc.py
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
def __init__(
    self,
    step_size: float = 0.1,
    max_tree_depth: int = 10,
    mass_matrix: MassMatrixKind = "diagonal",
    target_accept: float = 0.8,
    divergence_threshold: float = 1000.0,
    adapt_step_size: bool = True,
    adapt_mass_matrix: bool = True,
) -> None:
    if step_size <= 0:
        raise ValueError(f"NUTSKernel: step_size must be > 0, got {step_size}")
    if max_tree_depth < 1:
        raise ValueError(
            f"NUTSKernel: max_tree_depth must be >= 1, got {max_tree_depth}"
        )
    if not 0.0 < target_accept < 1.0:
        raise ValueError(
            f"NUTSKernel: target_accept must be in (0, 1), got {target_accept}"
        )
    self._step_size = step_size
    self._max_depth = max_tree_depth
    self._mass_kind = mass_matrix
    self._target_accept = target_accept
    self._divergence_threshold = divergence_threshold
    self._adapt_step_size = adapt_step_size
    self._adapt_mass_matrix = adapt_mass_matrix and mass_matrix != "identity"
    self._mass: _MassMatrix | None = None
    self._dual_avg: DualAveraging | None = None
    self._welford: WelfordCovariance | None = None

KernelState dataclass

KernelState(position: Tensor, log_density: Tensor, grad_log_density: Tensor, step_count: int = 0, accept_count: int = 0, diverged: bool = False, extras: dict = dict())

Mutable container for one chain's MCMC state.

ATTRIBUTE DESCRIPTION
position

Current flat unconstrained latent vector. Shape (D,) for one chain, (num_chains, D) for parallel chains.

TYPE: Tensor

log_density

Unconstrained-space log-density (Jacobian-corrected log-joint) at position. Shape () or (num_chains,).

TYPE: Tensor

grad_log_density

Gradient of log_density with respect to position, same shape as position. Cached so leapfrog re-uses the last-evaluated gradient at the proposal endpoint.

TYPE: Tensor

step_count

Number of step calls so far (counts both warmup and post-warmup steps).

TYPE: int

accept_count

Number of proposals accepted across the chain. Useful for reporting acceptance rate.

TYPE: int

diverged

Whether the most recent step's energy error exceeded the kernel's divergence threshold. Reset by each kernel as appropriate.

TYPE: bool

extras

Per-kernel additional state (e.g. NUTS tree depth, HMC step-size adaptation cumulants).

TYPE: dict

MCMCKernel

Bases: ABC

Abstract Markov kernel on the flat unconstrained latent vector.

Concrete subclasses implement init and step. Adaptation phases (warmup) typically mutate kernel-internal state (step size, mass matrix) and freeze it for the sampling phase; the kernel's is_adapting flag tracks that.

init abstractmethod

init(registry: LatentRegistry, model: MonadicProgram, x: Tensor, observations: dict[str, Tensor], initial_position: Tensor) -> KernelState

Build the starting KernelState from the supplied initial flat unconstrained vector. The initial gradient is evaluated here so step can re-use it.

Source code in src/quivers/inference/mcmc/kernel.py
190
191
192
193
194
195
196
197
198
199
200
201
@abstractmethod
def init(
    self,
    registry: LatentRegistry,
    model: MonadicProgram,
    x: torch.Tensor,
    observations: dict[str, torch.Tensor],
    initial_position: torch.Tensor,
) -> KernelState:
    """Build the starting `KernelState` from the supplied
    initial flat unconstrained vector. The initial gradient is
    evaluated here so `step` can re-use it."""

step abstractmethod

step(state: KernelState, potential: PotentialFn) -> KernelState

Advance the chain one Metropolis step. The potential function is constructed once per MCMC.run and re-used across every step / chain.

Source code in src/quivers/inference/mcmc/kernel.py
203
204
205
206
207
208
209
210
211
@abstractmethod
def step(
    self,
    state: KernelState,
    potential: PotentialFn,
) -> KernelState:
    """Advance the chain one Metropolis step. The potential
    function is constructed once per `MCMC.run` and
    re-used across every step / chain."""

start_adaptation

start_adaptation() -> None

Enter the adaptation (warmup) phase.

Source code in src/quivers/inference/mcmc/kernel.py
213
214
215
def start_adaptation(self) -> None:
    """Enter the adaptation (warmup) phase."""
    self.is_adapting = True

stop_adaptation

stop_adaptation() -> None

Freeze the kernel's adapted parameters.

Source code in src/quivers/inference/mcmc/kernel.py
217
218
219
def stop_adaptation(self) -> None:
    """Freeze the kernel's adapted parameters."""
    self.is_adapting = False

PotentialFn

PotentialFn(model: MonadicProgram, registry: LatentRegistry, x: Tensor, observations: dict[str, Tensor])

Callable that maps a flat unconstrained position to the unconstrained-space negative log-density and its gradient.

HMC and NUTS need both the potential :math:U(z) = -\log \tilde{p}(z) (where :math:\tilde{p}(z) = p(T(z), y) \cdot |\det J_T(z)| is the Jacobian-corrected unconstrained-space joint) and its gradient :math:\nabla U(z). The two are computed in a single autograd pass and cached on the kernel state.

PARAMETER DESCRIPTION
model

Generative model.

TYPE: MonadicProgram

registry

Latent-site registry for model.

TYPE: LatentRegistry

x

Model input. Shape (batch, ...).

TYPE: Tensor

observations

Observed-site values and host data.

TYPE: dict[str, Tensor]

Source code in src/quivers/inference/mcmc/kernel.py
 97
 98
 99
100
101
102
103
104
105
106
107
def __init__(
    self,
    model: MonadicProgram,
    registry: LatentRegistry,
    x: torch.Tensor,
    observations: dict[str, torch.Tensor],
) -> None:
    self._model = model
    self._registry = registry
    self._x = x
    self._observations = observations

log_density

log_density(z: Tensor) -> Tensor

Unconstrained-space log-density (Jacobian-corrected).

Trajectories that wander to the edge of a constrained support can produce values that fall outside torch.distributions' validation envelope (e.g. exact zeros against a strictly-positive support after a long leapfrog stride). Rather than letting the resulting ValueError propagate and kill the chain, this method returns -inf for those positions; the kernel reads non-finite log-densities as divergent transitions and rejects them in the Metropolis step.

Source code in src/quivers/inference/mcmc/kernel.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
def log_density(self, z: torch.Tensor) -> torch.Tensor:
    """Unconstrained-space log-density (Jacobian-corrected).

    Trajectories that wander to the edge of a constrained
    support can produce values that fall outside
    `torch.distributions`' validation envelope (e.g. exact
    zeros against a strictly-positive support after a long
    leapfrog stride). Rather than letting the resulting
    ``ValueError`` propagate and kill the chain, this method
    returns ``-inf`` for those positions; the kernel reads
    non-finite log-densities as divergent transitions and
    rejects them in the Metropolis step.
    """
    try:
        per_site_unc = self._registry.unflatten_unconstrained(z)
        constrained: dict[str, torch.Tensor] = {}
        log_det_total = torch.zeros((), device=z.device, dtype=z.dtype)
        for site in self._registry.sites.values():
            z_site = per_site_unc[site.name]
            if not site.is_plate:
                z_site = z_site.unsqueeze(0)
            v = site.bijector(z_site)
            log_det_total = log_det_total + (
                site.bijector.log_abs_det_jacobian(z_site, v).sum()
            )
            if site.constrained_dim == 1 and v.dim() >= 1 and v.shape[-1] == 1:
                v = v.squeeze(-1)
            constrained[site.name] = v
        log_p = self._model.log_joint(
            self._x, {**constrained, **self._observations}
        )
        result = log_p.sum() + log_det_total
    except ValueError:
        return torch.tensor(float("-inf"), device=z.device, dtype=z.dtype)
    if not torch.isfinite(result):
        return torch.tensor(float("-inf"), device=z.device, dtype=z.dtype)
    return result

value_and_grad

value_and_grad(z: Tensor) -> tuple[Tensor, Tensor]

Return (log_density, grad_log_density) for z.

z is expected to be a detached tensor; we make a fresh leaf with requires_grad=True so gradient propagation doesn't leak into the kernel's accumulated state.

For divergent positions (where the log-density is -inf), returns a zero gradient — the kernel rejects the trajectory in the Metropolis step anyway, and a zero gradient keeps the leapfrog integrator from producing NaN downstream.

Source code in src/quivers/inference/mcmc/kernel.py
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
def value_and_grad(self, z: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    """Return ``(log_density, grad_log_density)`` for ``z``.

    ``z`` is expected to be a detached tensor; we make a fresh
    leaf with ``requires_grad=True`` so gradient propagation
    doesn't leak into the kernel's accumulated state.

    For divergent positions (where the log-density is
    ``-inf``), returns a zero gradient — the kernel rejects
    the trajectory in the Metropolis step anyway, and a zero
    gradient keeps the leapfrog integrator from producing NaN
    downstream.
    """
    z_leaf = z.detach().clone().requires_grad_(True)
    try:
        ld = self.log_density(z_leaf)
        if not torch.isfinite(ld):
            return ld.detach(), torch.zeros_like(z)
        grad = torch.autograd.grad(
            ld, z_leaf, create_graph=False, allow_unused=False
        )[0]
        if grad is None or not torch.isfinite(grad).all():
            return ld.detach(), torch.zeros_like(z)
        return ld.detach(), grad.detach()
    except ValueError, RuntimeError:
        return (
            torch.tensor(float("-inf"), device=z.device, dtype=z.dtype),
            torch.zeros_like(z),
        )