Monadic Programs

What is a MonadicProgram?

A MonadicProgram is a probabilistic program specified as a sequence of bind and let steps. It defines a ContinuousMorphism from a domain to a codomain via monadic composition (Kleisli bind).

The program syntax mirrors probabilistic programming languages (Pyro, NumPyro, Stan):

program name : domain -> codomain
    x_1 <- morphism_1
    x_2 <- morphism_2(x_1)
    let y = x_1 + x_2
    observe z <- morphism_3(y)
    return y

Each <- bind samples from a conditional distribution, binding the result. The observe keyword conditions the program on an external observation.

Program structure

A program is an nn.Module that, when called, executes forward ancestral sampling:

from quivers.continuous.programs import MonadicProgram
from quivers.continuous.families import ConditionalNormal
from quivers.continuous.spaces import Euclidean
from quivers.core.objects import FinSet
import torch

Unit = FinSet(name="Unit", cardinality=1)
R1 = Euclidean(name="R1", dim=1)

prior      = ConditionalNormal(Unit, R1)            # x ~ Normal(0, 1)
likelihood = ConditionalNormal(R1, R1)              # y ~ Normal(x, 1)

program = MonadicProgram(
    Unit,
    R1,
    steps=[
        (("x",), prior, None),                      # x <- prior
        (("y",), likelihood, ("x",)),               # y <- likelihood(x)
    ],
    return_vars=("y",),
)

# Forward pass: sampling
samples = program.rsample(
    torch.zeros(4, dtype=torch.long), sample_shape=torch.Size([100])
)  # shape (100, 4, 1)

# Log joint: sum_i log p(z_i | pa(z_i)) given every bound variable
x_val = torch.randn(4, 1)
y_val = torch.randn(4, 1)
log_joint = program.log_joint(
    torch.zeros(4, dtype=torch.long),
    {"x": x_val, "y": y_val},
)

Bind steps

A bind step x <- f or x <- f(y, z) samples from a morphism, optionally conditioned on previous variables.

Single bind:

x <- prior_f

Conditioned bind:

y <- likelihood_f(x)

Destructuring tuple bind (stacked along feature dimension):

(x, y) <- joint_f(z, w)

The variable names on the left side are bound in the environment. An indexed bind v : A <- F(args) declares v as an \(A\)-indexed plate of independent draws.

Let steps

Deterministic binding:

let x = y + z
let weight = 0.5

Supports literals, variable references, arithmetic, and the let-expression primitive surface; see DSL Programs and Let-Expressions for the full primitive list.

Observe keyword

Condition the program on an observation:

observe y <- likelihood(x)

This marks y as conditioned. During inference, observations clamp these variables to external values. An indexed-observe observe r : N <- F(args) accumulates a batched likelihood over the index set N, with the response buffer supplied via the runtime observations dict.

Return statement

Specify the program output. Single or tuple:

return x
return (x, y, z)

The return value's shape determines the codomain. Tuples are bare-positional; the resulting product space's components are ordered by tuple position.

Domains and codomains

Domains can be:

  • A single FinSet or ContinuousSpace.
  • A product of sets / spaces: X * Y * Z.
  • Named parameters: the domain is the product, but variables can refer to sub-components.

Codomains are determined by the return statement shape.

rsample and log_joint

Two key operations on a compiled program:

rsample(x, sample_shape=(), observations=None)

Generate samples by executing the program:

x = torch.randn(5)
samples = program.rsample(x, sample_shape=torch.Size([1000]))
# shape: (1000, codomain_dim)

Sequential ancestral sampling: each draw step samples, previous draws are available to subsequent steps. observations is an optional dict[str, torch.Tensor] clamping observed sites to runtime data.

log_joint(x, intermediates)

Compute \(\log p(z_1, \ldots, z_k \mid x) = \sum_i \log p(z_i \mid \mathrm{pa}(z_i))\) given all bound-variable values:

x = torch.randn(5)
intermediates = {"z": z_value, "y": y_value}  # every bound variable

log_pjoint = program.log_joint(x, intermediates)

log_joint is the core kernel summed across the program's draw / plate-draw / observe steps, used inside ELBO.forward after the guide samples latents.

The observations dict

Indexed-observe steps (observe r : N <- F(args)) read their response buffers from a runtime observations: dict[str, torch.Tensor], keyed by the observed-variable name. The dict is passed as the observations kwarg to MonadicProgram.rsample and as the final positional argument to ELBO.forward / SVI.step:

observations = {
    "cloze_resp": cloze_tensor,    # shape (n_cloze_resp,)
    "prop_resp":  prop_tensor,     # shape (n_prop_resp,)
}

samples = program.rsample(x, observations=observations)
loss = elbo(model, guide, x, observations)

There is no .qvr-level data block; the tensor sources live in Python at the call site, and the keys must match the response identifiers declared in the program body.

Named parameters

If the domain is a product, name the components via the params argument so steps can reference them by name:

A = FinSet(name="A", cardinality=3)
B = FinSet(name="B", cardinality=4)
Z = FinSet(name="Z", cardinality=5)

program = MonadicProgram(
    A * B,
    Z,
    steps=[
        (("x",), f, ("a", "b")),                    # x <- f(a, b)
    ],
    return_vars=("x",),
    params=("a", "b"),
)

The program splits the product input along the feature axis at runtime and binds each slice to the corresponding name in params.

Example: a simple model

from quivers.continuous.programs import MonadicProgram
from quivers.continuous.families import (
    ConditionalNormal,
    ConditionalLogitNormal,
)
from quivers.continuous.spaces import Euclidean
from quivers.core.objects import FinSet

Unit = FinSet(name="Unit", cardinality=1)
R1 = Euclidean(name="R1", dim=1)
R2 = Euclidean(name="R2", dim=2)

prior_mu    = ConditionalNormal(Unit, R1)
prior_sigma = ConditionalLogitNormal(Unit, R1)
likelihood  = ConditionalNormal(R2, R1)

program = MonadicProgram(
    Unit,
    R1,
    steps=[
        (("mu",),    prior_mu,    None),
        (("sigma",), prior_sigma, None),
        (("y",),     likelihood,  ("mu", "sigma")),
    ],
    return_vars=("y",),
)

# Use for inference
optimizer = torch.optim.Adam(program.parameters())

Destructuring binds

Extract multiple values from a tuple-returning sub-program:

program sub : X -> Y * Y
    (a, b) <- some_morphism
    return (a, b)

program main : X -> Z
    (u, v) <- sub
    w <- g(u, v)
    return w

The pattern (u, v) <- sub destructures the output.

Observation clamping

During inference, the condition() function clamps observations:

from quivers.inference import condition

# Condition program on external observations
observed_y = torch.tensor([1.0, -0.5, 2.0])

conditioned = condition(program, {"y": observed_y})

# Trace under the conditioning: observed sites are clamped to the data
tr = conditioned.trace(x)

Product domains and outputs

For multiple domain inputs, stack along the feature dimension:

program f(x_val, y_val) : (X * Y) -> Z
    z <- g(x_val, y_val)
    return z

The bare-identifier parameters x_val, y_val name the projections of the product domain. Internally, the domain tensor is reshaped to match.

Integration with the DSL

MonadicPrograms are the output of .qvr DSL compilation. The DSL parser translates:

object X : FinSet 3
object Y : FinSet 4
program my_prog : X -> Y
    sample mu <- LogitNormal(0, 1)
    sample x <- Normal(mu, 1)
    return x

export my_prog

into a MonadicProgram instance that can be trained. The full DSL surface for programs lives in DSL Programs and Let-Expressions.

Where to next

  • Hierarchical Programs: parametric templates for crossed random intercepts, monotone-spline coefficients, and the grouped marginalization construct for fibred discrete latents.
  • Variational Inference: how programs feed into the variational training loop.