Monadic Programs¶
What is a MonadicProgram?¶
A MonadicProgram is a probabilistic program specified as a sequence of draw and let steps. It defines a ContinuousMorphism from a domain to a codomain via monadic composition (Kleisli bind).
The program syntax mirrors probabilistic programming languages (PDS, Pyro):
program name : domain -> codomain
draw x₁ ~ morphism_1
draw x₂ ~ morphism_2(x₁)
let y = x₁ + x₂
observe z ~ morphism_3(y)
return y
Each draw step samples from a conditional distribution, binding the result. The observe keyword conditions the program on an external observation.
Program Structure¶
A program is an nn.Module that, when called, executes forward ancestral sampling:
from quivers.continuous.programs import MonadicProgram
from quivers.continuous.families import ConditionalNormal
from quivers.core.objects import FinSet
import torch
# Build program manually
program = MonadicProgram(
domain=FinSet("input", 5),
codomain=FinSet("output", 3),
)
# Register morphisms
f = ConditionalNormal(...)
program.add_morphism("f", f)
# Add steps
program.add_draw("x", "f", args=None)
program.add_draw("y", "f", args=("x",))
program.add_return("y")
# Forward pass: sampling
samples = program(torch.randn(5), n_samples=100) # shape (100, 3)
# Log joint: log p(output, latents | input)
log_joint = program.log_joint(input_data, output_data)
Draw Steps¶
A draw step draw x ~ f or draw x ~ f(y, z) samples from a morphism, optionally conditioned on previous variables.
Single draw:
draw x ~ prior_f
Conditioned draw:
draw y ~ likelihood_f(x)
Multiple arguments (stacked along feature dimension):
draw (x, y) ~ joint_f(z, w)
The variable names on the left side are bound in the environment.
Let Steps¶
Deterministic binding:
let x = y + z
let weight = 0.5
Supports literals, variable references, and simple callable expressions.
Observe Keyword¶
Condition the program on an observation:
observe y ~ likelihood(x)
This marks y as conditioned. During inference, observations clamp these variables to external values.
Return Statement¶
Specify the program output. Single or tuple:
return x
return (x, y, z)
The return value's shape determines the codomain.
Domains and Codomains¶
Domains can be:
- A single FinSet or ContinuousSpace
- A product of sets/spaces: X * Y * Z
- Named parameters: the domain is the product, but variables can refer to sub-components
Codomains are determined by the return statement shape.
ReSample and Log Joint¶
Two key operations:
rsample(domain_values, n_samples)¶
Generate samples by executing the program:
domain_val = torch.randn(5)
samples = program.rsample(domain_val, n_samples=1000)
# shape: (1000, codomain_dim)
Sequential ancestral sampling: each draw step samples, previous draws are available to subsequent steps.
log_joint(domain_values, codomain_values)¶
Compute \(\log p(y, z_1, \ldots, z_k | x)\), where \(x\) is domain input, \(y\) is codomain (return value), and \(z_i\) are intermediate latent draws:
x = torch.randn(5)
y = torch.randn(3) # output
log_pjoint = program.log_joint(x, y)
# scalar or batch (depending on input shapes)
Useful for variational inference: log_joint enters the ELBO computation.
Named Parameters¶
If the domain is a product, define sub-domains:
program = MonadicProgram(
domain=FinSet("A", 3) * FinSet("B", 4),
codomain=FinSet("Z", 5),
)
program.add_param("a", FinSet("A", 3))
program.add_param("b", FinSet("B", 4))
# Now steps can reference a, b by name
program.add_draw("x", "f", args=("a", "b"))
Example: A Simple Model¶
from quivers.continuous.programs import MonadicProgram
from quivers.continuous.families import (
ConditionalNormal,
ConditionalLogitNormal,
)
from quivers.core.objects import Unit
from quivers.continuous.spaces import Euclidean
import torch.nn as nn
# Build a linear regression model
prior_mu = nn.Linear(1, 1)
prior_sigma = nn.Linear(1, 1)
likelihood_sigma = nn.Linear(1, 1)
program = MonadicProgram(
domain=Unit,
codomain=Euclidean(1),
)
# Prior on μ
f_mu = ConditionalNormal(Unit, Euclidean(1))
program.add_morphism("prior_mu", f_mu)
# Prior on σ
f_sigma = ConditionalLogitNormal(Unit, Euclidean(1))
program.add_morphism("prior_sigma", f_sigma)
# Likelihood
f_like = ConditionalNormal(Euclidean(2), Euclidean(1))
program.add_morphism("likelihood", f_like)
# Steps
program.add_draw("mu", "prior_mu")
program.add_draw("sigma", "prior_sigma")
program.add_draw("y", "likelihood", args=("mu", "sigma"))
program.add_return("y")
# Use for inference
optimizer = torch.optim.Adam(program.parameters())
Destructuring Draws¶
Extract multiple values from a tuple-returning sub-program:
program sub : X -> Y * Y
draw (a, b) ~ some_morphism
return (a, b)
program main : X -> Z
draw (u, v) ~ sub
draw w ~ g(u, v)
return w
The pattern (u, v) ~ sub destructures the output.
Observation Clamping¶
During inference, the condition() function clamps observations:
from quivers.inference import condition
# Condition program on external observations
observed_y = torch.tensor([1.0, -0.5, 2.0])
conditioned = condition(program, {"y": observed_y})
# Forward pass on conditioned program uses the clamped value
log_pjoint = conditioned.log_joint(x, observed_y)
Product Domains and Outputs¶
For multiple domain inputs, stack along the feature dimension:
program f : (X * Y) -> Z
draw (x_val, y_val) from domain input
draw z ~ g(x_val, y_val)
return z
Internally, the domain tensor is reshaped to match.
Integration with DSL¶
MonadicPrograms are the output of .qvr DSL compilation (see DSL guide). The DSL parser translates:
object X : 3
object Y : 4
program my_prog : X -> Y
draw mu ~ LogitNormal(0, 1)
draw x ~ Normal(mu, 1)
return x
output my_prog
into a MonadicProgram instance that can be trained.