ELBO

Evidence lower bound computation for variational inference.

elbo

Evidence lower bound (ELBO) computation.

The ELBO is the central objective for variational inference:

ELBO = E_q[log p(x, z) - log q(z)]

Maximizing the ELBO is equivalent to minimizing KL(q || p). This module computes the negative ELBO (a loss to minimize) using Monte Carlo estimation with multiple particles.

ELBO

ELBO(num_particles: int = 1)

Bases: Module

Compute the negative ELBO (loss to minimize).

PARAMETER DESCRIPTION
num_particles

Number of Monte Carlo samples for estimating the expectation.

TYPE: int DEFAULT: 1

Source code in src/quivers/inference/elbo.py
30
31
32
def __init__(self, num_particles: int = 1) -> None:
    super().__init__()
    self.num_particles = num_particles

forward

forward(model: MonadicProgram, guide: Guide, x: Tensor, observations: dict[str, Tensor]) -> Tensor

Compute negative ELBO.

PARAMETER DESCRIPTION
model

The generative model.

TYPE: MonadicProgram

guide

The variational guide.

TYPE: Guide

x

Program input. Shape (batch, ...).

TYPE: Tensor

observations

Observed variable values.

TYPE: dict[str, Tensor]

RETURNS DESCRIPTION
Tensor

Scalar negative ELBO (averaged over batch and particles).

Source code in src/quivers/inference/elbo.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def forward(
    self,
    model: MonadicProgram,
    guide: Guide,
    x: torch.Tensor,
    observations: dict[str, torch.Tensor],
) -> torch.Tensor:
    """Compute negative ELBO.

    Parameters
    ----------
    model : MonadicProgram
        The generative model.
    guide : Guide
        The variational guide.
    x : torch.Tensor
        Program input. Shape (batch, ...).
    observations : dict[str, torch.Tensor]
        Observed variable values.

    Returns
    -------
    torch.Tensor
        Scalar negative ELBO (averaged over batch and particles).
    """
    total = torch.tensor(0.0, device=x.device)

    for _ in range(self.num_particles):
        # sample latents from the guide
        latents = guide.rsample(x)

        # merge latents and observations for log-joint
        all_sites = {**latents, **observations}

        # model log-joint: log p(z, y_obs | x)
        model_lp = model.log_joint(x, all_sites)

        # guide log-prob: log q(z | x)
        guide_lp = guide.log_prob(x, latents)

        # elbo = E_q[log p - log q], loss = -elbo = E_q[log q - log p]
        total = total + (guide_lp - model_lp).mean()

    return total / self.num_particles