Multi-Output Zero-Inflated Poisson Regression

Overview

The zero-inflated Poisson regression (Lambert 1992) is a two-component mixture of a point mass at zero and a Poisson rate. The model fits count data with an excess of structural zeros relative to a plain Poisson likelihood. Each output dimension carries its own zero-inflation logits and rate coefficients, and the per-cell zero-inflation indicator is integrated out by a scoped marginalize block.

QVR Source

object Item : FinSet 200
object Out : FinSet 2
object Resp : FinSet 400

program zip_regression : Resp -> Resp
    sample alpha_zero : Out <- Normal(0.0, 5.0)
    sample beta_zero : Out <- Normal(0.0, 5.0)
    sample alpha_rate : Out <- Normal(0.0, 5.0)
    sample beta_rate : Out <- Normal(0.0, 5.0)

    let az = alpha_zero[out_idx]
    let bz = beta_zero[out_idx]
    let ar = alpha_rate[out_idx]
    let br = beta_rate[out_idx]
    let pi_z = sigmoid(az + bz * x)
    let rate = exp(ar + br * x)

    marginalize z : Resp <- ContinuousBernoulli(pi_z)
        let gated_rate = z * rate
        observe y : Resp <- Poisson(gated_rate)

    return beta_rate

export zip_regression

Walkthrough

Per-output coefficient plates alpha_zero, beta_zero carry the logit-link zero-inflation probability pi_{n, d}, and alpha_rate, beta_rate carry the log-link Poisson rate rate_{n, d}. The zero-inflation indicator z is sampled per cell from a ContinuousBernoulli relaxation of the underlying Bernoulli, then integrated out by the enclosing marginalize z block: the coordinate is pushed forward through the projection on the trace's z axis, integrating out the indicator via reparameterised sampling. The continuous-Bernoulli relaxation gives a closed-form tractable density on (0, 1) and lets SVI integrate the coordinate via reparameterised sampling; the canonical logsumexp marginalization over the two integer states is the limiting case as the relaxation temperature tightens.

Try it

The SVI step counts and NUTS warmup, sample, and chain budgets in the snippets below are illustrative: each block is sized to run in tens of seconds and demonstrate the API surface. Production fits typically need 10x to 100x more SVI steps, longer NUTS warmup, and multiple chains to actually converge to the data-generating parameters.

Generating synthetic data

import torch
from quivers.dsl import load

torch.manual_seed(0)
prog = load("docs/examples/source/zip_regression.qvr")
model = prog.morphism

N, D = 200, 2
ND = N * D
out_idx = torch.arange(D).repeat(N)
x = torch.randn(ND)

true_az = torch.tensor([0.5, 1.0])
true_bz = torch.tensor([0.3, -0.2])
true_ar = torch.tensor([0.5, 1.0])
true_br = torch.tensor([1.0, -0.5])
pi_true = torch.sigmoid(true_az[out_idx] + true_bz[out_idx] * x)
rate_true = torch.exp(true_ar[out_idx] + true_br[out_idx] * x)
z_struct = torch.bernoulli(pi_true)
y = z_struct * torch.poisson(rate_true)

observations = {"x": x, "y": y, "out_idx": out_idx}
x_in = torch.zeros(ND, 1)

SVI fit

from quivers.inference import AutoNormalGuide, ELBO, SVI

log_p_y0 = torch.log((1 - pi_true) + pi_true * torch.exp(-rate_true))
log_p_yk = torch.log(pi_true) + torch.distributions.Poisson(rate_true).log_prob(y)
oracle_nll = float(-torch.where(y == 0, log_p_y0, log_p_yk).mean())

torch.manual_seed(1)
guide = AutoNormalGuide(model, observed_names={"x", "y", "out_idx"})
optim = torch.optim.Adam(
    list(model.parameters()) + list(guide.parameters()), lr=5e-2,
)
svi = SVI(model, guide, optim, ELBO(num_particles=1))

losses = []
for _ in range(300):
    losses.append(svi.step(x_in, observations))

print(f"initial loss: {losses[0]:.2f}")
print(f"final loss:   {losses[-1]:.2f}")
print(f"oracle NLL:   {oracle_nll:.2f}")

NUTS posterior

from quivers.inference import MCMC, NUTSKernel

torch.manual_seed(2)
kernel = NUTSKernel(step_size=0.05, max_tree_depth=3, target_accept=0.8)
mc = MCMC(kernel, num_warmup=20, num_samples=20, num_chains=1)
result = mc.run(model, x_in, observations)

print(f"acceptance:  {float(result.acceptance_rates.mean()):.2f}")
print(f"divergences: {int(result.divergence_counts.sum())}")

Categorical Perspective

The model factors as a Kleisli composite of two kernels: a per-cell ContinuousBernoulli(pi) kernel on the unit interval and a Poisson(rate) kernel on the non-negative integers. The scoped marginalize step pushes forward the joint measure on the trace's z axis through projection, integrating out the indicator and leaving the marginal Poisson likelihood reweighted by the per-cell mixing weight. Categorically the construction is a coproduct fibration over the binary indicator axis, followed by logsumexp on the accumulated log-likelihood in the discrete-limit case and reparameterised integration in the relaxed case.

References

  • Diane Lambert. 1992. Zero-inflated Poisson regression, with an application to defects in manufacturing. Technometrics, 34(1):1–14.