Weibull Survival Regression

Overview

A parametric survival regression with a Weibull baseline (Klein and Moeschberger 2003). Each item carries a covariate x_i whose linear contribution scales the Weibull rate; the shape parameter k governs whether the hazard is decreasing (k < 1), constant (k = 1, exponential), or increasing (k > 1).

QVR Source

object Item : FinSet 200

program survival_weibull : Item -> Item
    sample alpha <- Normal(0.0, 5.0)
    sample beta <- Normal(0.0, 5.0)
    sample k <- Gamma(2.0, 1.0)

    let eta = alpha + beta * x
    let scale = exp(-eta / k)

    observe t : Item <- Weibull(scale, k)
    return beta

export survival_weibull

Walkthrough

The identifier x is the exogenous covariate: it is never declared inside the program, so the runtime resolves it from the observations dict at trace time, where the caller supplies the per-item predictor. The reparameterisation scale = exp(-eta / k) is the Weibull proportional-hazards convention: positive shifts in the linear predictor eta = alpha + beta * x increase the hazard and shorten survival times, matching the canonical direction. The shape k has a Gamma prior centered at 2. The observed event times t are uncensored Weibull draws; right-censoring is handled at the inference layer by substituting the Weibull survival function for the density on censored rows.

Try it

The SVI step counts and NUTS warmup, sample, and chain budgets in the snippets below are illustrative: each block is sized to run in tens of seconds and demonstrate the API surface. Production fits typically need 10x to 100x more SVI steps, longer NUTS warmup, and multiple chains to actually converge to the data-generating parameters.

Generating synthetic data

import torch
from quivers.dsl import load

torch.manual_seed(0)
prog = load("docs/examples/source/survival_weibull.qvr")
model = prog.morphism

N = 64
true_alpha = 0.5
true_beta = 1.0
true_k = 2.0

x = torch.randn(N)
eta_true = true_alpha + true_beta * x
scale_true = torch.exp(-eta_true / true_k)
t = torch.distributions.Weibull(scale_true, true_k).sample()

observations = {"x": x, "t": t}
x_in = torch.zeros(N, 1)

SVI fit

from quivers.inference import AutoNormalGuide, ELBO, SVI

oracle_nll = float(
    -torch.distributions.Weibull(scale_true, true_k).log_prob(t).mean()
)

torch.manual_seed(1)
guide = AutoNormalGuide(model, observed_names={"x", "t"})
optim = torch.optim.Adam(
    list(model.parameters()) + list(guide.parameters()), lr=5e-2,
)
svi = SVI(model, guide, optim, ELBO(num_particles=1))

losses = []
for _ in range(300):
    losses.append(svi.step(x_in, observations))

print(f"initial loss: {losses[0]:.2f}")
print(f"final loss:   {losses[-1]:.2f}")
print(f"oracle NLL:   {oracle_nll:.2f}")

NUTS posterior

from quivers.inference import MCMC, NUTSKernel

N_mcmc = 32
x_mcmc = x[:N_mcmc]
t_mcmc = t[:N_mcmc]
obs_mcmc = {"x": x_mcmc, "t": t_mcmc}
x_in_mcmc = torch.zeros(N_mcmc, 1)

torch.manual_seed(2)
kernel = NUTSKernel(step_size=0.05, max_tree_depth=3, target_accept=0.8)
mc = MCMC(kernel, num_warmup=20, num_samples=20, num_chains=1)
result = mc.run(model, x_in_mcmc, obs_mcmc)

print(f"acceptance:  {float(result.acceptance_rates.mean()):.2f}")
print(f"divergences: {int(result.divergence_counts.sum())}")

Categorical Perspective

The model denotes a Kleisli morphism into the positive reals in the Giry monad's Kleisli category. The Weibull is the exponential family generalization with shape; the proportional-hazards link makes the model canonical in the exponential family representation.

References

  • John P. Klein and Melvin L. Moeschberger. 2003. Survival Analysis: Techniques for Censored and Truncated Data, 2nd edition. Springer.