Predictive

Predictive inference and posterior sampling.

predictive

Posterior predictive sampling.

After training a guide via SVI, the Predictive class draws posterior predictive samples by repeatedly sampling latents from the guide and running the model forward.

Predictive

Predictive(model: MonadicProgram, guide: Guide, num_samples: int = 100)

Draw posterior predictive samples from a trained model + guide.

PARAMETER DESCRIPTION
model

The generative model.

TYPE: MonadicProgram

guide

The trained variational guide.

TYPE: Guide

num_samples

Number of posterior samples to draw.

TYPE: int DEFAULT: 100

Source code in src/quivers/inference/predictive.py
30
31
32
33
34
35
36
37
38
def __init__(
    self,
    model: MonadicProgram,
    guide: Guide,
    num_samples: int = 100,
) -> None:
    self.model = model
    self.guide = guide
    self.num_samples = num_samples

__call__

__call__(x: Tensor, observations: dict[str, Tensor] | None = None) -> dict[str, Tensor]

Draw posterior predictive samples.

For each of num_samples iterations, samples latents from the guide and traces the model with those latents as observations. Returns all site values stacked along a new leading dimension.

PARAMETER DESCRIPTION
x

Program input. Shape (batch, ...).

TYPE: Tensor

observations

Additional observed data to condition on.

TYPE: dict[str, Tensor] or None DEFAULT: None

RETURNS DESCRIPTION
dict[str, Tensor]

Each key is a site name, each value has shape (num_samples, batch, ...).

Source code in src/quivers/inference/predictive.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
@torch.no_grad()
def __call__(
    self,
    x: torch.Tensor,
    observations: dict[str, torch.Tensor] | None = None,
) -> dict[str, torch.Tensor]:
    """Draw posterior predictive samples.

    For each of ``num_samples`` iterations, samples latents from
    the guide and traces the model with those latents as
    observations. Returns all site values stacked along a new
    leading dimension.

    Parameters
    ----------
    x : torch.Tensor
        Program input. Shape (batch, ...).
    observations : dict[str, torch.Tensor] or None
        Additional observed data to condition on.

    Returns
    -------
    dict[str, torch.Tensor]
        Each key is a site name, each value has shape
        (num_samples, batch, ...).
    """
    if observations is None:
        observations = {}

    collected: dict[str, list[torch.Tensor]] = {}

    for _ in range(self.num_samples):
        # sample latents from guide
        latents = self.guide.rsample(x)

        # merge with any fixed observations
        all_obs = {**latents, **observations}

        # trace the model with these values clamped
        tr = trace(self.model, x, observations=all_obs)

        for name, site in tr.sites.items():
            if site.is_deterministic:
                continue

            if name not in collected:
                collected[name] = []

            collected[name].append(site.value)

    # stack along a new leading dimension
    return {name: torch.stack(vals, dim=0) for name, vals in collected.items()}