4. Mixtures and discrete latents¶
When a model has a discrete latent variable, you have two options. Sample it (every gradient step pays a Monte Carlo penalty and you lose the score-function variance), or marginalize it out (sum over its support, get a deterministic log-likelihood, gradients flow cleanly). The right choice depends on the support size: if the discrete latent has only a handful of values per observation, marginalizing is the obvious win.
QVR makes marginalization a first-class block. The body of the block runs once per value the discrete latent can take; the runtime collects per-value log-likelihoods and combines them under the prior with a logsumexp. Mathematically this is an exact integration over the discrete latent; computationally it's the categorical-prior version of the Rao-Blackwellised gradient (Casella & Robert, 1996). The same syntax handles flat mixtures, hierarchical mixtures with grouping, and HMM-shaped models with a per-row latent.
A two-component Gaussian mixture¶
Each observation comes from one of two Gaussian clusters; we don't know which.
object Item : FinSet 500
object K : FinSet 2
program gmm : Item -> Item [effects=[Sample, Score, Marginal]]
sample probs : K <- HalfNormal(1.0)
sample mu_k : K <- Normal(0.0, 5.0)
sample sd_k : K <- HalfNormal(1.0)
marginalize z : K <- Categorical(probs)
observe y : Item <- Normal(mu_k[z], sd_k[z])
return y
export gmm
@config_enumerate
def model(data):
probs = pyro.sample("probs", dist.Dirichlet(torch.ones(2)))
mu_k = pyro.sample("mu_k", dist.Normal(0., 5.).expand([2]).to_event(1))
sd_k = pyro.sample("sd_k", dist.HalfNormal(1.).expand([2]).to_event(1))
with pyro.plate("data", len(data)):
z = pyro.sample("z", dist.Categorical(probs),
infer={"enumerate": "parallel"})
pyro.sample("y", dist.Normal(mu_k[z], sd_k[z]), obs=data)
data { int N; vector[N] y; }
parameters {
simplex[2] probs;
ordered[2] mu_k;
vector<lower=0>[2] sd_k;
}
model {
probs ~ dirichlet([1, 1]');
mu_k ~ normal(0, 5);
sd_k ~ normal(0, 1);
for (n in 1:N) {
vector[2] lp;
for (k in 1:2)
lp[k] = log(probs[k])
+ normal_lpdf(y[n] | mu_k[k], sd_k[k]);
target += log_sum_exp(lp);
}
}
The marginalize z : K <- Categorical(probs) block, with its body of observe steps indented underneath, is exactly the Stan log_sum_exp pattern, expressed once and instantiated for every row of the response. The effects=[Marginal] entry in the option block makes the marginalization visible at the program signature.
! Fitting the mixture¶
(See docs/examples/source/mixture_model.qvr for the full end-to-end version with grouped marginalisation and the factor patterns needed to drive the body. The snippet below shows the shape of the fit; for a running version copy from the example.)
import torch
from quivers.dsl import loads
from quivers.inference import AutoNormalGuide, ELBO, SVI
GMM_SRC = """
object Item : FinSet 500
object K : FinSet 2
program gmm : Item -> Item
sample probs : K <- HalfNormal(1.0)
sample mu_k : K <- Normal(0.0, 5.0)
sample sd_k : K <- HalfNormal(1.0)
marginalize z : K <- Categorical(probs)
observe y : Item <- Normal(mu_k[z], sd_k[z])
return y
export gmm
"""
program = loads(GMM_SRC)
model = program.morphism
torch.manual_seed(0)
true_mu = torch.tensor([-2.0, 2.0])
true_sd = torch.tensor([0.5, 0.7])
z_true = torch.bernoulli(torch.full((500,), 0.6)).long()
y_data = (torch.randn(500) * true_sd[z_true] + true_mu[z_true]).unsqueeze(0)
guide = AutoNormalGuide(model, observed_names={"y"})
elbo = ELBO(num_particles=1)
optimizer = torch.optim.Adam(
list(model.parameters()) + list(guide.parameters()), lr=1e-2,
)
svi = SVI(model, guide, optimizer, elbo)
x_tensor = torch.zeros(1, 1)
observations = {"y": y_data}
for _ in range(20): # bump to ~3000 for real fits
svi.step(x_tensor, observations)
The marginalize block is integrated out exactly at every SVI step, so the gradients on mu_k, sd_k, and probs flow through a smooth logsumexp. No discrete-sampling variance contaminates the ELBO.
Hierarchical mixtures with grouping¶
Suppose each observation belongs to one of G groups, and the categorical mixture proportions vary by group. The marginalization has to respect group membership: the log-likelihood over the discrete latent gets aggregated per group, not per row. The marginalize header declares the grouping plate (over G); each observe inside the body carries its own via <idx> clause naming the fibration from its response plate into the grouping plate.
object Item : FinSet 1000
object G : FinSet 20
object K : FinSet 3
program grouped_mixture : Item -> Item [effects=[Sample, Score, Marginal]]
sample group : Item <- HalfNormal(1.0)
sample probs : G <- HalfNormal(1.0)
sample mu_k : K <- Normal(0.0, 5.0)
sample sd_k : K <- HalfNormal(1.0)
marginalize z : K <- Categorical(probs) [over=G]
observe y : Item <- Normal(mu_k[z], sd_k[z]) [via=group]
return y
export grouped_mixture
The [over=G] entry on the marginalize step's option block declares the grouping plate; the [via=group] entry on each observe says "every row of y is fibred over G by group, and the marginalization is per group, not per row." The group : Item <- HalfNormal(1.0) line names a per-row fibration into G: today the DSL doesn't have a dedicated fibration declaration, so the canonical idiom is to declare it as if it were a per-row latent and then supply the integer indices through the observations dict at fit time (cf. docs/examples/source/mixture_model.qvr). The block contributes
to the log-density, which is the right Kan extension along the fibration Item -> G and matches Stan's target += log_mix(probs[g], ll_item[i]) accumulation. A grouped block can contain multiple observes, each with its own [via=<idx>] entry, when several heterogeneous response axes share the same per-group class indicator; the per-axis log-likelihoods scatter-sum into the same (|G|, K) accumulator before the log-sum-exp.
When to marginalize vs sample¶
The decision is a straight cost-benefit:
marginalizecosts roughlyK × (body cost)per evaluation. The reward is exact gradients with respect to the discrete-prior parameters and zero Monte Carlo variance on the discrete latent.- Sampling the discrete latent with
ScoreFunction(REINFORCE) costs(body cost)per evaluation but pays Monte Carlo variance that scales with \(K\) and with the variance of the per-component log-likelihood. In practice, REINFORCE gradients are noisy enough that you usually need 10x to 100x more SVI steps to match a marginalized run, swamping the per-step savings.
The crossover is around K ≈ 32: below that, marginalize wins on wall-clock to convergence; above that, the K× factor on the body becomes painful and a relaxed continuous proxy (Gumbel-softmax, Maddison, Mnih & Teh, 2017) tends to be the better tradeoff.
| Discrete support per row | Recommendation |
|---|---|
K ≤ 8 |
marginalize, always. |
8 < K ≤ 32 |
marginalize unless the body itself is expensive. |
32 < K ≤ 100 |
Profile both; relaxed continuous if the body is heavy. |
K > 100 or unbounded |
Gumbel-softmax or score-function with a learned baseline. |
| Continuous-discrete mixture | marginalize the discrete part, reparameterize the continuous part. |
flowchart LR
A["program block"] -- "marginalize z : K" --> B["body runs K times,<br/>once per z value"]
B --> C["logsumexp over K"]
C --> D["score added to ELBO"]
A --> D
Try this¶
- Initialise the GMM with
K = 4and watch what happens to the recoveredmu_k. (Hint: mixture models have a label-switching identifiability problem, Stephens, 2000; the standard fix isordered[K] mu_kin Stan; in QVR you'd add alet mu_k_sorted = sort(mu_k)constraint or use an ordered prior.) - Convert the grouped mixture to a
marginalizewithout theover/viaclauses and observe the difference: per-row marginalization versus per-group. - Combine with chapter 3's plate-draws: a hierarchical mixture where each group has its own
mu_kdrawn from a hyperprior.
Next¶
Chapter 5 looks at sequence-shaped models: HMMs, state-space models, and the chart-shaped deduction surface.
References¶
- Chris J. Maddison, Andriy Mnih, and Yee Whye Teh. 2017. The Concrete distribution: A continuous relaxation of discrete random variables. arXiv preprint arXiv:1611.00712.
- George Casella and Christian P. Robert. 1996. Rao-Blackwellisation of sampling schemes. Biometrika, 83(1):81–94.
- Matthew Stephens. 2000. Dealing with label switching in mixture models. Journal of the Royal Statistical Society Series B: Statistical Methodology, 62(4):795–809.