Tutorial 5: Variational Inference¶
In this tutorial, you will fit a probabilistic program to observed data using variational inference. You will set up a model, condition it on observations, create a variational guide, define an ELBO loss, run a training loop with SVI, and use Predictive for posterior sampling.
The pipeline:
sequenceDiagram
participant User
participant trace as trace(model, ...)
participant condition as condition(model, obs)
participant guide as Guide
participant svi as SVI
participant pred as Predictive
User->>trace: model + inputs
trace-->>User: Trace (values, log-densities per site)
User->>condition: observations dict
condition-->>User: conditioned model
User->>guide: AutoNormalGuide(model, observed_names)
loop fit
User->>svi: svi.step(x, observations)
svi->>guide: rsample
guide-->>svi: posterior draw
svi-->>User: ELBO loss
end
User->>pred: Predictive(model, posterior=guide)
pred-->>User: posterior + predictive draws
Concepts¶
- Trace: Record of a program execution with sample sites and log-densities
- Conditioning: Fixing observed variables to particular values
- Guide: A variational approximation to the posterior distribution
- ELBO: Evidence Lower Bound, a loss function for variational inference
- SVI: Stochastic Variational Inference, a gradient-based optimization algorithm
- Predictive: Posterior sampling using the fitted guide
Setup¶
import torch
import torch.optim as optim
from quivers.core.objects import FinSet
from quivers.continuous.spaces import Euclidean
from quivers.continuous.families import ConditionalNormal
from quivers.continuous.programs import MonadicProgram
from quivers.inference import (
trace,
condition,
AutoNormalGuide,
ELBO,
SVI,
Predictive,
)
Building a Model¶
Create a simple generative model: latent variable z drives observation y.
Unit = FinSet(name="Unit", cardinality=1)
R = Euclidean(name="real", dim=1)
prior = ConditionalNormal(Unit, R)
likelihood = ConditionalNormal(R, R)
model = MonadicProgram(
Unit, R,
steps=[
(("z",), prior, None), # z <- prior(unit)
(("y",), likelihood, ("z",)), # y <- likelihood(z)
],
return_vars=("z", "y"),
)
Simulating Observed Data¶
Generate synthetic observations from the model (as if from an experiment):
torch.manual_seed(42)
# Sample from the prior
batch = torch.zeros(50, dtype=torch.long) # 50 samples
samples = model.rsample(batch)
# Tuple returns come back as a dict keyed by variable name.
# Detach so the simulated observations don't carry a grad history
# back into the model parameters during SVI.
y_observed = samples["y"].detach()
print(y_observed.shape) # [50, 1]
print(y_observed.mean(), y_observed.std())
In practice, these observations come from real data. Here we simulate for illustration.
Tracing the Model¶
A trace records the values and log-densities at each site (random variable):
# Trace the model at a single point
x_single = torch.zeros(1, dtype=torch.long)
tr = trace(model, x_single)
# Inspect sites
print("Sites:", tr.sites.keys()) # {'z': SampleSite, 'y': SampleSite}
# Access a site
z_site = tr.sites["z"]
print("z value:", z_site.value)
print("z log_prob:", z_site.log_prob)
Each site has:
value: The sampled valuelog_prob: The log-probability under the distributionis_observed: Whether this site is conditioned (fixed)is_deterministic: Whether it is a let binding
Conditioning on Observations¶
Wrap the model to fix observed variables:
# Fix y to observed values
conditioned_model = condition(model, {"y": y_observed})
# Trace the conditioned model
tr_cond = conditioned_model.trace(x_single)
# Now y is fixed
print(tr_cond.sites["y"].is_observed) # True
print(tr_cond.sites["y"].value) # matches y_observed[0]
print(tr_cond.sites["z"].is_observed) # False (still latent)
The conditioned model still allows the latent variable z to vary.
Creating a Guide¶
A guide is a variational approximation to the posterior. It has the same interface as the model but is typically simpler (e.g., a mean-field normal distribution).
# Create a guide: assumes posterior over z and y is normal
guide = AutoNormalGuide(model, observed_names={"y"})
The AutoNormalGuide:
- Identifies latent (non-observed) variables: just
zin this case - Creates learnable parameters for the mean and log-scale of a normal distribution
- Provides
rsample(x)andlog_prob(x, samples)methods
Sample from the guide:
x = torch.zeros(4, dtype=torch.long)
posterior_samples = guide.rsample(x)
print(posterior_samples) # dict: {"z": tensor, ...}
posterior_z = posterior_samples["z"]
print(posterior_z.shape) # [4]
Compute the guide's log-probability:
log_q = guide.log_prob(x, posterior_samples)
print(log_q.shape) # [4]
Passing Observations at Runtime¶
Programs that use indexed observes (observe r : N <- F(args)) read their response tensors from a runtime observations: dict[str, torch.Tensor] keyed by the observed-variable name. The dict is forwarded to MonadicProgram.rsample (kwarg) and to ELBO.forward / SVI.step (positional, after the program input):
observations = {"y": y_observed} # shape matches the program's N
# Once SVI is set up (next section), pass observations into svi.step:
# loss = svi.step(domain_input, observations)
There is no .qvr-level data block; observation tensors live in Python at the call site.
Setting Up Inference¶
Define the ELBO loss and optimizer:
elbo = ELBO(num_particles=1)
optimizer = optim.Adam(list(model.parameters()) + list(guide.parameters()), lr=0.01)
svi = SVI(model, guide, optimizer, elbo)
The SVI object pairs:
model: the generativeMonadicProgram. Observations are supplied to each step rather than baked in viacondition.guide: the variational posterior.optim: optimizer over the model and guide parameters.objective: anObjectivesuch asELBO.
Training Loop¶
Run inference to optimize the guide's parameters:
num_steps = 5
losses = []
# Prepare observed data
observations = {"y": y_observed} # 50 observed values
batch_size = 10
n_batches = len(y_observed) // batch_size
for step in range(num_steps):
# Shuffle and batch observations
indices = torch.randperm(len(y_observed))
batch_losses = []
for i in range(n_batches):
batch_idx = indices[i*batch_size:(i+1)*batch_size]
batch_obs = {"y": y_observed[batch_idx]}
# One SVI step: the observations dict flows through the
# objective at evaluation time, so the model itself does not
# need to be re-wrapped per batch.
loss = svi.step(
torch.zeros(batch_size, dtype=torch.long), batch_obs
)
batch_losses.append(loss)
epoch_loss = sum(batch_losses) / len(batch_losses)
losses.append(epoch_loss)
if step % 10 == 0:
print(f"Step {step}: Loss {epoch_loss:.4f}")
The ELBO loss combines:
- Model likelihood: How well the model explains observations under the guide's samples
- KL divergence: How close the guide is to the prior (regularization)
Minimizing ELBO maximizes the evidence log-likelihood and keeps the guide regularized.
Posterior Inference¶
After training, use the guide to make predictions on new data:
# Create a Predictive: posterior samples from the guide
predictive = Predictive(model, guide, num_samples=100)
# Sample posterior at a new point
x_new = torch.zeros(1, dtype=torch.long)
posterior_samples = predictive(x_new)
print(posterior_samples.keys()) # dict with z and y
z_posterior = posterior_samples["z"]
print(z_posterior.shape) # [100, 1] (num_samples, batch)
Analyze the posterior:
# Posterior mean and std
z_mean = z_posterior.mean(dim=0)
z_std = z_posterior.std(dim=0)
print(f"z posterior: mean={z_mean.item():.3f}, std={z_std.item():.3f}")
# Posterior credible interval
z_quantile_lower = z_posterior.quantile(0.025, dim=0)
z_quantile_upper = z_posterior.quantile(0.975, dim=0)
print(f"95% CI: [{z_quantile_lower.item():.3f}, {z_quantile_upper.item():.3f}]")
Evaluating the Guide¶
Compare the learned guide to the true posterior. Sample from both:
# True posterior: z conditioned on observed y
tr_true = conditioned_model.trace(x_new)
z_true = tr_true.sites["z"].value
# Posterior from guide
z_guide = guide.rsample(x_new)["z"]
print(f"True z: {z_true.item():.3f}")
print(f"Guide z: {z_guide.item():.3f}")
Visualize: plot the true posterior density vs. the guide's density (if tractable).
More Complex Models¶
The same pattern extends to complex models. For instance, with the PDS model from Tutorial 4:
from quivers.dsl import loads
prog_pds = loads("""
object Entity : FinSet 1
object Truth : FinSet 2
object Resp : FinSet 1
program factivity : Entity -> Truth * Truth * Truth * Resp
sample theta_know <- LogitNormal(0.0, 1.0)
sample theta_cg <- LogitNormal(0.0, 1.0)
let cg_complement = 1
sample tau_know <- Bernoulli(theta_know)
sample cg_matrix <- Bernoulli(theta_cg)
sample sigma <- Uniform(0.0, 1.0)
observe response <- TruncatedNormal(theta_know, sigma, 0.0, 1.0)
return (tau_know, cg_complement, cg_matrix, response)
""")
model_pds = prog_pds.morphism
# Observed response judgments from a linguistic experiment
observed_responses = torch.tensor([0.8, 0.6, 0.7, 0.9, 0.5])
# Build a guide and an SVI loop. Observations are routed through
# ``svi.step`` rather than wrapped into the model.
guide_pds = AutoNormalGuide(model_pds, observed_names={"response"})
elbo = ELBO(num_particles=1)
optimizer = optim.Adam(
list(model_pds.parameters()) + list(guide_pds.parameters()),
lr=0.01,
)
svi = SVI(model_pds, guide_pds, optimizer, elbo)
# Run training... pass {"response": observed_responses} as the
# observations dict on each step.
Summary¶
You have:
- Built a probabilistic model and simulated observations
- Traced a model to inspect sites and log-probabilities
- Conditioned a model on observed data
- Created an AutoNormalGuide as a posterior approximation
- Set up and ran a variational inference training loop
- Used Predictive for posterior sampling
- Evaluated the inferred posterior
This workflow applies to any quivers probabilistic program, from simple Gaussian models to complex linguistic models like PDS.
Next¶
Tutorial 6 covers first-class transformations: MorphismTransformation and AlgebraHomomorphism as values, the >>> composition operator, and change-of-base pipelines.
Further Reading¶
- Inference Guide: Detailed documentation of trace, conditioning, guides, ELBO, and SVI
- Continuous Morphisms: More on distributions and spaces
- DSL Guide: Writing models in
.qvrsyntax