Predictive

Predictive inference and posterior sampling.

predictive

Posterior predictive sampling.

Given a trained posterior representation — either a variational guide or an MCMC chain — Predictive repeatedly samples latents from the posterior and traces the model forward to produce posterior predictive draws of every site.

Predictive

Predictive(model: MonadicProgram, posterior: Guide | MCMCResult, num_samples: int | None = None)

Posterior predictive sampler.

Accepts either a trained Guide (variational posterior) or an MCMCResult (Monte Carlo posterior). Variational case: draws num_samples fresh guide samples. MCMC case: iterates over the recorded posterior draws (one forward trace per draw, up to num_samples if specified).

PARAMETER DESCRIPTION
model

Generative model.

TYPE: MonadicProgram

posterior

Trained posterior representation.

TYPE: Guide or MCMCResult

num_samples

Number of predictive draws. Defaults to 100 for guides and num_chains * num_samples for MCMC results (use all recorded posterior draws). If supplied as an explicit integer for an MCMC result it is capped at the available draw count.

TYPE: int DEFAULT: None

Source code in src/quivers/inference/predictive.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def __init__(
    self,
    model: MonadicProgram,
    posterior: Guide | MCMCResult,
    num_samples: int | None = None,
) -> None:
    if not isinstance(posterior, (Guide, MCMCResult)):
        raise TypeError(
            f"Predictive: posterior must be Guide or MCMCResult; "
            f"got {type(posterior).__name__}"
        )
    self.model = model
    self.posterior = posterior
    if isinstance(posterior, MCMCResult):
        available = posterior.num_chains * posterior.num_samples
        self.num_samples = (
            min(num_samples, available) if num_samples is not None else available
        )
    else:
        self.num_samples = num_samples if num_samples is not None else 100
    if self.num_samples < 1:
        raise ValueError(
            f"Predictive: num_samples must be >= 1, got {self.num_samples}"
        )

__call__

__call__(x: Tensor, observations: dict[str, Tensor] | None = None) -> dict[str, Tensor]

Draw posterior predictive samples.

PARAMETER DESCRIPTION
x

Program input. Shape (batch, ...).

TYPE: Tensor

observations

Additional observed data to condition on.

TYPE: dict[str, Tensor] or None DEFAULT: None

RETURNS DESCRIPTION
dict[str, Tensor]

One key per site, value of shape (num_samples, batch, ...) (or the trace-side shape for plate sites).

Source code in src/quivers/inference/predictive.py
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
@torch.no_grad()
def __call__(
    self,
    x: torch.Tensor,
    observations: dict[str, torch.Tensor] | None = None,
) -> dict[str, torch.Tensor]:
    """Draw posterior predictive samples.

    Parameters
    ----------
    x : torch.Tensor
        Program input. Shape ``(batch, ...)``.
    observations : dict[str, torch.Tensor] or None
        Additional observed data to condition on.

    Returns
    -------
    dict[str, torch.Tensor]
        One key per site, value of shape
        ``(num_samples, batch, ...)`` (or the trace-side shape
        for plate sites).
    """
    observations = observations if observations is not None else {}
    if isinstance(self.posterior, Guide):
        latents_iter: list[dict[str, torch.Tensor]] = [
            self.posterior.rsample(x) for _ in range(self.num_samples)
        ]
    else:
        latents_iter = self._iter_mcmc_latents()

    collected: dict[str, list[torch.Tensor]] = {}
    for latents in latents_iter:
        all_obs = {**latents, **observations}
        tr = trace(self.model, x, observations=all_obs)
        for name, site in tr.sites.items():
            if site.is_deterministic:
                continue
            collected.setdefault(name, []).append(site.value)

    return {name: torch.stack(vals, dim=0) for name, vals in collected.items()}