Variational Guides

Variational guide distributions for approximate inference.

guide

Variational guide families for approximate posterior inference.

A guide is a parameterized distribution q(z | x) over latent variables that approximates the true posterior p(z | x, y_obs). Guides are used by SVI to optimize the ELBO.

This module provides:

  • Guide — abstract base class for all guides
  • AutoNormalGuide — mean-field Normal over all continuous latents
  • AutoDeltaGuide — point-estimate (MAP) guide

Guide

Bases: Module, ABC

Abstract variational guide.

A guide provides a parameterized approximate posterior q(z | x) over latent variables. It must support reparameterized sampling and log-density evaluation.

latent_names abstractmethod property

latent_names: list[str]

Names of latent variables this guide covers.

rsample abstractmethod

rsample(x: Tensor) -> dict[str, Tensor]

Sample latent variables from the guide.

PARAMETER DESCRIPTION
x

Program input. Shape (batch, ...).

TYPE: Tensor

RETURNS DESCRIPTION
dict[str, Tensor]

Sampled values for each latent variable.

Source code in src/quivers/inference/guide.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
@abstractmethod
def rsample(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
    """Sample latent variables from the guide.

    Parameters
    ----------
    x : torch.Tensor
        Program input. Shape (batch, ...).

    Returns
    -------
    dict[str, torch.Tensor]
        Sampled values for each latent variable.
    """
    ...

log_prob abstractmethod

log_prob(x: Tensor, sites: dict[str, Tensor]) -> Tensor

Log-density of latent values under the guide.

PARAMETER DESCRIPTION
x

Program input. Shape (batch, ...).

TYPE: Tensor

sites

Values for each latent variable.

TYPE: dict[str, Tensor]

RETURNS DESCRIPTION
Tensor

Total log-density. Shape (batch,).

Source code in src/quivers/inference/guide.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
@abstractmethod
def log_prob(
    self,
    x: torch.Tensor,
    sites: dict[str, torch.Tensor],
) -> torch.Tensor:
    """Log-density of latent values under the guide.

    Parameters
    ----------
    x : torch.Tensor
        Program input. Shape (batch, ...).
    sites : dict[str, torch.Tensor]
        Values for each latent variable.

    Returns
    -------
    torch.Tensor
        Total log-density. Shape (batch,).
    """
    ...

AutoNormalGuide

AutoNormalGuide(model: MonadicProgram, observed_names: set[str], init_scale: float = 0.1)

Bases: Guide

Mean-field Normal guide with learnable loc and scale per latent.

Inspects the model's step specs to discover latent (non-observed) sites and creates a pair of parameters (loc, log_scale) for each.

PARAMETER DESCRIPTION
model

The generative model to build a guide for.

TYPE: MonadicProgram

observed_names

Names of observed variables (excluded from the guide).

TYPE: set[str]

init_scale

Initial scale for all latent sites.

TYPE: float DEFAULT: 0.1

Source code in src/quivers/inference/guide.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
def __init__(
    self,
    model: MonadicProgram,
    observed_names: set[str],
    init_scale: float = 0.1,
) -> None:
    super().__init__()
    self._latent_names = []

    for spec in model._step_specs:
        if isinstance(spec, _LetSpec):
            continue

        for var in spec.vars:
            if var in observed_names:
                continue

            self._latent_names.append(var)

            # determine dimension from morphism codomain
            assert model._modules[spec.morphism_name] is not None
            morph = cast(ContinuousMorphism, model._modules[spec.morphism_name])
            dim = self._infer_dim(morph, len(spec.vars))

            # register learnable parameters
            self.register_parameter(
                f"loc_{var}",
                nn.Parameter(torch.zeros(dim)),
            )
            self.register_parameter(
                f"log_scale_{var}",
                nn.Parameter(
                    torch.full((dim,), torch.tensor(init_scale).log().item())
                ),
            )

latent_names property

latent_names: list[str]

Names of latent variables this guide covers.

rsample

rsample(x: Tensor) -> dict[str, Tensor]

Sample from mean-field Normal for each latent.

PARAMETER DESCRIPTION
x

Program input (used for batch size).

TYPE: Tensor

RETURNS DESCRIPTION
dict[str, Tensor]

Sampled latent values.

Source code in src/quivers/inference/guide.py
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
def rsample(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
    """Sample from mean-field Normal for each latent.

    Parameters
    ----------
    x : torch.Tensor
        Program input (used for batch size).

    Returns
    -------
    dict[str, torch.Tensor]
        Sampled latent values.
    """
    batch = x.shape[0]
    result = {}

    for name in self._latent_names:
        loc = getattr(self, f"loc_{name}")
        log_scale = getattr(self, f"log_scale_{name}")
        scale = log_scale.exp().clamp(min=1e-6)

        # expand to batch
        loc_batch = loc.unsqueeze(0).expand(batch, -1)
        scale_batch = scale.unsqueeze(0).expand(batch, -1)

        dist = D.Normal(loc_batch, scale_batch)
        sample = dist.rsample()

        # squeeze if 1-d
        if sample.shape[-1] == 1:
            sample = sample.squeeze(-1)

        result[name] = sample

    return result

log_prob

log_prob(x: Tensor, sites: dict[str, Tensor]) -> Tensor

Log-density under the mean-field Normal guide.

PARAMETER DESCRIPTION
x

Program input (used for batch size).

TYPE: Tensor

sites

Latent variable values.

TYPE: dict[str, Tensor]

RETURNS DESCRIPTION
Tensor

Total log-density. Shape (batch,).

Source code in src/quivers/inference/guide.py
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
def log_prob(
    self,
    x: torch.Tensor,
    sites: dict[str, torch.Tensor],
) -> torch.Tensor:
    """Log-density under the mean-field Normal guide.

    Parameters
    ----------
    x : torch.Tensor
        Program input (used for batch size).
    sites : dict[str, torch.Tensor]
        Latent variable values.

    Returns
    -------
    torch.Tensor
        Total log-density. Shape (batch,).
    """
    batch = x.shape[0]
    total = torch.zeros(batch, device=x.device)

    for name in self._latent_names:
        if name not in sites:
            continue

        loc = getattr(self, f"loc_{name}")
        log_scale = getattr(self, f"log_scale_{name}")
        scale = log_scale.exp().clamp(min=1e-6)

        val = sites[name]

        if val.dim() == 1:
            val = val.unsqueeze(-1)

        loc_batch = loc.unsqueeze(0).expand(batch, -1)
        scale_batch = scale.unsqueeze(0).expand(batch, -1)

        dist = D.Normal(loc_batch, scale_batch)
        # sum over event dims
        total = total + dist.log_prob(val).sum(dim=-1)

    return total

AutoDeltaGuide

AutoDeltaGuide(model: MonadicProgram, observed_names: set[str])

Bases: Guide

Point-estimate (MAP) guide with a learnable value per latent.

The guide distribution is a delta at the learned point, so log_prob returns 0 for all sites (the delta contribution cancels in the ELBO).

PARAMETER DESCRIPTION
model

The generative model.

TYPE: MonadicProgram

observed_names

Names of observed variables.

TYPE: set[str]

Source code in src/quivers/inference/guide.py
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
def __init__(
    self,
    model: MonadicProgram,
    observed_names: set[str],
) -> None:
    super().__init__()
    self._latent_names = []

    for spec in model._step_specs:
        if isinstance(spec, _LetSpec):
            continue

        for var in spec.vars:
            if var in observed_names:
                continue

            self._latent_names.append(var)

            assert model._modules[spec.morphism_name] is not None
            morph = cast(ContinuousMorphism, model._modules[spec.morphism_name])
            dim = AutoNormalGuide._infer_dim(morph, len(spec.vars))

            self.register_parameter(
                f"value_{var}",
                nn.Parameter(torch.randn(dim) * 0.1),
            )

latent_names property

latent_names: list[str]

Names of latent variables this guide covers.

rsample

rsample(x: Tensor) -> dict[str, Tensor]

Return the learned point estimates.

PARAMETER DESCRIPTION
x

Program input (used for batch size).

TYPE: Tensor

RETURNS DESCRIPTION
dict[str, Tensor]

Point-estimate values for each latent.

Source code in src/quivers/inference/guide.py
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
def rsample(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
    """Return the learned point estimates.

    Parameters
    ----------
    x : torch.Tensor
        Program input (used for batch size).

    Returns
    -------
    dict[str, torch.Tensor]
        Point-estimate values for each latent.
    """
    batch = x.shape[0]
    result = {}

    for name in self._latent_names:
        val = getattr(self, f"value_{name}")
        expanded = val.unsqueeze(0).expand(batch, -1)

        if expanded.shape[-1] == 1:
            expanded = expanded.squeeze(-1)

        result[name] = expanded

    return result

log_prob

log_prob(x: Tensor, sites: dict[str, Tensor]) -> Tensor

Log-density under the delta guide (always zero).

PARAMETER DESCRIPTION
x

Program input.

TYPE: Tensor

sites

Latent variable values (ignored).

TYPE: dict[str, Tensor]

RETURNS DESCRIPTION
Tensor

Zeros. Shape (batch,).

Source code in src/quivers/inference/guide.py
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
def log_prob(
    self,
    x: torch.Tensor,
    sites: dict[str, torch.Tensor],
) -> torch.Tensor:
    """Log-density under the delta guide (always zero).

    Parameters
    ----------
    x : torch.Tensor
        Program input.
    sites : dict[str, torch.Tensor]
        Latent variable values (ignored).

    Returns
    -------
    torch.Tensor
        Zeros. Shape (batch,).
    """
    return torch.zeros(x.shape[0], device=x.device)