Variational Guides

Variational guide distributions for approximate inference. The shipped guides (AutoNormalGuide, AutoDeltaGuide, AutoMultivariateNormalGuide, AutoLowRankMultivariateNormalGuide, AutoLaplaceApproximation, AutoNormalizingFlow, AutoIAFGuide, AutoNeuralSplineGuide, AutoMixtureGuide) live as submodules of quivers.inference.guides and share the Guide ABC and the LatentRegistry introspection layer.

guides

Variational guide families.

Public surface (re-exported by the parent quivers.inference package): one ABC (Guide) plus a zoo of concrete Auto*Guide subclasses spanning the standard variational-family ladder from mean-field Normal to normalizing-flow stacks and hierarchical / mixture / structured guides.

Every concrete guide is built against a single quivers.inference.registry.LatentRegistry and obeys the shape contract documented on Guide.

Guide

Bases: Module, ABC

Abstract variational guide.

Subclasses MUST implement rsample and log_prob and expose latent_names. They MAY override registry if they construct their registry lazily, but the default implementation expects self._registry to be set in __init__.

registry property

registry: LatentRegistry

The LatentRegistry this guide was built against.

latent_names abstractmethod property

latent_names: list[str]

Names of latent variables this guide covers.

build_registry classmethod

build_registry(model: MonadicProgram, observed_names: set[str] | frozenset[str]) -> LatentRegistry

Convenience wrapper around LatentRegistry.from_model so guide constructors can do self._registry = self.build_registry(model, obs) without an extra import.

Source code in src/quivers/inference/guides/base.py
44
45
46
47
48
49
50
51
52
53
54
@classmethod
def build_registry(
    cls,
    model: MonadicProgram,
    observed_names: set[str] | frozenset[str],
) -> LatentRegistry:
    """Convenience wrapper around
    `LatentRegistry.from_model` so guide constructors
    can do ``self._registry = self.build_registry(model, obs)``
    without an extra import."""
    return LatentRegistry.from_model(model, observed_names)

rsample abstractmethod

rsample(x: Tensor) -> dict[str, Tensor]

Reparameterized sample from :math:q_\phi(z \mid x).

PARAMETER DESCRIPTION
x

Program input. Shape (batch, ...). Used only for its batch dim and device; the variational parameters are stored on the guide itself.

TYPE: Tensor

RETURNS DESCRIPTION
dict[str, Tensor]

Per-site constrained samples shaped to match the model's trace-side convention.

Source code in src/quivers/inference/guides/base.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
@abstractmethod
def rsample(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
    """Reparameterized sample from :math:`q_\\phi(z \\mid x)`.

    Parameters
    ----------
    x : torch.Tensor
        Program input. Shape ``(batch, ...)``. Used only for
        its batch dim and device; the variational parameters
        are stored on the guide itself.

    Returns
    -------
    dict[str, torch.Tensor]
        Per-site constrained samples shaped to match the
        model's trace-side convention.
    """
    ...

log_prob abstractmethod

log_prob(x: Tensor, sites: dict[str, Tensor]) -> Tensor

Log density of sites under :math:q_\phi(z \mid x), with the change-of-variables Jacobian correction baked in.

RETURNS DESCRIPTION
Tensor

Shape (batch,).

Source code in src/quivers/inference/guides/base.py
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
@abstractmethod
def log_prob(
    self,
    x: torch.Tensor,
    sites: dict[str, torch.Tensor],
) -> torch.Tensor:
    """Log density of ``sites`` under :math:`q_\\phi(z \\mid x)`,
    with the change-of-variables Jacobian correction baked in.

    Returns
    -------
    torch.Tensor
        Shape ``(batch,)``.
    """
    ...

AutoDeltaGuide

AutoDeltaGuide(model: MonadicProgram, observed_names: set[str], init_value: float = 0.0)

Bases: Guide

Dirac-delta MAP guide with per-site constrained bijector.

PARAMETER DESCRIPTION
model

Generative model.

TYPE: MonadicProgram

observed_names

Variable names treated as observations.

TYPE: set[str]

init_value

Initial unconstrained-space coordinate for every latent. Default 0.0; for the standard bijectors this maps to a sensible interior point of each support (the median of a HalfNormal, the centre of the unit interval, the uniform Dirichlet, etc.). Small Gaussian noise is added so two coordinate values don't collide.

TYPE: float DEFAULT: 0.0

Source code in src/quivers/inference/guides/delta.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def __init__(
    self,
    model: MonadicProgram,
    observed_names: set[str],
    init_value: float = 0.0,
) -> None:
    super().__init__()
    self._registry = self.build_registry(model, observed_names)

    for site in self._registry.sites.values():
        self.register_parameter(
            f"unconstrained_{site.name}",
            nn.Parameter(
                torch.full(site.unconstrained_shape, init_value)
                + torch.randn(site.unconstrained_shape) * 0.01
            ),
        )

rsample

rsample(x: Tensor) -> dict[str, Tensor]

Return the learned point estimates in the prior's support.

Source code in src/quivers/inference/guides/delta.py
74
75
76
77
78
79
80
81
def rsample(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
    """Return the learned point estimates in the prior's
    support."""
    batch = x.shape[0]
    return {
        site.name: self._push_through_bijector(site, batch)
        for site in self._registry.sites.values()
    }

log_prob

log_prob(x: Tensor, sites: dict[str, Tensor]) -> Tensor

Delta log-density: zero everywhere (the delta term and its Jacobian cancel in the ELBO under the standard score-function trick).

Source code in src/quivers/inference/guides/delta.py
83
84
85
86
87
88
89
90
91
def log_prob(
    self,
    x: torch.Tensor,
    sites: dict[str, torch.Tensor],
) -> torch.Tensor:
    """Delta log-density: zero everywhere (the delta term and
    its Jacobian cancel in the ELBO under the standard
    score-function trick)."""
    return torch.zeros(x.shape[0], device=x.device)

AutoIAFGuide

AutoIAFGuide(model: MonadicProgram, observed_names: set[str], num_flows: int = 4, hidden_dim: int | None = None, num_hidden_layers: int = 2)

Bases: AutoNormalizingFlow

Inverse-autoregressive-flow guide.

Default normalizing-flow guide for variational inference (Kingma-Salimans-Jozefowicz et al. 2016). Stack of InverseAutoregressiveTransform layers, each separated by a reverse permutation so successive layers have different autoregressive orderings.

Sampling is parallel (one MLP forward per layer); density evaluation is sequential (one coordinate at a time per layer), so this guide should be used with objectives that sample more than they score the same flow (ELBO, IWAE).

PARAMETER DESCRIPTION
model

Generative model.

TYPE: MonadicProgram

observed_names

Variable names treated as observations.

TYPE: set[str]

num_flows

Number of IAF blocks in the stack. Default 4.

TYPE: int DEFAULT: 4

hidden_dim

Hidden width of every MADE inside the stack. Default 2 * D where D is the latent dimension.

TYPE: int DEFAULT: None

num_hidden_layers

Number of hidden layers in each MADE. Default 2.

TYPE: int DEFAULT: 2

Source code in src/quivers/inference/guides/flow.py
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
def __init__(
    self,
    model: MonadicProgram,
    observed_names: set[str],
    num_flows: int = 4,
    hidden_dim: int | None = None,
    num_hidden_layers: int = 2,
) -> None:
    if num_flows < 1:
        raise ValueError(f"AutoIAFGuide: num_flows must be >= 1, got {num_flows}")
    registry = Guide.build_registry(model, observed_names)
    D_total = registry.total_unconstrained_dim
    if hidden_dim is None:
        hidden_dim = max(8, 2 * D_total)
    if D_total < 2:
        raise ValueError(
            f"AutoIAFGuide: model must have >= 2 unconstrained "
            f"latent dimensions for an IAF (got {D_total}); use "
            f"AutoNormalGuide for 1-D models"
        )
    layers: list[TransformModule] = []
    for i in range(num_flows):
        made = MADE(
            dim=D_total,
            n_per_dim=2,
            hidden=hidden_dim,
            n_hidden_layers=num_hidden_layers,
        )
        layers.append(InverseAutoregressiveTransform(made))
        if i < num_flows - 1:
            layers.append(_ReversePermutation(D_total))
    super().__init__(model, observed_names, layers)

AutoNeuralSplineGuide

AutoNeuralSplineGuide(model: MonadicProgram, observed_names: set[str], num_flows: int = 4, num_bins: int = 8, tail_bound: float = 3.0, hidden_dim: int | None = None, num_hidden_layers: int = 2)

Bases: AutoNormalizingFlow

Neural-spline-flow guide (Durkan-Bekasov-Murray-Papamakarios 2019).

Stack of monotone rational-quadratic spline coupling layers (NeuralSplineCouplingTransform) with alternating half-masks. Sharper than IAF for posteriors with bounded support or sharp modes; comparable runtime.

PARAMETER DESCRIPTION
model

Generative model.

TYPE: MonadicProgram

observed_names

Variable names treated as observations.

TYPE: set[str]

num_flows

Number of coupling layers. Default 4.

TYPE: int DEFAULT: 4

num_bins

Number of spline bins per coordinate. Default 8.

TYPE: int DEFAULT: 8

tail_bound

Inputs outside [-tail_bound, tail_bound] pass through as identity. Default 3.0.

TYPE: float DEFAULT: 3.0

hidden_dim

Hidden width of the coupling MLPs. Default max(64, 2*D).

TYPE: int DEFAULT: None

num_hidden_layers

Hidden layers in each coupling MLP. Default 2.

TYPE: int DEFAULT: 2

Source code in src/quivers/inference/guides/flow.py
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
def __init__(
    self,
    model: MonadicProgram,
    observed_names: set[str],
    num_flows: int = 4,
    num_bins: int = 8,
    tail_bound: float = 3.0,
    hidden_dim: int | None = None,
    num_hidden_layers: int = 2,
) -> None:
    if num_flows < 1:
        raise ValueError(
            f"AutoNeuralSplineGuide: num_flows must be >= 1, got {num_flows}"
        )
    registry = Guide.build_registry(model, observed_names)
    D_total = registry.total_unconstrained_dim
    if D_total < 2:
        raise ValueError(
            f"AutoNeuralSplineGuide: model must have >= 2 "
            f"unconstrained latent dimensions for a spline "
            f"coupling flow (got {D_total})"
        )
    if hidden_dim is None:
        hidden_dim = max(64, 2 * D_total)
    layers: list[TransformModule] = []
    for i in range(num_flows):
        mask = alternating_mask(D_total, even=(i % 2 == 0))
        num_unmasked = int(mask.sum().item())
        num_masked = D_total - num_unmasked
        # Net produces (3 * num_bins - 1) parameters per masked coord.
        out_dim = num_masked * (3 * num_bins - 1)
        net = make_coupling_mlp(
            n_in=num_unmasked,
            n_out=out_dim,
            hidden=hidden_dim,
            n_hidden_layers=num_hidden_layers,
        )
        layers.append(
            NeuralSplineCouplingTransform(
                dim=D_total,
                net=net,
                mask=mask,
                num_bins=num_bins,
                tail_bound=tail_bound,
            )
        )
    super().__init__(model, observed_names, layers)

AutoNormalizingFlow

AutoNormalizingFlow(model: MonadicProgram, observed_names: set[str], transforms: list[TransformModule])

Bases: Guide

Normalising-flow variational guide over the flat latent vector.

PARAMETER DESCRIPTION
model

Generative model to build a guide for.

TYPE: MonadicProgram

observed_names

Variable names treated as observations.

TYPE: set[str]

transforms

Flow stack applied to the standard-Normal base. Each TransformModule must accept a (..., D)-shaped tensor where D is the registry's total unconstrained dimension.

TYPE: list[TransformModule]

Source code in src/quivers/inference/guides/flow.py
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
def __init__(
    self,
    model: MonadicProgram,
    observed_names: set[str],
    transforms: list[TransformModule],
) -> None:
    super().__init__()
    self._registry = self.build_registry(model, observed_names)
    D_total = self._registry.total_unconstrained_dim
    if D_total == 0:
        raise ValueError(
            f"{type(self).__name__}: registry has zero total "
            f"unconstrained dimension; model has no continuous "
            f"latents to guide"
        )
    if not transforms:
        raise ValueError(
            f"{type(self).__name__}: transforms list must be non-empty"
        )
    self._D = D_total
    self.flow = nn.ModuleList(transforms)
    self.register_buffer("base_loc", torch.zeros(D_total))
    self.register_buffer("base_scale", torch.ones(D_total))

rsample

rsample(x: Tensor) -> dict[str, Tensor]

One flow draw, unflattened and bijected to constrained space.

Source code in src/quivers/inference/guides/flow.py
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def rsample(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
    """One flow draw, unflattened and bijected to constrained space."""
    batch = x.shape[0]
    z0 = self._base_dist().rsample()
    z_K, _ = self._forward(z0)
    per_site = self._registry.unflatten_unconstrained(z_K)
    result: dict[str, torch.Tensor] = {}
    for site in self._registry.sites.values():
        z_site = per_site[site.name]
        if not site.is_plate:
            z_site = z_site.unsqueeze(0).expand(batch, *site.unconstrained_shape)
        v = site.bijector(z_site)
        if site.constrained_dim == 1 and v.dim() >= 1 and v.shape[-1] == 1:
            v = v.squeeze(-1)
        result[site.name] = v
    return result

log_prob

log_prob(x: Tensor, sites: dict[str, Tensor]) -> Tensor

Log-density at the supplied constrained sites.

Source code in src/quivers/inference/guides/flow.py
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
def log_prob(
    self,
    x: torch.Tensor,
    sites: dict[str, torch.Tensor],
) -> torch.Tensor:
    """Log-density at the supplied constrained sites."""
    batch = x.shape[0]
    unconstrained_per_site: dict[str, torch.Tensor] = {}
    bijector_log_det = torch.zeros((), device=x.device)
    for site in self._registry.sites.values():
        if site.name not in sites:
            raise KeyError(
                f"{type(self).__name__}.log_prob: missing site {site.name!r}"
            )
        v = sites[site.name]
        if site.constrained_dim == 1 and v.dim() == (1 if site.is_plate else 1):
            v_e = v.unsqueeze(-1)
        else:
            v_e = v
        if not site.is_plate and v_e.dim() == len(site.unconstrained_shape) + 1:
            v_e = v_e[0]
        z_site = site.bijector.inv(v_e)
        unconstrained_per_site[site.name] = z_site
        bijector_log_det = bijector_log_det + (
            site.bijector.inv.log_abs_det_jacobian(v_e, z_site).sum()
        )

    z_K = self._registry.flatten_unconstrained(unconstrained_per_site)
    z_0, flow_log_det = self._inverse(z_K)
    log_p_base = self._base_dist().log_prob(z_0)
    # log q(z_K) = log p_0(z_0) - log|det dT/dz_0|
    #            = log p_0(z_0) - flow_log_det
    log_q_z = log_p_base - flow_log_det
    return (log_q_z + bijector_log_det).expand(batch)

AutoLaplaceApproximation

AutoLaplaceApproximation(model: MonadicProgram, observed_names: set[str], init_value: float = 0.0)

Bases: Guide

Laplace-approximation guide.

PARAMETER DESCRIPTION
model

Generative model.

TYPE: MonadicProgram

observed_names

Variable names treated as observations.

TYPE: set[str]

init_value

Initial unconstrained-space MAP estimate. Default 0.0.

TYPE: float DEFAULT: 0.0

Source code in src/quivers/inference/guides/laplace.py
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def __init__(
    self,
    model: MonadicProgram,
    observed_names: set[str],
    init_value: float = 0.0,
) -> None:
    super().__init__()
    self._registry = self.build_registry(model, observed_names)
    D_total = self._registry.total_unconstrained_dim
    if D_total == 0:
        raise ValueError(
            f"{type(self).__name__}: registry has zero total "
            f"unconstrained dimension; model has no continuous "
            f"latents to guide"
        )
    self._D = D_total
    # MAP estimate in flat unconstrained space.
    init = torch.full((D_total,), float(init_value))
    init = init + 0.01 * torch.randn(D_total)
    self.map_z = nn.Parameter(init)
    # Hessian-phase parameters; initialized to identity scale_tril
    # but only used after fit_hessian() is called.
    self.register_buffer("_hessian_fitted", torch.zeros((), dtype=torch.bool))
    self.register_buffer(
        "_scale_tril",
        torch.eye(D_total) * 1e-3,
        persistent=True,
    )

hessian_fitted property

hessian_fitted: bool

Whether fit_hessian has been called.

fit_hessian

fit_hessian(model: MonadicProgram, x: Tensor, observations: dict[str, Tensor], *, jitter: float = 0.0001) -> None

Compute and cache the Hessian-derived Cholesky factor.

Solves the eigenproblem of the negative-log-joint Hessian at the current MAP, projects negative eigenvalues to jitter (so the resulting Gaussian is always positive-definite), and stores the matching lower-triangular Cholesky factor of the inverse Hessian as the posterior scale_tril.

Call this after MAP optimisation has converged. Subsequent rsample / log_prob calls sample from :math:\mathcal{N}(z^\star, H^{-1}).

Source code in src/quivers/inference/guides/laplace.py
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def fit_hessian(
    self,
    model: MonadicProgram,
    x: torch.Tensor,
    observations: dict[str, torch.Tensor],
    *,
    jitter: float = 1e-4,
) -> None:
    """Compute and cache the Hessian-derived Cholesky factor.

    Solves the eigenproblem of the negative-log-joint Hessian at
    the current MAP, projects negative eigenvalues to ``jitter``
    (so the resulting Gaussian is always positive-definite), and
    stores the matching lower-triangular Cholesky factor of the
    inverse Hessian as the posterior scale_tril.

    Call this after MAP optimisation has converged. Subsequent
    `rsample` / `log_prob` calls sample from
    :math:`\\mathcal{N}(z^\\star, H^{-1})`.
    """

    def neg_log_joint(z_flat: torch.Tensor) -> torch.Tensor:
        per_site_unconstrained = self._registry.unflatten_unconstrained(z_flat)
        constrained: dict[str, torch.Tensor] = {}
        log_det_sum = torch.zeros((), device=z_flat.device)
        for site in self._registry.sites.values():
            z_site = per_site_unconstrained[site.name]
            if not site.is_plate:
                z_site = z_site.unsqueeze(0)
            v = site.bijector(z_site)
            log_det_sum = log_det_sum + (
                site.bijector.log_abs_det_jacobian(z_site, v).sum()
            )
            if site.constrained_dim == 1 and v.dim() >= 1 and v.shape[-1] == 1:
                v = v.squeeze(-1)
            constrained[site.name] = v
        log_p = model.log_joint(x, {**constrained, **observations})
        return -(log_p.sum() + log_det_sum)

    H = torch.autograd.functional.hessian(neg_log_joint, self.map_z.detach())
    H = 0.5 * (H + H.t())
    eigvals, eigvecs = torch.linalg.eigh(H)
    eigvals_clamped = eigvals.clamp(min=jitter)
    # Σ = (V Λ V^T)^{-1} = V Λ^{-1} V^T; scale_tril is its Cholesky.
    inv_eigvals = 1.0 / eigvals_clamped
    sigma = (eigvecs * inv_eigvals.unsqueeze(0)) @ eigvecs.t()
    sigma = 0.5 * (sigma + sigma.t())
    sigma = sigma + jitter * torch.eye(self._D, device=sigma.device)
    L = torch.linalg.cholesky(sigma)
    self._scale_tril.copy_(L)
    self._hessian_fitted.fill_(True)

rsample

rsample(x: Tensor) -> dict[str, Tensor]

Sample from the Laplace posterior, unflatten, and biject.

Source code in src/quivers/inference/guides/laplace.py
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
def rsample(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
    """Sample from the Laplace posterior, unflatten, and biject."""
    batch = x.shape[0]
    z_flat = self._sample_unconstrained()
    per_site = self._registry.unflatten_unconstrained(z_flat)

    result: dict[str, torch.Tensor] = {}
    for site in self._registry.sites.values():
        z_site = per_site[site.name]
        if not site.is_plate:
            z_site = z_site.unsqueeze(0).expand(batch, *site.unconstrained_shape)
        v = site.bijector(z_site)
        if site.constrained_dim == 1 and v.dim() >= 1 and v.shape[-1] == 1:
            v = v.squeeze(-1)
        result[site.name] = v
    return result

log_prob

log_prob(x: Tensor, sites: dict[str, Tensor]) -> Tensor

Log-density at the supplied constrained sites.

Returns zero before fit_hessian (MAP-phase delta convention); after fit_hessian returns the Gaussian log-density plus the per-site bijector Jacobian correction.

Source code in src/quivers/inference/guides/laplace.py
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
def log_prob(
    self,
    x: torch.Tensor,
    sites: dict[str, torch.Tensor],
) -> torch.Tensor:
    """Log-density at the supplied constrained sites.

    Returns zero before `fit_hessian` (MAP-phase delta
    convention); after `fit_hessian` returns the Gaussian
    log-density plus the per-site bijector Jacobian correction.
    """
    batch = x.shape[0]
    if not bool(self._hessian_fitted):
        return torch.zeros(batch, device=x.device)

    unconstrained_per_site: dict[str, torch.Tensor] = {}
    bijector_log_det = torch.zeros((), device=x.device)
    for site in self._registry.sites.values():
        if site.name not in sites:
            raise KeyError(
                f"{type(self).__name__}.log_prob: missing site {site.name!r}"
            )
        v = sites[site.name]
        if site.constrained_dim == 1 and v.dim() == (1 if site.is_plate else 1):
            v_e = v.unsqueeze(-1)
        else:
            v_e = v
        if not site.is_plate and v_e.dim() == len(site.unconstrained_shape) + 1:
            v_e = v_e[0]
        z_site = site.bijector.inv(v_e)
        unconstrained_per_site[site.name] = z_site
        bijector_log_det = bijector_log_det + (
            site.bijector.inv.log_abs_det_jacobian(v_e, z_site).sum()
        )
    z_flat = self._registry.flatten_unconstrained(unconstrained_per_site)
    gauss = D.MultivariateNormal(self.map_z, scale_tril=self._scale_tril)
    log_q_z = gauss.log_prob(z_flat)
    return (log_q_z + bijector_log_det).expand(batch)

AutoMixtureGuide

AutoMixtureGuide(components: list[Guide], init_temperature: float = 1.0)

Bases: Guide

Finite mixture variational guide.

PARAMETER DESCRIPTION
components

Component guides. All components must share the same LatentRegistry (i.e. be built against the same model + observed-name set).

TYPE: list[Guide]

init_temperature

Initial Gumbel-Softmax temperature. Default 1.0; anneal toward zero for sharper component selection.

TYPE: float DEFAULT: 1.0

Source code in src/quivers/inference/guides/mixture.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def __init__(
    self,
    components: list[Guide],
    init_temperature: float = 1.0,
) -> None:
    super().__init__()
    if len(components) < 2:
        raise ValueError(
            f"AutoMixtureGuide: need at least 2 components, got {len(components)}"
        )
    reference = components[0]
    ref_names = tuple(reference.registry.names)
    for i, comp in enumerate(components[1:], 1):
        comp_names = tuple(comp.registry.names)
        if comp_names != ref_names:
            raise ValueError(
                f"AutoMixtureGuide: component {i} has different "
                f"latent names {comp_names!r} than component 0 "
                f"{ref_names!r}"
            )
    if init_temperature <= 0.0:
        raise ValueError(
            f"AutoMixtureGuide: init_temperature must be positive, "
            f"got {init_temperature}"
        )
    self._registry = reference.registry
    self.components = nn.ModuleList(components)
    self.mixture_logits = nn.Parameter(torch.zeros(len(components)))
    self._temperature: torch.Tensor
    self.register_buffer("_temperature", torch.tensor(float(init_temperature)))

temperature property

temperature: float

Current Gumbel-Softmax temperature.

set_temperature

set_temperature(value: float) -> None

Anneal the Gumbel-Softmax temperature.

Source code in src/quivers/inference/guides/mixture.py
93
94
95
96
97
98
99
def set_temperature(self, value: float) -> None:
    """Anneal the Gumbel-Softmax temperature."""
    if value <= 0.0:
        raise ValueError(
            f"AutoMixtureGuide.set_temperature: must be positive, got {value}"
        )
    self._temperature.fill_(float(value))

rsample

rsample(x: Tensor) -> dict[str, Tensor]

Reparameterized mixture draw via Gumbel-Softmax.

Each call samples a Gumbel-Softmax weight vector :math:w \in \Delta^{K-1} and returns :math:\sum_k w_k \cdot v^{(k)} per site, where :math:v^{(k)} is component k's constrained-space sample. Because the constrained-space sites' supports are not in general convex (e.g. a Cholesky factor on torch.distributions.constraints.corr_cholesky), the soft mixture can drift outside any single component's support during training; the categorical-pick fallback in hard_rsample returns a single component's sample for use at inference time.

Source code in src/quivers/inference/guides/mixture.py
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
def rsample(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
    """Reparameterized mixture draw via Gumbel-Softmax.

    Each call samples a Gumbel-Softmax weight vector
    :math:`w \\in \\Delta^{K-1}` and returns
    :math:`\\sum_k w_k \\cdot v^{(k)}` per site, where
    :math:`v^{(k)}` is component ``k``'s constrained-space
    sample. Because the constrained-space sites' supports are
    not in general convex (e.g. a Cholesky factor on
    `torch.distributions.constraints.corr_cholesky`), the
    soft mixture can drift outside any single component's
    support during training; the categorical-pick fallback in
    `hard_rsample` returns a single component's sample
    for use at inference time.
    """
    gumbel_logits = (
        self.mixture_logits
        - torch.empty_like(self.mixture_logits).exponential_().log()
    )
    w = F.softmax(gumbel_logits / self._temperature, dim=-1)
    component_samples = [comp.rsample(x) for comp in self.components]

    result: dict[str, torch.Tensor] = {}
    for site_name in self._registry.names:
        stacked = torch.stack(
            [comp_samples[site_name] for comp_samples in component_samples],
            dim=0,
        )
        # Broadcast w against the stacked shape.
        broadcast_shape = (self.num_components,) + (1,) * (stacked.dim() - 1)
        result[site_name] = (w.reshape(broadcast_shape) * stacked).sum(dim=0)
    return result

hard_rsample

hard_rsample(x: Tensor) -> dict[str, Tensor]

Categorical-pick variant: sample a component index and return that component's draw verbatim. Use at inference time when soft-mixture interpolation would violate a support constraint.

Source code in src/quivers/inference/guides/mixture.py
138
139
140
141
142
143
144
145
def hard_rsample(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
    """Categorical-pick variant: sample a component index and
    return that component's draw verbatim. Use at inference
    time when soft-mixture interpolation would violate a
    support constraint."""
    probs = F.softmax(self.mixture_logits, dim=-1)
    k = int(torch.distributions.Categorical(probs=probs).sample().item())
    return self.components[k].rsample(x)

log_prob

log_prob(x: Tensor, sites: dict[str, Tensor]) -> Tensor

Mixture log-density via logsumexp over components.

Source code in src/quivers/inference/guides/mixture.py
151
152
153
154
155
156
157
158
159
160
161
def log_prob(
    self,
    x: torch.Tensor,
    sites: dict[str, torch.Tensor],
) -> torch.Tensor:
    """Mixture log-density via logsumexp over components."""
    log_pi = F.log_softmax(self.mixture_logits, dim=-1)
    component_log_probs = torch.stack(
        [comp.log_prob(x, sites) for comp in self.components], dim=0
    )
    return torch.logsumexp(log_pi.unsqueeze(-1) + component_log_probs, dim=0)

AutoLowRankMultivariateNormalGuide

AutoLowRankMultivariateNormalGuide(model: MonadicProgram, observed_names: set[str], rank: int = 5, init_scale: float = 0.1)

Bases: _MVNCommon

Low-rank-plus-diagonal multivariate-Normal guide.

Covariance :math:\Sigma = W W^\top + \mathrm{diag}(\sigma^2) with W of shape :math:(D, r) and :math:\sigma \in \mathbb{R}^{D}_{>0}. Memory :math:O(Dr); sampling and log-density via Woodbury / matrix-determinant lemma in torch.distributions.LowRankMultivariateNormal.

Captures the dominant r posterior correlation directions while remaining tractable for D in the hundreds-to- thousands range, where full-rank is infeasible.

PARAMETER DESCRIPTION
model

Generative model to build a guide for.

TYPE: MonadicProgram

observed_names

Variable names treated as observations.

TYPE: set[str]

rank

Number of correlated directions. Default 5.

TYPE: int DEFAULT: 5

init_scale

Initial diagonal scale. W is initialized at zero.

TYPE: float DEFAULT: 0.1

Source code in src/quivers/inference/guides/multivariate_normal.py
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
def __init__(
    self,
    model: MonadicProgram,
    observed_names: set[str],
    rank: int = 5,
    init_scale: float = 0.1,
) -> None:
    super().__init__(model, observed_names)
    if rank < 1:
        raise ValueError(
            f"AutoLowRankMultivariateNormalGuide: rank must be >= 1, got {rank}"
        )
    if rank > self._D:
        raise ValueError(
            f"AutoLowRankMultivariateNormalGuide: rank ({rank}) "
            f"cannot exceed total unconstrained dimension "
            f"({self._D})"
        )
    self._rank = rank
    init_diag = torch.full((self._D,), float(init_scale))
    init_diag_raw = torch.log(torch.expm1(init_diag.clamp(min=1e-6)))
    self.cov_diag_raw = nn.Parameter(init_diag_raw)
    self.cov_factor = nn.Parameter(torch.zeros(self._D, rank))

AutoMultivariateNormalGuide

AutoMultivariateNormalGuide(model: MonadicProgram, observed_names: set[str], init_scale: float = 0.1)

Bases: _MVNCommon

Full-rank multivariate-Normal variational guide.

Parameterises a joint Gaussian over the registry's flat unconstrained vector with a learnable lower-triangular Cholesky factor. Captures every pairwise posterior correlation across every latent site — the right choice when posterior couplings are strong (hierarchical regression with crossed random effects, parameter pairs with multiplicative interaction).

PARAMETER DESCRIPTION
model

Generative model to build a guide for.

TYPE: MonadicProgram

observed_names

Variable names treated as observations.

TYPE: set[str]

init_scale

Initial diagonal of the Cholesky factor. Default 0.1; the off-diagonal entries start at 0.

TYPE: float DEFAULT: 0.1

Source code in src/quivers/inference/guides/multivariate_normal.py
187
188
189
190
191
192
193
194
195
196
197
def __init__(
    self,
    model: MonadicProgram,
    observed_names: set[str],
    init_scale: float = 0.1,
) -> None:
    super().__init__(model, observed_names)
    init_diag = torch.full((self._D,), float(init_scale))
    init_diag_raw = torch.log(torch.expm1(init_diag.clamp(min=1e-6)))
    self.scale_diag_raw = nn.Parameter(init_diag_raw)
    self.scale_offdiag = nn.Parameter(torch.zeros(self._D, self._D))

AutoNormalGuide

AutoNormalGuide(model: MonadicProgram, observed_names: set[str], init_scale: float = 0.1)

Bases: Guide

Mean-field Normal guide with per-site constrained-support bijector.

PARAMETER DESCRIPTION
model

Generative model to build a guide for.

TYPE: MonadicProgram

observed_names

Variable names treated as observations (skipped in the guide; their values flow through the conditioning data dict at trace time).

TYPE: set[str]

init_scale

Initial scale (in unconstrained space) of every latent. Default 0.1; small enough to keep the guide near its prior at the start of optimisation.

TYPE: float DEFAULT: 0.1

Source code in src/quivers/inference/guides/normal.py
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def __init__(
    self,
    model: MonadicProgram,
    observed_names: set[str],
    init_scale: float = 0.1,
) -> None:
    super().__init__()
    self._registry = self.build_registry(model, observed_names)
    init_log_scale = float(torch.tensor(init_scale).log().item())

    for site in self._registry.sites.values():
        self.register_parameter(
            f"loc_{site.name}",
            nn.Parameter(torch.zeros(site.unconstrained_shape)),
        )
        self.register_parameter(
            f"log_scale_{site.name}",
            nn.Parameter(torch.full(site.unconstrained_shape, init_log_scale)),
        )

rsample

rsample(x: Tensor) -> dict[str, Tensor]

Reparameterized mean-field Normal-then-bijector sample.

Source code in src/quivers/inference/guides/normal.py
128
129
130
131
132
133
134
135
def rsample(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
    """Reparameterized mean-field Normal-then-bijector sample."""
    batch = x.shape[0]
    result: dict[str, torch.Tensor] = {}
    for site in self._registry.sites.values():
        _, v = self._sample_site(site, batch)
        result[site.name] = v
    return result

log_prob

log_prob(x: Tensor, sites: dict[str, Tensor]) -> Tensor

Pushforward log-density at constrained values sites.

Uses the change-of-variables identity:

log q(v) = log Normal(z; loc, scale) + log|det J_{T^{-1}}(v)|

where z = bijector.inv(v). The plate / scalar shape dispatch matches rsample's convention.

Source code in src/quivers/inference/guides/normal.py
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
def log_prob(
    self,
    x: torch.Tensor,
    sites: dict[str, torch.Tensor],
) -> torch.Tensor:
    """Pushforward log-density at constrained values ``sites``.

    Uses the change-of-variables identity:

        log q(v) = log Normal(z; loc, scale) + log|det J_{T^{-1}}(v)|

    where ``z = bijector.inv(v)``. The plate / scalar shape
    dispatch matches `rsample`'s convention.
    """
    batch = x.shape[0]
    total = torch.zeros(batch, device=x.device)
    for site in self._registry.sites.values():
        if site.name not in sites:
            continue
        v = sites[site.name]
        if site.constrained_dim == 1 and v.dim() == (1 if site.is_plate else 1):
            v = v.unsqueeze(-1)
        z = site.bijector.inv(v)
        loc = self._loc(site.name)
        scale = self._scale(site.name)
        if site.is_plate:
            # Plate latent: single shared sample, scalar density
            # broadcast against the batch accumulator.
            log_q_z = D.Normal(loc, scale).log_prob(z)
            log_abs_det = site.bijector.inv.log_abs_det_jacobian(v, z)
            contribution = log_q_z.reshape(-1).sum() + log_abs_det.reshape(-1).sum()
            total = total + contribution
        else:
            loc_b = loc.unsqueeze(0).expand(batch, *site.unconstrained_shape)
            scale_b = scale.unsqueeze(0).expand(batch, *site.unconstrained_shape)
            log_q_z = D.Normal(loc_b, scale_b).log_prob(z)
            log_abs_det = site.bijector.inv.log_abs_det_jacobian(v, z)
            while log_q_z.dim() > 1:
                log_q_z = log_q_z.sum(dim=-1)
            while log_abs_det.dim() > 1:
                log_abs_det = log_abs_det.sum(dim=-1)
            total = total + log_q_z + log_abs_det
    return total