Continuous Distribution Families

quivers ships 30+ built-in conditional distribution families. Each takes an input (the domain) and produces learnable parameters via a parameter network. This page groups the registry by output shape and details the families whose event structure interacts with the axis-role surface. The spaces and base ContinuousMorphism interface live in continuous spaces.

Family registry

Hand-written core families

from quivers.continuous.families import (
    ConditionalNormal,
    ConditionalLogitNormal,
    ConditionalBeta,
    ConditionalTruncatedNormal,
    ConditionalDirichlet,
)
from quivers.continuous.spaces import Euclidean
from quivers.core.objects import FinSet
import torch

# Input: finite set X
X = FinSet(name="X", cardinality=5)
domain = Euclidean(name="X", dim=5)        # or FinSet
codomain = Euclidean(name="Y", dim=3)

# Conditional normal: learns linear maps mu(x), log_sigma(x)
normal = ConditionalNormal(domain, codomain)

# Sample from p(y | x)
x = torch.randn(8, 5)                              # batch of 8 inputs
samples = normal.rsample(x, sample_shape=torch.Size((100,)))  # (100, 8, 3)

# Log probability
y = torch.randn(8, 3)
log_p = normal.log_prob(x, y)  # shape (8,)

Loc-scale families

Standard reparameterizable distributions with learned location and scale:

from quivers.continuous.families import (
    ConditionalCauchy,
    ConditionalLaplace,
    ConditionalGumbel,
    ConditionalLogNormal,
    ConditionalStudentT,
)

Positive-valued families

For \(\mathbb{R}_{>0}\) output:

import torch
from quivers.continuous.families import (
    ConditionalExponential,
    ConditionalGamma,
    ConditionalWeibull,
    ConditionalPareto,
    ConditionalInverseGamma,
    ConditionalHalfCauchy,
    ConditionalHalfNormal,
)
from quivers.continuous.spaces import Euclidean

domain = Euclidean(name="X", dim=5)
codomain = Euclidean(name="Y", dim=3)
x = torch.randn(8, 5)

gamma = ConditionalGamma(domain, codomain)
samples = gamma.rsample(x)  # positive

Unit-interval families

For \((0, 1)\) output:

from quivers.continuous.families import (
    ConditionalBeta,
    ConditionalKumaraswamy,
    ConditionalContinuousBernoulli,
)

Multivariate families

from quivers.continuous.families import (
    ConditionalMultivariateNormal,
    ConditionalLowRankMVN,
    ConditionalDirichlet,
    ConditionalWishart,
    ConditionalInverseWishart,
    ConditionalMatrixNormal,
    ConditionalLKJCholesky,
    ConditionalGaussianProcess,
    ConditionalHorseshoe,
)
from quivers.continuous.spaces import Euclidean

domain = Euclidean(name="X", dim=5)

# Multivariate normal with learned mean and cov
mvn = ConditionalMultivariateNormal(domain, Euclidean(name="Y", dim=5))

Discrete / categorical families

from quivers.continuous.families import (
    ConditionalBernoulli,
    ConditionalCategorical,
    ConditionalRelaxedBernoulli,
    ConditionalRelaxedOneHotCategorical,
)

Event rank and the axis-role surface

event_rank per family controls the axis-role surface in the DSL:

Family Event rank Categorical reading
Normal, Beta, Gamma, Exponential, LogNormal, LogitNormal, Cauchy, Laplace, Gumbel, StudentT, Weibull, Pareto, Kumaraswamy, ContinuousBernoulli, HalfNormal, HalfCauchy, InverseGamma, TruncatedNormal, Uniform, Bernoulli 0 Scalar; every codomain axis is iid by default
MultivariateNormal, LowRankMVN, Dirichlet, OneHotCategorical, RelaxedOneHotCategorical, LogisticNormal, GP, Horseshoe 1 Vector; one named event axis carries the joint distribution
Wishart, InverseWishart, MatrixNormal, LKJCholesky 2 Matrix; two named event axes carry the joint distribution

The DSL surface ~ Family over <axes> requires the axis count to match the family's event rank exactly; mismatch is a compile-time error. In particular, a flat MVN over \(\dim(A) \cdot \dim(B)\) (dense covariance, event_rank 1 with a single named axis whose dim equals the product) is categorically distinct from a MatrixNormal over (A, B) (Kronecker structure \(V \otimes U\), event_rank 2); the surface keeps the two distinguishable rather than auto-substituting.

Structured priors over weight matrices

MatrixNormal: Kronecker-covariance matrix prior

from quivers.continuous.families import ConditionalMatrixNormal
from quivers.continuous.spaces import Euclidean

domain = Euclidean(name="X", dim=5)

# Matrix-valued kernel: domain -> R^(rows*cols), with samples
# reshaped to (rows, cols) and Kronecker covariance Sigma = V (x) U.
mn = ConditionalMatrixNormal(
    domain, Euclidean(name="W", dim=4 * 8), rows=4, cols=8
)

The matrix-Normal MN(M, U, V) is the natural prior for a weight matrix \(W : \mathbb{R}^d \to \mathbb{R}^k\) whose row and column correlations factor separately. When used as a latent-morphism prior in the DSL โ€” morphism W : Euclidean(D) -> Euclidean(K) [role=latent, over=[dom, cod]] ~ MatrixNormal(loc, row_scale, col_scale) โ€” the first axis listed in over binds the row covariance and the second the column covariance.

InverseWishart: conjugate covariance prior

from quivers.continuous.families import ConditionalInverseWishart
from quivers.continuous.spaces import Euclidean

domain = Euclidean(name="X", dim=5)

# Conjugate prior on a d-dim covariance matrix.  Realized as the
# inversion of a Wishart sample with the correct symmetric-matrix
# change-of-variables Jacobian.
d = 4
iw = ConditionalInverseWishart(domain, Euclidean(name="Sigma", dim=d))

Conjugate prior for the covariance of a multivariate normal (Gelman, Carlin, Stern, Dunson, Vehtari, Rubin 2013, ยง3.6). The change-of-variables Jacobian for inverting a positive-definite symmetric matrix contributes a -(d + 1) log det(Sigma) term to the log-density.

LKJCholesky: correlation-matrix prior

from quivers.continuous.families import ConditionalLKJCholesky
from quivers.continuous.spaces import Euclidean

domain = Euclidean(name="X", dim=5)
d = 4

lkj = ConditionalLKJCholesky(
    domain, Euclidean(name="L", dim=d),
)

LKJ-distributed Cholesky factors on the manifold of correlation matrices. The matrix dimension comes from codomain.dim; the concentration parameter is learned from the input by the family's parameter network (concentration 1 gives the uniform distribution over correlation matrices).

Wishart

from quivers.continuous.families import ConditionalWishart
from quivers.continuous.spaces import Euclidean

domain = Euclidean(name="X", dim=5)
d = 4

wishart = ConditionalWishart(
    domain, Euclidean(name="Sigma", dim=d),
)

Non-parametric and structured families

ConditionalGaussianProcess

A Gaussian process on \(\mathbb{R}^D\) is a Markov kernel \(X^N \to \mathcal{G}(\mathbb{R}^N)\) whose value at any finite collection of input locations \(x_1, \dots, x_N \in \mathbb{R}^D\) is a joint multivariate Normal with covariance \(K(x_i, x_j)\). Categorically, the family departs from the other parametric families: its "parameters" at each evaluation are the input locations themselves, not a small neural-net summary of them. ConditionalGaussianProcess exposes three covariance kernels (RBF, Matern 5/2, linear) with learnable length scale and amplitude.

from quivers.continuous.families import ConditionalGaussianProcess
from quivers.continuous.spaces import Euclidean
import torch

D, N = 2, 8
gp = ConditionalGaussianProcess(
    Euclidean(name="X", dim=D),
    Euclidean(name="Y", dim=N),
    kernel="rbf",
    length_scale=0.5,
    amplitude=1.0,
)
x = torch.randn(N, D)            # N input locations in R^D
f = gp.rsample(x)                # one GP draw at those inputs, shape (N,)
log_p = gp.log_prob(x, f)        # MVN(0, K(x, x) + jitter * I) density

The DSL registers the family as GP with event rank 1; one named axis (the input-location count \(N\)) carries the joint distribution. Reference: Rasmussen and Williams (2006).

ConditionalHorseshoe

The horseshoe prior is a sparse-signal prior whose hierarchical form is

\[ \begin{aligned} \tau &\sim \mathrm{HalfCauchy}(\text{scale}) \\ \lambda_d &\sim \mathrm{HalfCauchy}(1) \quad \text{for } d = 1, \dots, D \\ \beta_d \mid \tau, \lambda_d &\sim \mathcal{N}\!\bigl(0, (\tau \lambda_d)^2\bigr). \end{aligned} \]

The marginal density of \(\beta_d\) after integrating the local scale \(\lambda_d\) has no closed form; the Carvalho, Polson, Scott (2010) bound is improper. ConditionalHorseshoe returns the exact marginal via 16-point Gauss-Legendre quadrature after the change of variables \(\lambda = \tan(\pi t / 2)\) which maps \((0, \infty)\) to \((0, 1)\) with Jacobian \((\pi / 2) \sec^2(\pi t / 2)\).

import torch
from quivers.continuous.families import ConditionalHorseshoe
from quivers.continuous.spaces import Euclidean
from quivers.core.objects import FinSet

hs = ConditionalHorseshoe(
    FinSet(name="A", cardinality=4),
    Euclidean(name="beta", dim=10),
    scale=0.1,
)
x = torch.tensor([0, 1, 2, 3])
beta = hs.rsample(x)             # shape (4, 10)
log_p = hs.log_prob(x, beta)     # shape (4,), summed over coordinates

The DSL registers the family as Horseshoe with event rank 1; the one named axis carries the per-coordinate marginal product.

Choosing among multivariate priors

Goal Family Why
Generic multivariate posterior with dense covariance ConditionalMultivariateNormal Standard; learnable Cholesky factor.
Weight matrix with separable row / column correlations ConditionalMatrixNormal Kronecker structure \(V \otimes U\) has \(r^2 + c^2\) params instead of \((rc)^2\).
Conjugate covariance for a downstream MVN ConditionalInverseWishart Pairs with ConditionalMultivariateNormal for closed-form posterior updates.
Correlation matrix in a hierarchical model ConditionalLKJCholesky Prior on correlations alone, leaving scales free.
Non-parametric regression with smoothness control ConditionalGaussianProcess Prior over functions; bandwidth from kernel length-scale.
Sparse signal selection (e.g. variable selection in regression) ConditionalHorseshoe Strong shrinkage toward zero with heavy tails for large signals.

See also

References

  • Andrew Gelman, John B. Carlin, Hal S. Stern, David B. Dunson, Aki Vehtari, and Donald B. Rubin. 2013. Bayesian Data Analysis, 3rd edition. Chapman & Hall/CRC.
  • Carlos M. Carvalho, Nicholas G. Polson, and James G. Scott. 2010. The horseshoe estimator for sparse signals. Biometrika, 97(2):465โ€“480.