Multi-Output Negative Binomial Regression

Overview

A multi-output negative-binomial regression for overdispersed count data, using the mean / dispersion parameterization that follows the log link convention shared with Poisson regression. Each output dimension carries its own coefficients and dispersion; the response is the standard NB2 form with per-cell variance mu + mu^2 / dispersion, recovering Poisson in the limit of infinite dispersion.

QVR Source

object Item : FinSet 200
object Out : FinSet 3
object Resp : FinSet 600

program negbin_regression : Resp -> Resp
    sample beta_0 : Out <- Normal(0.0, 5.0)
    sample beta_1 : Out <- Normal(0.0, 5.0)
    sample dispersion : Out <- Gamma(2.0, 0.5)

    let b0 = beta_0[out_idx]
    let b1 = beta_1[out_idx]
    let disp = dispersion[out_idx]
    let eta = b0 + b1 * x
    let mu = exp(eta)
    let probs = disp / (disp + mu)

    observe y : Resp <- NegativeBinomial(disp, probs)
    return beta_1

export negbin_regression

Walkthrough

Per-output coefficient and dispersion plates broadcast through out_idx gathers. The per-cell linear predictor eta = b0 + b1 * x is mapped through the log link exp to give the conditional mean mu. The NB2 parameterization uses probs = dispersion / (dispersion + mu) so the resulting NegativeBinomial(dispersion, probs) has mean mu and variance mu * (1 + mu / dispersion). The Gamma prior on dispersion encodes a soft preference for finite overdispersion; per-output dispersion permits heterogeneous count regimes across the response axis.

Try it

The SVI step counts and NUTS warmup, sample, and chain budgets in the snippets below are illustrative: each block is sized to run in tens of seconds and demonstrate the API surface. Production fits typically need 10x to 100x more SVI steps, longer NUTS warmup, and multiple chains to actually converge to the data-generating parameters.

Generating synthetic data

import torch
from quivers.dsl import load

torch.manual_seed(0)
prog = load("docs/examples/source/negbin_regression.qvr")
model = prog.morphism

N, D = 21, 3
ND = N * D
out_idx = torch.arange(D).repeat(N)
x = torch.randn(ND)

true_b0 = torch.tensor([1.0, 0.5, 2.0])
true_b1 = torch.tensor([0.5, -0.3, 0.8])
true_disp = torch.tensor([5.0, 10.0, 3.0])
mu_true = torch.exp(true_b0[out_idx] + true_b1[out_idx] * x)
probs_true = true_disp[out_idx] / (true_disp[out_idx] + mu_true)
y = torch.distributions.NegativeBinomial(true_disp[out_idx], probs_true).sample()

observations = {"x": x, "y": y, "out_idx": out_idx}
x_in = torch.zeros(ND, 1)

SVI fit

from quivers.inference import AutoNormalGuide, ELBO, SVI

oracle_nll = float(
    -torch.distributions.NegativeBinomial(true_disp[out_idx], probs_true)
    .log_prob(y)
    .mean()
)

torch.manual_seed(1)
guide = AutoNormalGuide(model, observed_names={"x", "y", "out_idx"})
optim = torch.optim.Adam(
    list(model.parameters()) + list(guide.parameters()), lr=5e-2,
)
svi = SVI(model, guide, optim, ELBO(num_particles=1))

losses = []
for _ in range(300):
    losses.append(svi.step(x_in, observations))

print(f"initial loss: {losses[0]:.2f}")
print(f"final loss:   {losses[-1]:.2f}")
print(f"oracle NLL:   {oracle_nll:.2f}")

NUTS posterior

from quivers.inference import MCMC, NUTSKernel

N_mcmc, D_mcmc = 11, 3
ND_mcmc = N_mcmc * D_mcmc
out_idx_mcmc = torch.arange(D_mcmc).repeat(N_mcmc)
x_mcmc = x[:ND_mcmc]
y_mcmc = y[:ND_mcmc]
obs_mcmc = {"x": x_mcmc, "y": y_mcmc, "out_idx": out_idx_mcmc}
x_in_mcmc = torch.zeros(ND_mcmc, 1)

torch.manual_seed(2)
kernel = NUTSKernel(step_size=0.05, max_tree_depth=3, target_accept=0.8)
mc = MCMC(kernel, num_warmup=20, num_samples=20, num_chains=1)
result = mc.run(model, x_in_mcmc, obs_mcmc)

print(f"acceptance:  {float(result.acceptance_rates.mean()):.2f}")
print(f"divergences: {int(result.divergence_counts.sum())}")

Categorical Perspective

The negative binomial is the Gamma-Poisson mixture: a Poisson(rate) kernel with rate ~ Gamma(dispersion, dispersion / mu) marginalizes to NegativeBinomial(dispersion, mu / (mu + dispersion)). The model factors through this mixture by sampling per-cell from the closed-form negative binomial; categorically the family is the pushforward of the Gamma-Poisson joint kernel along the rate-projection.