Continuous Families

Families of continuous-valued distributions.

families

Parameterized distribution families as continuous morphisms.

Each family is a ContinuousMorphism whose codomain is a continuous space and whose conditional distribution p(y | x) belongs to a specific parametric family. The parameters are learnable functions of x:

  • For discrete domains (FinSet): parameters are looked up from a table.
  • For continuous domains (ContinuousSpace): parameters are produced by a small neural network.

This module wraps every reparameterizable distribution in torch.distributions as a conditional morphism, plus custom families (TruncatedNormal, MultivariateNormal, etc.).

Architecture

Most per-dimension-independent distributions are built on a shared generic base _IndependentConditional that handles the parameter source, transform, and torch.distributions plumbing. The _make_family class factory generates named classes from a specification. Distributions that need special handling (MultivariateNormal, Dirichlet, TruncatedNormal, etc.) are implemented as standalone classes.

ConditionalNormal

ConditionalNormal(domain: AnySpace, codomain: ContinuousSpace, hidden_dim: int = 64)

Bases: ContinuousMorphism

Conditional normal (Gaussian) distribution.

For each input x, produces an independent normal distribution on each dimension of the codomain:

y_i ~ Normal(mu_i(x), sigma_i(x))

Parameters are learnable: mu and log(sigma) are functions of x, implemented as lookup tables (discrete domain) or neural networks (continuous domain).

PARAMETER DESCRIPTION
domain

Source space.

TYPE: SetObject or ContinuousSpace

codomain

Target space.

TYPE: Euclidean

hidden_dim

Hidden layer width for neural parameter source.

TYPE: int DEFAULT: 64

Examples:

>>> from quivers import FinSet
>>> from quivers.continuous.spaces import Euclidean
>>> A = FinSet(name="context", cardinality=5)
>>> Y = Euclidean(name="response", dim=3)
>>> f = ConditionalNormal(A, Y)
>>> x = torch.tensor([0, 1, 2])
>>> samples = f.rsample(x)  # shape (3, 3)
Source code in src/quivers/continuous/families.py
306
307
308
309
310
311
312
313
314
315
316
317
def __init__(
    self,
    domain: AnySpace,
    codomain: ContinuousSpace,
    hidden_dim: int = 64,
) -> None:
    super().__init__(domain, codomain)
    d = codomain.dim

    # param_dim = d (mu) + d (log_sigma)
    self.param_source = _make_source(domain, 2 * d, hidden_dim)
    self._d = d

ConditionalLogitNormal

ConditionalLogitNormal(domain: AnySpace, codomain: ContinuousSpace, hidden_dim: int = 64)

Bases: ContinuousMorphism

Conditional logit-normal distribution on (0, 1)^d.

If z ~ Normal(mu(x), sigma(x)), then y = sigmoid(z) ~ LogitNormal. Useful for modeling probabilities and bounded quantities.

PARAMETER DESCRIPTION
domain

Source space.

TYPE: SetObject or ContinuousSpace

codomain

Target space (should have bounds [0, 1]).

TYPE: Euclidean

hidden_dim

Hidden layer width for neural parameter source.

TYPE: int DEFAULT: 64

Source code in src/quivers/continuous/families.py
382
383
384
385
386
387
388
389
390
391
def __init__(
    self,
    domain: AnySpace,
    codomain: ContinuousSpace,
    hidden_dim: int = 64,
) -> None:
    super().__init__(domain, codomain)
    d = codomain.dim
    self.param_source = _make_source(domain, 2 * d, hidden_dim)
    self._d = d

ConditionalBeta

ConditionalBeta(domain: AnySpace, codomain: ContinuousSpace, hidden_dim: int = 64)

Bases: ContinuousMorphism

Conditional beta distribution on (0, 1)^d.

For each input x, produces an independent Beta(alpha_i(x), beta_i(x)) on each dimension of the codomain.

PARAMETER DESCRIPTION
domain

Source space.

TYPE: SetObject or ContinuousSpace

codomain

Target space (should have bounds [0, 1]).

TYPE: Euclidean

hidden_dim

Hidden layer width for neural parameter source.

TYPE: int DEFAULT: 64

Source code in src/quivers/continuous/families.py
460
461
462
463
464
465
466
467
468
469
def __init__(
    self,
    domain: AnySpace,
    codomain: ContinuousSpace,
    hidden_dim: int = 64,
) -> None:
    super().__init__(domain, codomain)
    d = codomain.dim
    self.param_source = _make_source(domain, 2 * d, hidden_dim)
    self._d = d

ConditionalTruncatedNormal

ConditionalTruncatedNormal(domain: AnySpace, codomain: Euclidean, hidden_dim: int = 64)

Bases: ContinuousMorphism

Conditional truncated normal on [low, high]^d.

A normal distribution restricted to a bounded interval. Uses rejection-free sampling via the inverse CDF (Phi-based) method.

PARAMETER DESCRIPTION
domain

Source space.

TYPE: SetObject or ContinuousSpace

codomain

Target space (must be bounded).

TYPE: Euclidean

hidden_dim

Hidden layer width for neural parameter source.

TYPE: int DEFAULT: 64

Source code in src/quivers/continuous/families.py
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
def __init__(
    self,
    domain: AnySpace,
    codomain: Euclidean,
    hidden_dim: int = 64,
) -> None:
    if codomain.low is None or codomain.high is None:
        raise ValueError("ConditionalTruncatedNormal requires a bounded codomain")

    super().__init__(domain, codomain)
    d = codomain.dim
    self.param_source = _make_source(domain, 2 * d, hidden_dim)
    self._d = d
    self._low = codomain.low
    self._high = codomain.high

ConditionalDirichlet

ConditionalDirichlet(domain: AnySpace, codomain: ContinuousSpace, hidden_dim: int = 64)

Bases: ContinuousMorphism

Conditional Dirichlet distribution on a probability simplex.

For each input x, produces a Dirichlet(alpha(x)) distribution on the simplex.

PARAMETER DESCRIPTION
domain

Source space.

TYPE: SetObject or ContinuousSpace

codomain

Target simplex.

TYPE: Simplex

hidden_dim

Hidden layer width for neural parameter source.

TYPE: int DEFAULT: 64

Source code in src/quivers/continuous/families.py
607
608
609
610
611
612
613
614
615
616
def __init__(
    self,
    domain: AnySpace,
    codomain: ContinuousSpace,
    hidden_dim: int = 64,
) -> None:
    super().__init__(domain, codomain)
    d = codomain.dim
    self.param_source = _make_source(domain, d, hidden_dim)
    self._d = d

ConditionalUniform

ConditionalUniform(domain: AnySpace, codomain: ContinuousSpace, hidden_dim: int = 64)

Bases: ContinuousMorphism

Conditional uniform distribution on a learnable interval.

Parameterized as Uniform(loc - width/2, loc + width/2) where loc is unconstrained and width is positive. This ensures low < high is always satisfied.

PARAMETER DESCRIPTION
domain

Source space.

TYPE: SetObject or ContinuousSpace

codomain

Target space.

TYPE: ContinuousSpace

hidden_dim

Hidden layer width for neural parameter source.

TYPE: int DEFAULT: 64

Source code in src/quivers/continuous/families.py
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
def __init__(
    self,
    domain: AnySpace,
    codomain: ContinuousSpace,
    hidden_dim: int = 64,
) -> None:
    super().__init__(domain, codomain)
    d = codomain.dim
    # param_dim = d (loc) + d (raw_width)
    self.param_source = _make_source(domain, 2 * d, hidden_dim)
    self._d = d
    # The bounds are data-dependent, so we cannot pin a single
    # interval at construction time; advertise the codomain's
    # declared bounds when available, otherwise fall back to the
    # real line.
    low, high = getattr(codomain, "low", None), getattr(codomain, "high", None)
    if low is not None and high is not None:
        self._support_cache: _constraints.Constraint = _constraints.interval(
            float(low), float(high)
        )
    else:
        self._support_cache = _constraints.real

ConditionalMultivariateNormal

ConditionalMultivariateNormal(domain: AnySpace, codomain: ContinuousSpace, hidden_dim: int = 64)

Bases: ContinuousMorphism

Conditional multivariate normal with full covariance.

Parameterized via Cholesky factor: the parameter source outputs loc (d values) and the lower-triangular entries of L (d*(d+1)/2 values), where Sigma = L @ L^T.

PARAMETER DESCRIPTION
domain

Source space.

TYPE: SetObject or ContinuousSpace

codomain

Target space (d-dimensional).

TYPE: ContinuousSpace

hidden_dim

Hidden layer width for neural parameter source.

TYPE: int DEFAULT: 64

Source code in src/quivers/continuous/families.py
894
895
896
897
898
899
900
901
902
903
904
905
def __init__(
    self,
    domain: AnySpace,
    codomain: ContinuousSpace,
    hidden_dim: int = 64,
) -> None:
    super().__init__(domain, codomain)
    d = codomain.dim
    n_tril = d * (d + 1) // 2
    self.param_source = _make_source(domain, d + n_tril, hidden_dim)
    self._d = d
    self._n_tril = n_tril

ConditionalLowRankMVN

ConditionalLowRankMVN(domain: AnySpace, codomain: ContinuousSpace, rank: int = 2, hidden_dim: int = 64)

Bases: ContinuousMorphism

Conditional low-rank multivariate normal.

Parameterized as loc + low-rank factor + diagonal: Sigma = W @ W^T + diag(d)

This is more parameter-efficient than full MVN for high dimensions.

PARAMETER DESCRIPTION
domain

Source space.

TYPE: SetObject or ContinuousSpace

codomain

Target space (d-dimensional).

TYPE: ContinuousSpace

rank

Rank of the low-rank factor W.

TYPE: int DEFAULT: 2

hidden_dim

Hidden layer width for neural parameter source.

TYPE: int DEFAULT: 64

Source code in src/quivers/continuous/families.py
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
def __init__(
    self,
    domain: AnySpace,
    codomain: ContinuousSpace,
    rank: int = 2,
    hidden_dim: int = 64,
) -> None:
    super().__init__(domain, codomain)
    d = codomain.dim
    self._d = d
    self._rank = rank

    # loc (d) + factor (d * rank) + diag (d)
    total = d + d * rank + d
    self.param_source = _make_source(domain, total, hidden_dim)

ConditionalRelaxedBernoulli

ConditionalRelaxedBernoulli(domain: AnySpace, codomain: ContinuousSpace, temperature: float = 0.5, hidden_dim: int = 64)

Bases: ContinuousMorphism

Conditional relaxed Bernoulli (concrete) distribution.

Outputs continuous values in (0, 1) that approximate Bernoulli samples. The temperature controls the relaxation: lower temperature = closer to discrete.

PARAMETER DESCRIPTION
domain

Source space.

TYPE: SetObject or ContinuousSpace

codomain

Target space (should be 1-d per Bernoulli component).

TYPE: ContinuousSpace

temperature

Relaxation temperature.

TYPE: float DEFAULT: 0.5

hidden_dim

Hidden layer width for neural parameter source.

TYPE: int DEFAULT: 64

Source code in src/quivers/continuous/families.py
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
def __init__(
    self,
    domain: AnySpace,
    codomain: ContinuousSpace,
    temperature: float = 0.5,
    hidden_dim: int = 64,
) -> None:
    super().__init__(domain, codomain)
    d = codomain.dim
    self.param_source = _make_source(domain, d, hidden_dim)
    self._d = d
    self._temperature = temperature

ConditionalRelaxedOneHotCategorical

ConditionalRelaxedOneHotCategorical(domain: AnySpace, codomain: ContinuousSpace, temperature: float = 0.5, hidden_dim: int = 64)

Bases: ContinuousMorphism

Conditional relaxed one-hot categorical (Gumbel-Softmax).

Outputs continuous vectors on the simplex that approximate one-hot categorical samples.

PARAMETER DESCRIPTION
domain

Source space.

TYPE: SetObject or ContinuousSpace

codomain

Target space (simplex or d-dimensional).

TYPE: ContinuousSpace

temperature

Relaxation temperature.

TYPE: float DEFAULT: 0.5

hidden_dim

Hidden layer width for neural parameter source.

TYPE: int DEFAULT: 64

Source code in src/quivers/continuous/families.py
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
def __init__(
    self,
    domain: AnySpace,
    codomain: ContinuousSpace,
    temperature: float = 0.5,
    hidden_dim: int = 64,
) -> None:
    super().__init__(domain, codomain)
    d = codomain.dim
    self.param_source = _make_source(domain, d, hidden_dim)
    self._d = d
    self._temperature = temperature

ConditionalWishart

ConditionalWishart(domain: AnySpace, codomain: ContinuousSpace, hidden_dim: int = 64)

Bases: ContinuousMorphism

Conditional Wishart distribution over positive-definite matrices.

Produces random d x d positive-definite matrices. Parameterized by degrees of freedom df(x) and a scale matrix V(x).

The codomain dimension is interpreted as d, and outputs are d x d matrices flattened to d^2.

PARAMETER DESCRIPTION
domain

Source space.

TYPE: SetObject or ContinuousSpace

codomain

Target space. dim is the matrix size d (output is d x d).

TYPE: ContinuousSpace

hidden_dim

Hidden layer width for neural parameter source.

TYPE: int DEFAULT: 64

Source code in src/quivers/continuous/families.py
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
def __init__(
    self,
    domain: AnySpace,
    codomain: ContinuousSpace,
    hidden_dim: int = 64,
) -> None:
    super().__init__(domain, codomain)
    d = codomain.dim
    n_tril = d * (d + 1) // 2
    # df (1) + lower-triangular scale (n_tril)
    self.param_source = _make_source(domain, 1 + n_tril, hidden_dim)
    self._d = d
    self._n_tril = n_tril

ConditionalMatrixNormal

ConditionalMatrixNormal(domain: AnySpace, codomain: ContinuousSpace, rows: int, cols: int, hidden_dim: int = 64)

Bases: ContinuousMorphism

Conditional Matrix-Normal :math:\mathcal{MN}(M, U, V).

The matrix-Normal distribution on :math:\mathbb{R}^{n \times p} factorises with a Kronecker-product covariance: if :math:X \sim \mathcal{MN}(M, U, V) then :math:\mathrm{vec}(X) \sim \mathcal{N}(\mathrm{vec}(M), V \otimes U) with :math:U \in \mathbb{R}^{n \times n} the row covariance and :math:V \in \mathbb{R}^{p \times p} the column covariance.

Categorically, the Kronecker structure is the tensor product of two Gaussians; it is strictly more constrained than the flat :math:np-dim MVN that the same parameter tensor could carry, so the surface keeps the two families distinct (no auto-substitution). Use this when the prior should express independent row and column correlation structure.

The codomain's product factorisation supplies the row and column dimensions. The grammar surface ~ MatrixNormal(loc, row_scale, col_scale) over (dom, cod) binds the first axis listed in over to the row covariance and the second to the column covariance.

PARAMETER DESCRIPTION
domain

Source space.

TYPE: AnySpace

codomain

Target space whose factorisation supplies (n, p). Must carry a product structure of two factors; the first is the row axis, the second the column.

TYPE: ContinuousSpace

rows

Row dimension :math:n.

TYPE: int

cols

Column dimension :math:p.

TYPE: int

hidden_dim

Hidden layer width for the parameter network.

TYPE: int DEFAULT: 64

Source code in src/quivers/continuous/families.py
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
def __init__(
    self,
    domain: AnySpace,
    codomain: ContinuousSpace,
    rows: int,
    cols: int,
    hidden_dim: int = 64,
) -> None:
    super().__init__(domain, codomain)
    self._rows = int(rows)
    self._cols = int(cols)
    n_loc = self._rows * self._cols
    n_row_tril = self._rows * (self._rows + 1) // 2
    n_col_tril = self._cols * (self._cols + 1) // 2
    self._n_loc = n_loc
    self._n_row_tril = n_row_tril
    self._n_col_tril = n_col_tril
    self.param_source = _make_source(
        domain, n_loc + n_row_tril + n_col_tril, hidden_dim
    )

ConditionalInverseWishart

ConditionalInverseWishart(domain: AnySpace, codomain: ContinuousSpace, hidden_dim: int = 64)

Bases: ContinuousMorphism

Conditional Inverse-Wishart over positive-definite matrices.

Conjugate prior for the covariance of a multivariate normal. Realised as a deterministic inversion of a Wishart sample so autograd flows; equivalent in distribution to drawing :math:\Sigma^{-1} \sim \mathcal{W}(\nu, V^{-1}) and inverting. See Gelman et al. (2013) §3.6 for the conjugacy statement.

PARAMETER DESCRIPTION
domain

Source space.

TYPE: AnySpace

codomain

Target space whose dim is the matrix size :math:d.

TYPE: ContinuousSpace

hidden_dim

Hidden layer width for the parameter network.

TYPE: int DEFAULT: 64

Source code in src/quivers/continuous/families.py
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
def __init__(
    self,
    domain: AnySpace,
    codomain: ContinuousSpace,
    hidden_dim: int = 64,
) -> None:
    super().__init__(domain, codomain)
    d = codomain.dim
    n_tril = d * (d + 1) // 2
    self._d = d
    self._n_tril = n_tril
    self.param_source = _make_source(domain, 1 + n_tril, hidden_dim)

ConditionalBernoulli

ConditionalBernoulli(domain: AnySpace, codomain: AnySpace, hidden_dim: int = 64)

Bases: ContinuousMorphism

Conditional Bernoulli: continuous probability -> discrete truth value.

Takes a continuous input x and produces learnable logits that parameterize a Bernoulli distribution. The output is a discrete sample in {0, 1}, returned as a LongTensor.

This is the key bridge used in PDS (Grove & White) for the Bern x pattern, where a LogitNormal draw x in (0,1) parameterizes a Bernoulli over truth values.

The codomain must be a FinSet of size 2 (representing {False, True} or {0, 1}).

Note

Sampling from Bernoulli is NOT reparameterizable. Gradients do not flow through the discrete samples. Use score function estimators (REINFORCE) or the Gumbel-Softmax trick if differentiable samples are needed.

PARAMETER DESCRIPTION
domain

Source space (typically UnitInterval or a FinSet).

TYPE: SetObject or ContinuousSpace

codomain

Target FinSet of size 2.

TYPE: SetObject

hidden_dim

Hidden layer width for neural parameter source.

TYPE: int DEFAULT: 64

Source code in src/quivers/continuous/families.py
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
def __init__(
    self,
    domain: AnySpace,
    codomain: AnySpace,
    hidden_dim: int = 64,
) -> None:
    from quivers.core.objects import SetObject

    if not isinstance(codomain, SetObject) or codomain.size != 2:
        raise ValueError(
            f"ConditionalBernoulli requires a FinSet(2) codomain, got {codomain!r}"
        )

    super().__init__(domain, codomain)

    # one logit per input
    self.param_source = _make_source(domain, 1, hidden_dim)

log_prob

log_prob(x: Tensor, y: Tensor) -> Tensor

Log-probability of discrete output y given input x.

PARAMETER DESCRIPTION
x

Input tensor.

TYPE: Tensor

y

Discrete output in {0, 1}. Shape (batch,).

TYPE: Tensor

RETURNS DESCRIPTION
Tensor

Log-probabilities. Shape (batch,).

Source code in src/quivers/continuous/families.py
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
def log_prob(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """Log-probability of discrete output y given input x.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor.
    y : torch.Tensor
        Discrete output in {0, 1}. Shape (batch,).

    Returns
    -------
    torch.Tensor
        Log-probabilities. Shape (batch,).
    """
    probs = self._get_probs(x)
    dist = D.Bernoulli(probs=probs)
    return dist.log_prob(y.float())

rsample

rsample(x: Tensor, sample_shape: Size = Size()) -> Tensor

Sample from Bernoulli (not reparameterizable).

PARAMETER DESCRIPTION
x

Input tensor.

TYPE: Tensor

sample_shape

Additional leading sample dimensions.

TYPE: Size DEFAULT: Size()

RETURNS DESCRIPTION
Tensor

Discrete samples in {0, 1}. Shape (*sample_shape, batch).

Source code in src/quivers/continuous/families.py
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
def rsample(
    self,
    x: torch.Tensor,
    sample_shape: torch.Size = torch.Size(),
) -> torch.Tensor:
    """Sample from Bernoulli (not reparameterizable).

    Parameters
    ----------
    x : torch.Tensor
        Input tensor.
    sample_shape : torch.Size
        Additional leading sample dimensions.

    Returns
    -------
    torch.Tensor
        Discrete samples in {0, 1}. Shape (*sample_shape, batch).
    """
    probs = self._get_probs(x)
    dist = D.Bernoulli(probs=probs)
    return dist.sample(sample_shape).long()

ConditionalCategorical

ConditionalCategorical(domain: AnySpace, codomain: AnySpace, hidden_dim: int = 64)

Bases: ContinuousMorphism

Conditional Categorical: continuous input -> discrete category.

Generalizes ConditionalBernoulli to k > 2 categories. Takes a continuous input and produces learnable logits over k categories. The output is a discrete sample in {0, ..., k-1}.

PARAMETER DESCRIPTION
domain

Source space.

TYPE: SetObject or ContinuousSpace

codomain

Target FinSet of size k.

TYPE: SetObject

hidden_dim

Hidden layer width for neural parameter source.

TYPE: int DEFAULT: 64

Source code in src/quivers/continuous/families.py
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
def __init__(
    self,
    domain: AnySpace,
    codomain: AnySpace,
    hidden_dim: int = 64,
) -> None:
    from quivers.core.objects import SetObject

    if not isinstance(codomain, SetObject):
        raise ValueError(
            f"ConditionalCategorical requires a FinSet codomain, got {codomain!r}"
        )

    super().__init__(domain, codomain)
    self._k = codomain.size
    self.param_source = _make_source(domain, self._k, hidden_dim)

support property

support

Discrete-integer support over {0, …, k-1}.

log_prob

log_prob(x: Tensor, y: Tensor) -> Tensor

Log-probability of discrete output y given input x.

PARAMETER DESCRIPTION
x

Input tensor.

TYPE: Tensor

y

Discrete output in {0, ..., k-1}. Shape (batch,).

TYPE: Tensor

RETURNS DESCRIPTION
Tensor

Log-probabilities. Shape (batch,).

Source code in src/quivers/continuous/families.py
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
def log_prob(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """Log-probability of discrete output y given input x.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor.
    y : torch.Tensor
        Discrete output in {0, ..., k-1}. Shape (batch,).

    Returns
    -------
    torch.Tensor
        Log-probabilities. Shape (batch,).
    """
    logits = self._get_logits(x)
    dist = D.Categorical(logits=logits)
    return dist.log_prob(y.long())

rsample

rsample(x: Tensor, sample_shape: Size = Size()) -> Tensor

Sample from Categorical (not reparameterizable).

PARAMETER DESCRIPTION
x

Input tensor.

TYPE: Tensor

sample_shape

Additional leading sample dimensions.

TYPE: Size DEFAULT: Size()

RETURNS DESCRIPTION
Tensor

Discrete samples in {0, ..., k-1}. Shape (*sample_shape, batch).

Source code in src/quivers/continuous/families.py
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
def rsample(
    self,
    x: torch.Tensor,
    sample_shape: torch.Size = torch.Size(),
) -> torch.Tensor:
    """Sample from Categorical (not reparameterizable).

    Parameters
    ----------
    x : torch.Tensor
        Input tensor.
    sample_shape : torch.Size
        Additional leading sample dimensions.

    Returns
    -------
    torch.Tensor
        Discrete samples in {0, ..., k-1}. Shape (*sample_shape, batch).
    """
    logits = self._get_logits(x)
    dist = D.Categorical(logits=logits)
    return dist.sample(sample_shape).long()

ConditionalBinomial

ConditionalBinomial(domain: AnySpace, codomain: ContinuousSpace, total_count: int = 1, hidden_dim: int = 64)

Bases: ContinuousMorphism

Conditional Binomial(total_count, probs(x)).

The total_count (number of trials) is a fixed hyperparameter set at construction time — typical for binomial likelihoods where n is known per observation. Only the probs parameter is learnable.

Outputs integer counts in {0, 1, ..., total_count}.

PARAMETER DESCRIPTION
domain

Source space.

TYPE: SetObject or ContinuousSpace

codomain

Target space.

TYPE: ContinuousSpace

total_count

Number of Bernoulli trials per observation.

TYPE: int DEFAULT: 1

hidden_dim

Hidden layer width for the parameter source.

TYPE: int DEFAULT: 64

Source code in src/quivers/continuous/families.py
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
def __init__(
    self,
    domain: AnySpace,
    codomain: ContinuousSpace,
    total_count: int = 1,
    hidden_dim: int = 64,
) -> None:
    if total_count < 1:
        raise ValueError(
            f"ConditionalBinomial: total_count must be >= 1, got {total_count}"
        )
    super().__init__(domain, codomain)
    d = codomain.dim
    self._d = d
    self._total_count = int(total_count)
    self.param_source = _make_source(domain, d, hidden_dim)

ConditionalLogisticNormal

ConditionalLogisticNormal(domain: AnySpace, codomain: ContinuousSpace, hidden_dim: int = 64)

Bases: ContinuousMorphism

Conditional LogisticNormal on the simplex.

Pushes a Normal(loc(x), scale(x)) draw through the softmax transform to produce a simplex-valued sample. Multivariate analogue of ConditionalLogitNormal. Useful as an alternative to ConditionalDirichlet when the underlying simplex distribution should be Gaussian in logit space rather than Beta-shaped.

PARAMETER DESCRIPTION
domain

Source space.

TYPE: SetObject or ContinuousSpace

codomain

Target space; codomain.dim is the simplex dimension.

TYPE: ContinuousSpace

hidden_dim

Hidden layer width for the parameter source.

TYPE: int DEFAULT: 64

Source code in src/quivers/continuous/families.py
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
def __init__(
    self,
    domain: AnySpace,
    codomain: ContinuousSpace,
    hidden_dim: int = 64,
) -> None:
    super().__init__(domain, codomain)
    d = codomain.dim
    # We use a Normal in (d-1)-dim space and the
    # StickBreakingTransform to land on the d-simplex.
    # torch.distributions.LogisticNormal handles this.
    self.param_source = _make_source(domain, 2 * (d - 1), hidden_dim)
    self._d = d

ConditionalOneHotCategorical

ConditionalOneHotCategorical(domain: AnySpace, codomain: ContinuousSpace, hidden_dim: int = 64)

Bases: ContinuousMorphism

Conditional OneHotCategorical(probs(x)).

Generalises ConditionalCategorical to one-hot encoded outputs (vector of zeros with a single one). Useful as a discrete-output observation kernel where downstream code wants a vector rather than an integer index.

PARAMETER DESCRIPTION
domain

Source space.

TYPE: SetObject or ContinuousSpace

codomain

Target space; codomain.dim is the number of categories.

TYPE: ContinuousSpace

hidden_dim

Hidden layer width for the parameter source.

TYPE: int DEFAULT: 64

Source code in src/quivers/continuous/families.py
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
def __init__(
    self,
    domain: AnySpace,
    codomain: ContinuousSpace,
    hidden_dim: int = 64,
) -> None:
    super().__init__(domain, codomain)
    d = codomain.dim
    self._d = d
    self.param_source = _make_source(domain, d, hidden_dim)

ConditionalMixture

ConditionalMixture(domain: AnySpace, codomain: ContinuousSpace, component_class: type, num_components: int = 4, hidden_dim: int = 64)

Bases: ContinuousMorphism

K-component mixture of a conditional family.

Wraps a single conditional family class (one of the registered ConditionalX types) and gives it K independent parameterizations plus learnable mixture logits, producing

p(y | x) = sum_k pi_k(x) * f_k(y | x)

where each f_k is an instance of the component class and pi is the softmax of K learnable logits.

Sampling is via ancestral simulation (Categorical pick + the chosen component's rsample). log_prob evaluates the log-sum-exp of the per-component log-densities. The Categorical pick is non-reparameterizable; gradient flow through the weights uses the score-function path (which higher-level objectives like IWAE can route through).

PARAMETER DESCRIPTION
domain

Source space.

TYPE: SetObject or ContinuousSpace

codomain

Target space; matches the component family's codomain.

TYPE: ContinuousSpace

component_class

A ConditionalX class accepting (domain, codomain, hidden_dim) constructor args.

TYPE: type

num_components

Number of mixture components.

TYPE: int DEFAULT: 4

hidden_dim

Hidden width for both the mixture-logit MLP and each component's parameter source.

TYPE: int DEFAULT: 64

Source code in src/quivers/continuous/families.py
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
def __init__(
    self,
    domain: AnySpace,
    codomain: ContinuousSpace,
    component_class: type,
    num_components: int = 4,
    hidden_dim: int = 64,
) -> None:
    if num_components < 2:
        raise ValueError(
            f"ConditionalMixture: num_components must be >= 2, got {num_components}"
        )
    super().__init__(domain, codomain)
    self._K = int(num_components)
    self._components = torch.nn.ModuleList(
        [component_class(domain, codomain, hidden_dim) for _ in range(self._K)]
    )
    self.mixture_logits = _make_source(domain, self._K, hidden_dim)

ConditionalIndependent

ConditionalIndependent(base: ContinuousMorphism)

Bases: ContinuousMorphism

Reinterpret the trailing batch dimension of a base conditional family as an event dimension.

Equivalent to wrapping the base distribution in torch.distributions.Independent with reinterpreted_batch_ndims = 1. Used to make per-element independence explicit when a downstream guide wants to score a vector-valued observation as a single event.

PARAMETER DESCRIPTION
base

The base conditional family. The wrapped distribution sums the base's per-element log-probabilities along the last axis to score a vector-valued observation.

TYPE: ContinuousMorphism

Source code in src/quivers/continuous/families.py
1901
1902
1903
def __init__(self, base: ContinuousMorphism) -> None:
    super().__init__(base.domain, base.codomain)
    self._base = base

ConditionalTransformed

ConditionalTransformed(base: ContinuousMorphism, transforms: list)

Bases: ContinuousMorphism

A base conditional family composed with a chain of bijectors.

Equivalent to torch.distributions.TransformedDistribution lifted to ContinuousMorphism. The transforms are applied in forward order to rsample outputs; log_prob includes the log-determinant Jacobian correction.

PARAMETER DESCRIPTION
base

Base conditional family.

TYPE: ContinuousMorphism

transforms

Bijectors applied in forward order. Each must implement __call__, inv, and log_abs_det_jacobian.

TYPE: list of torch.distributions.Transform

Source code in src/quivers/continuous/families.py
1943
1944
1945
1946
1947
1948
1949
1950
def __init__(
    self,
    base: ContinuousMorphism,
    transforms: list,
) -> None:
    super().__init__(base.domain, base.codomain)
    self._base = base
    self._transforms = list(transforms)

ConditionalLKJCholesky

ConditionalLKJCholesky(domain: AnySpace, codomain: ContinuousSpace, hidden_dim: int = 64)

Bases: ContinuousMorphism

Conditional LKJCholesky(dim, concentration(x)).

Produces lower-triangular Cholesky factors of correlation matrices on the LKJ distribution (Lewandowski-Kurowicka-Joe 2009, doi:10.1016/j.jmva.2009.04.008). The matrix dimension is taken from codomain.dim; only the concentration parameter is learnable.

PARAMETER DESCRIPTION
domain

Source space.

TYPE: SetObject or ContinuousSpace

codomain

Target space; codomain.dim is the correlation-matrix size.

TYPE: ContinuousSpace

hidden_dim

Hidden layer width for the parameter source.

TYPE: int DEFAULT: 64

Source code in src/quivers/continuous/families.py
2007
2008
2009
2010
2011
2012
2013
2014
2015
def __init__(
    self,
    domain: AnySpace,
    codomain: ContinuousSpace,
    hidden_dim: int = 64,
) -> None:
    super().__init__(domain, codomain)
    self._matrix_dim = codomain.dim
    self.param_source = _make_source(domain, 1, hidden_dim)

ConditionalGaussianProcess

ConditionalGaussianProcess(domain: AnySpace, codomain: ContinuousSpace, kernel: str = 'rbf', length_scale: float = 1.0, amplitude: float = 1.0, jitter: float = 1e-06)

Bases: ContinuousMorphism

Gaussian process prior with mean zero and a chosen covariance kernel.

A Gaussian process is a Markov kernel X^N -> G(R^N) whose value at a finite set of input locations x_1, ..., x_N follows a multivariate Normal with covariance matrix K(x_i, x_j). Unlike the parametric families that derive their distribution parameters from a neural network on the input, the GP's "parameters" are the input locations themselves: the kernel function evaluated on the inputs produces the covariance directly.

Reference: Rasmussen & Williams (2006), Gaussian Processes for Machine Learning.

PARAMETER DESCRIPTION
domain

Source space. Its dim is the per-location feature dimensionality D of the inputs.

TYPE: SetObject or ContinuousSpace

codomain

Target space. Its dim is the number of input locations N at which the GP is evaluated.

TYPE: ContinuousSpace

kernel

Covariance kernel. "rbf" is the squared-exponential kernel; "matern52" is the Matern kernel with smoothness nu = 5/2; "linear" is the inner-product kernel.

TYPE: ('rbf', 'matern52', 'linear') DEFAULT: "rbf"

length_scale

Initial length scale of the kernel (positive; learnable). Ignored by the linear kernel.

TYPE: float DEFAULT: 1.0

amplitude

Initial amplitude (positive; learnable). Multiplies the kernel by amplitude^2.

TYPE: float DEFAULT: 1.0

jitter

Diagonal regulariser added to K for numerical positive-definiteness of the Cholesky factorisation.

TYPE: float DEFAULT: 1e-06

Source code in src/quivers/continuous/families.py
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
def __init__(
    self,
    domain: AnySpace,
    codomain: ContinuousSpace,
    kernel: str = "rbf",
    length_scale: float = 1.0,
    amplitude: float = 1.0,
    jitter: float = 1e-6,
) -> None:
    super().__init__(domain, codomain)
    if kernel not in _GP_KERNEL_CHOICES:
        raise ValueError(
            f"ConditionalGaussianProcess: unknown kernel {kernel!r}; "
            f"valid choices: {_GP_KERNEL_CHOICES}"
        )
    if length_scale <= 0.0:
        raise ValueError(
            f"ConditionalGaussianProcess: length_scale must be > 0, got {length_scale!r}"
        )
    if amplitude <= 0.0:
        raise ValueError(
            f"ConditionalGaussianProcess: amplitude must be > 0, got {amplitude!r}"
        )
    self._kernel = kernel
    self._jitter = jitter
    self._n = codomain.dim
    self._d = getattr(domain, "dim", None)
    # Store raw (pre-softplus) parameters so the transformed value
    # is strictly positive and the optimiser sees unconstrained
    # variables.
    inv_softplus_ls = math.log(math.expm1(length_scale))
    inv_softplus_amp = math.log(math.expm1(amplitude))
    self._raw_length_scale = torch.nn.Parameter(
        torch.tensor(inv_softplus_ls, dtype=torch.get_default_dtype())
    )
    self._raw_amplitude = torch.nn.Parameter(
        torch.tensor(inv_softplus_amp, dtype=torch.get_default_dtype())
    )

length_scale property

length_scale: Tensor

Current (positive) length scale of the kernel.

amplitude property

amplitude: Tensor

Current (positive) amplitude of the kernel.

ConditionalHorseshoe

ConditionalHorseshoe(domain: AnySpace, codomain: ContinuousSpace, scale: float = 1.0)

Bases: ContinuousMorphism

Carvalho-Polson-Scott horseshoe prior.

The horseshoe prior places a global-local shrinkage structure on each coordinate:

.. code-block:: text

tau ~ HalfCauchy(scale) lambda_d ~ HalfCauchy(1) for d = 1, ..., D beta_d | tau, lambda_d ~ Normal(0, (tau * lambda_d)^2)

The marginal density of beta_d after integrating the local scale lambda_d has no closed form; this implementation uses a 16-point Gauss-Legendre quadrature after mapping the half-line lambda in (0, inf) to t in (0, 1) via the change of variables lambda = tan(pi * t / 2), whose Jacobian is (pi / 2) * sec^2(pi * t / 2).

Reference: Carvalho, Polson & Scott (2010), The horseshoe estimator for sparse signals.

PARAMETER DESCRIPTION
domain

Source space. The prior is conditionally independent of x; x only carries the batch shape.

TYPE: SetObject or ContinuousSpace

codomain

Target space. Its dim is the coordinate count D.

TYPE: ContinuousSpace

scale

Initial global shrinkage tau (positive; learnable).

TYPE: float DEFAULT: 1.0

Source code in src/quivers/continuous/families.py
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
def __init__(
    self,
    domain: AnySpace,
    codomain: ContinuousSpace,
    scale: float = 1.0,
) -> None:
    super().__init__(domain, codomain)
    if scale <= 0.0:
        raise ValueError(f"ConditionalHorseshoe: scale must be > 0, got {scale!r}")
    self._d = codomain.dim
    inv_softplus_scale = math.log(math.expm1(scale))
    self._raw_scale = torch.nn.Parameter(
        torch.tensor(inv_softplus_scale, dtype=torch.get_default_dtype())
    )

scale property

scale: Tensor

Current (positive) global shrinkage tau.

ConditionalGeneralizedPareto

ConditionalGeneralizedPareto(domain: AnySpace, codomain: ContinuousSpace, hidden_dim: int = 64)

Bases: ContinuousMorphism

Conditional generalized Pareto distribution.

PARAMETER DESCRIPTION
domain

Source space.

TYPE: SetObject or ContinuousSpace

codomain

Target space.

TYPE: ContinuousSpace

hidden_dim

Hidden layer width for neural parameter source.

TYPE: int DEFAULT: 64

Source code in src/quivers/continuous/families.py
2428
2429
2430
2431
2432
2433
2434
2435
2436
2437
2438
def __init__(
    self,
    domain: AnySpace,
    codomain: ContinuousSpace,
    hidden_dim: int = 64,
) -> None:
    super().__init__(domain, codomain)
    d = codomain.dim
    # loc + scale + concentration
    self.param_source = _make_source(domain, 3 * d, hidden_dim)
    self._d = d

LKJCorrelationFactor

LKJCorrelationFactor(dim: int, eta: float, domain: AnySpace)

Bases: ContinuousMorphism

LKJ prior on Cholesky factors LKJ(K, η) over CholeskyFactor(K).

Density on the Cholesky factor:

.. math::

p(L) \propto \prod_{k=2}^{K} L_{kk}^{K - k + 2(\eta - 1)}.

A higher concentration :math:\eta > 1 pulls toward the identity correlation; :math:\eta = 1 is uniform on correlations. Sampling uses the onion method of Lewandowski-Kurowicka-Joe 2009: draw row-norm partial correlations from Beta distributions and form :math:L row-by-row.

PARAMETER DESCRIPTION
dim

Correlation-matrix size :math:K \ge 2.

TYPE: int

eta

Concentration :math:\eta > 0.

TYPE: float

domain

The morphism's source (parameter conditioning); typically the program's input space. The LKJ prior itself does not consume per-observation conditioning, so the rsample path broadcasts the prior across the batch dimension.

TYPE: AnySpace

Source code in src/quivers/continuous/families.py
2810
2811
2812
2813
2814
2815
2816
2817
2818
def __init__(self, dim: int, eta: float, domain: AnySpace) -> None:
    if dim < 2:
        raise ValueError(f"LKJ requires dim >= 2; got {dim}")
    if eta <= 0:
        raise ValueError(f"LKJ requires eta > 0; got {eta}")
    codomain = CholeskyFactor(name=f"L({dim})", dim=dim)
    super().__init__(domain, codomain)
    self._dim = dim
    self._eta = float(eta)

log_prob

log_prob(x: Tensor, y: Tensor) -> Tensor

Log-density of the LKJ prior at the Cholesky factor y.

Up to a normalizing constant that doesn't depend on :math:L, :math:\log p(L) = \sum_{k=2}^{K} (K-k+2(\eta-1)) \log L_{kk}. The diagonal entries are extracted from the flattened representation.

Source code in src/quivers/continuous/families.py
2854
2855
2856
2857
2858
2859
2860
2861
2862
2863
2864
2865
2866
2867
2868
2869
2870
2871
2872
2873
def log_prob(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """Log-density of the LKJ prior at the Cholesky factor ``y``.

    Up to a normalizing constant that doesn't depend on
    :math:`L`, :math:`\\log p(L) = \\sum_{k=2}^{K} (K-k+2(\\eta-1))
    \\log L_{kk}`. The diagonal entries are extracted from the
    flattened representation.
    """
    batch = y.shape[0]
    K = self._dim
    L = y.reshape(batch, K, K)
    diag = torch.diagonal(L, dim1=-2, dim2=-1)  # (batch, K)
    # Coefficients per diagonal entry (Stan's lkj_corr_cholesky_lpdf):
    # log_jac_term[k] = (K - k + 2*(eta - 1)) * log(L_kk)  for k = 2..K
    # Pre-K-indexed: power[0..K-1] where power[k] = (K-1-k) + 2*(eta-1).
    # The first diagonal is fixed at 1 so log(1)=0 contributes nothing.
    ks = torch.arange(K, device=y.device, dtype=y.dtype)
    powers = (K - 1 - ks) + 2.0 * (self._eta - 1.0)
    log_diag = torch.log(diag.clamp(min=1e-30))
    return (powers * log_diag).sum(dim=-1)

Truncated

Truncated(base: ContinuousMorphism, lower: float | None = None, upper: float | None = None, max_rejection_iterations: int = 64)

Bases: ContinuousMorphism

Truncate a base family to an interval :math:[a, b].

Categorical denotation: given a base family :math:F : \Theta \to \mathcal{G}(\mathbb{R}) and constants :math:a, b \in \bar{\mathbb{R}} with :math:a < b, the truncated family has density

.. math::

p_{F_{|[a,b]}}(x) = \frac{p_F(x)}{F_{\text{cdf}}(b)
- F_{\text{cdf}}(a)} \cdot \mathbb{1}_{[a,b]}(x)

and the morphism :math:F_{|[a,b]} : \Theta \to \mathcal{G}([a,b]). Sampling uses inverse-CDF when base supports it; otherwise rejection sampling.

PARAMETER DESCRIPTION
base

The base distribution-family morphism. Must expose log_prob and rsample plus an icdf method or a base_distribution torch Distribution for inverse-CDF sampling. Falls back to rejection sampling otherwise.

TYPE: ContinuousMorphism

lower

Lower bound :math:a. None means :math:-\infty.

TYPE: float or None DEFAULT: None

upper

Upper bound :math:b. None means :math:+\infty.

TYPE: float or None DEFAULT: None

max_rejection_iterations

Cap on rejection-sampling attempts before raising.

TYPE: int DEFAULT: 64

Source code in src/quivers/continuous/families.py
2916
2917
2918
2919
2920
2921
2922
2923
2924
2925
2926
2927
2928
2929
2930
2931
2932
2933
2934
2935
2936
2937
2938
def __init__(
    self,
    base: ContinuousMorphism,
    lower: float | None = None,
    upper: float | None = None,
    max_rejection_iterations: int = 64,
) -> None:
    super().__init__(base.domain, base.codomain)
    if lower is None and upper is None:
        raise ValueError(
            "Truncated requires at least one of lower / upper to be finite; "
            "without truncation, use the base family directly"
        )
    if lower is not None and upper is not None and not (lower < upper):
        raise ValueError(
            f"Truncated requires lower < upper; got lower={lower}, upper={upper}"
        )
    self._base = base
    self._lower = lower
    self._upper = upper
    self._max_iters = max_rejection_iterations
    # Attach so the parent nn.Module tracks parameters.
    self._base_mod = base