SVI

Stochastic variational inference algorithms.

svi

Stochastic Variational Inference (SVI) training loop.

SVI optimizes a quivers.inference.objectives.Objective — the ELBO, IWAE, Rényi, or VR-IWAE bound — by taking gradient steps on the guide and model parameters. The objective parameter accepts any Objective subclass.

SVI

SVI(model: MonadicProgram, guide: Guide, optim: Optimizer, objective: Objective)

Stochastic Variational Inference optimizer.

PARAMETER DESCRIPTION
model

Generative model.

TYPE: MonadicProgram

guide

Variational guide.

TYPE: Guide

optim

Optimiser for both model and guide parameters.

TYPE: Optimizer

objective

Variational objective (ELBO, IWAEBound, RenyiBound, VRIWAEBound, …).

TYPE: Objective

Source code in src/quivers/inference/svi.py
34
35
36
37
38
39
40
41
42
43
44
def __init__(
    self,
    model: MonadicProgram,
    guide: Guide,
    optim: torch.optim.Optimizer,
    objective: Objective,
) -> None:
    self.model = model
    self.guide = guide
    self.optim = optim
    self.objective = objective

step

step(x: Tensor, observations: dict[str, Tensor]) -> float

One SVI step.

PARAMETER DESCRIPTION
x

Program input. Shape (batch, ...).

TYPE: Tensor

observations

Observed variable values + host data (the non-site keys are exposed to the trace via the condition machinery).

TYPE: dict[str, Tensor]

RETURNS DESCRIPTION
float

Scalar loss for this step.

Source code in src/quivers/inference/svi.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def step(
    self,
    x: torch.Tensor,
    observations: dict[str, torch.Tensor],
) -> float:
    """One SVI step.

    Parameters
    ----------
    x : torch.Tensor
        Program input. Shape ``(batch, ...)``.
    observations : dict[str, torch.Tensor]
        Observed variable values + host data (the
        non-site keys are exposed to the trace via the
        ``condition`` machinery).

    Returns
    -------
    float
        Scalar loss for this step.
    """
    self.optim.zero_grad()
    loss_val = self.objective(self.model, self.guide, x, observations)
    loss_val.backward()
    self.optim.step()
    return loss_val.item()