Continuous Morphisms

Continuous mappings between topological spaces.

morphisms

Continuous morphisms: Markov kernels on continuous and mixed spaces.

A ContinuousMorphism represents a conditional probability distribution p(y | x) where x and y may live in either discrete (FinSet) or continuous (ContinuousSpace) spaces. The morphism is defined by two operations:

log_prob(x, y) — log-density/probability of y given x
rsample(x)     — reparameterized samples from p(· | x)

Composition uses ancestral sampling:

(g . f)(x, z) = integral f(x, y) g(y, z) dy
               ~ E_{y~f(x,.)}[g(y, z)]

This module provides:

ContinuousMorphism         — abstract base with >> and @ operators
SampledComposition         — f >> g via ancestral sampling
ProductContinuousMorphism  — f @ g (independent product)
DiscreteAsContinuous       — wrap a discrete Morphism as continuous
Convention for input shapes
  • Discrete domain (SetObject): x is LongTensor of shape (batch,)
  • Continuous domain (ContinuousSpace): x is FloatTensor of shape (batch, dim)
  • Discrete codomain: y is LongTensor of shape (batch,)
  • Continuous codomain: y is FloatTensor of shape (batch, dim)

ContinuousMorphism

ContinuousMorphism(domain: AnySpace, codomain: AnySpace)

Bases: Module, ABC

Abstract base for morphisms involving continuous spaces.

Subclasses must implement log_prob and rsample. The composition operator >> and product operator @ are provided and dispatch to SampledComposition and ProductContinuousMorphism respectively.

Unlike discrete Morphism (which materializes a full tensor), ContinuousMorphism is defined operationally: it can evaluate log-densities and generate reparameterized samples.

PARAMETER DESCRIPTION
domain

Source space.

TYPE: SetObject or ContinuousSpace

codomain

Target space.

TYPE: SetObject or ContinuousSpace

Source code in src/quivers/continuous/morphisms.py
68
69
70
71
def __init__(self, domain: AnySpace, codomain: AnySpace) -> None:
    super().__init__()
    self._domain = domain
    self._codomain = codomain

domain property

domain: AnySpace

Source space.

codomain property

codomain: AnySpace

Target space.

support property

support: Constraint

The support constraint of the distribution this morphism samples from, in the form of a torch.distributions.constraints.Constraint.

Used by variational guides (quivers.inference.AutoNormalGuide, quivers.inference.AutoDeltaGuide) to determine the correct bijector that maps an unconstrained variational approximation back to the constrained support of the prior, so that samples used to evaluate the prior's log_prob lie inside its support (avoiding Expected value to be within the support of the distribution errors).

Subclasses representing a constrained distribution family (HalfNormal, Beta, Uniform, Dirichlet, LogitNormal, Wishart, …) should override this property to return the appropriate constraint. The default is torch.distributions.constraints.real, which is correct for unconstrained families like Normal and discrete codomains (where the guide skips the site anyway).

log_prob abstractmethod

log_prob(x: Tensor, y: Tensor) -> Tensor

Log-probability (density) of y given x.

PARAMETER DESCRIPTION
x

Inputs. Shape (batch,) for discrete domain or (batch, domain_dim) for continuous domain.

TYPE: Tensor

y

Outputs. Shape (batch,) for discrete codomain or (batch, codomain_dim) for continuous codomain.

TYPE: Tensor

RETURNS DESCRIPTION
Tensor

Log-probabilities/densities. Shape (batch,).

Source code in src/quivers/continuous/morphisms.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
@abstractmethod
def log_prob(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """Log-probability (density) of y given x.

    Parameters
    ----------
    x : torch.Tensor
        Inputs. Shape (batch,) for discrete domain or
        (batch, domain_dim) for continuous domain.
    y : torch.Tensor
        Outputs. Shape (batch,) for discrete codomain or
        (batch, codomain_dim) for continuous codomain.

    Returns
    -------
    torch.Tensor
        Log-probabilities/densities. Shape (batch,).
    """
    ...

rsample abstractmethod

rsample(x: Tensor, sample_shape: Size = Size()) -> Tensor

Reparameterized samples from p(. | x).

Gradients flow through the returned samples back to the parameters of this morphism (and to x if the domain is continuous).

PARAMETER DESCRIPTION
x

Inputs. Shape (batch,) or (batch, domain_dim).

TYPE: Tensor

sample_shape

Additional leading sample dimensions.

TYPE: Size DEFAULT: Size()

RETURNS DESCRIPTION
Tensor

Samples. Shape (sample_shape, batch, codomain_dim) for continuous codomain, or (sample_shape, batch) for discrete.

Source code in src/quivers/continuous/morphisms.py
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
@abstractmethod
def rsample(
    self, x: torch.Tensor, sample_shape: torch.Size = torch.Size()
) -> torch.Tensor:
    """Reparameterized samples from p(. | x).

    Gradients flow through the returned samples back to the
    parameters of this morphism (and to x if the domain is
    continuous).

    Parameters
    ----------
    x : torch.Tensor
        Inputs. Shape (batch,) or (batch, domain_dim).
    sample_shape : torch.Size
        Additional leading sample dimensions.

    Returns
    -------
    torch.Tensor
        Samples. Shape (*sample_shape, batch, codomain_dim) for
        continuous codomain, or (*sample_shape, batch) for discrete.
    """
    ...

sample

sample(x: Tensor, sample_shape: Size = Size()) -> Tensor

Non-reparameterized samples (no gradient through samples).

PARAMETER DESCRIPTION
x

Inputs.

TYPE: Tensor

sample_shape

Additional leading sample dimensions.

TYPE: Size DEFAULT: Size()

RETURNS DESCRIPTION
Tensor

Samples (detached from computation graph).

Source code in src/quivers/continuous/morphisms.py
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
def sample(
    self, x: torch.Tensor, sample_shape: torch.Size = torch.Size()
) -> torch.Tensor:
    """Non-reparameterized samples (no gradient through samples).

    Parameters
    ----------
    x : torch.Tensor
        Inputs.
    sample_shape : torch.Size
        Additional leading sample dimensions.

    Returns
    -------
    torch.Tensor
        Samples (detached from computation graph).
    """
    with torch.no_grad():
        return self.rsample(x, sample_shape)

__rshift__

__rshift__(other: object) -> ContinuousMorphism

Composition via ancestral sampling: self >> other.

Source code in src/quivers/continuous/morphisms.py
171
172
173
174
175
176
177
178
179
def __rshift__(self, other: object) -> ContinuousMorphism:
    """Composition via ancestral sampling: self >> other."""
    if isinstance(other, ContinuousMorphism):
        return SampledComposition(self, other)
    from quivers.core.morphisms import Morphism

    if isinstance(other, Morphism):
        return SampledComposition(self, DiscreteAsContinuous(other))
    return NotImplemented

__rrshift__

__rrshift__(other: object) -> ContinuousMorphism

Handle discrete_morphism >> continuous_morphism.

Source code in src/quivers/continuous/morphisms.py
181
182
183
184
185
186
187
def __rrshift__(self, other: object) -> ContinuousMorphism:
    """Handle discrete_morphism >> continuous_morphism."""
    from quivers.core.morphisms import Morphism

    if isinstance(other, Morphism):
        return SampledComposition(DiscreteAsContinuous(other), self)
    return NotImplemented

__matmul__

__matmul__(other: object) -> ProductContinuousMorphism

Independent product: self @ other.

Source code in src/quivers/continuous/morphisms.py
189
190
191
192
193
194
195
196
197
def __matmul__(self, other: object) -> ProductContinuousMorphism:
    """Independent product: self @ other."""
    if isinstance(other, ContinuousMorphism):
        return ProductContinuousMorphism(self, other)
    from quivers.core.morphisms import Morphism

    if isinstance(other, Morphism):
        return ProductContinuousMorphism(self, DiscreteAsContinuous(other))
    return NotImplemented

SampledComposition

SampledComposition(left: ContinuousMorphism, right: ContinuousMorphism, n_intermediate: int = 100)

Bases: ContinuousMorphism

Composition of morphisms via ancestral sampling.

Given f: X -> Y and g: Y -> Z, the composition g . f satisfies:

(g . f)(x, z) = integral f(x, y) g(y, z) dy

This integral is computed: - Exactly (finite sum) when Y is discrete. - Approximately (Monte Carlo) when Y is continuous.

For rsample: draw y ~ f(x, .), then draw z ~ g(y, .). For log_prob: sum/average g(z | y_i) weighted by f(y_i | x).

PARAMETER DESCRIPTION
left

First morphism (applied first).

TYPE: ContinuousMorphism

right

Second morphism (applied second).

TYPE: ContinuousMorphism

n_intermediate

Number of Monte Carlo samples for continuous intermediate spaces. Ignored when the intermediate space is discrete.

TYPE: int DEFAULT: 100

Source code in src/quivers/continuous/morphisms.py
313
314
315
316
317
318
319
320
321
322
def __init__(
    self,
    left: ContinuousMorphism,
    right: ContinuousMorphism,
    n_intermediate: int = 100,
) -> None:
    super().__init__(left.domain, right.codomain)
    self.left = left
    self.right = right
    self.n_intermediate = n_intermediate

rsample

rsample(x: Tensor, sample_shape: Size = Size()) -> Tensor

Ancestral sampling: y ~ f(x, .), then z ~ g(y, .).

PARAMETER DESCRIPTION
x

Inputs to the composition.

TYPE: Tensor

sample_shape

Additional sample dimensions.

TYPE: Size DEFAULT: Size()

RETURNS DESCRIPTION
Tensor

Samples from the composed morphism.

Source code in src/quivers/continuous/morphisms.py
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
def rsample(
    self, x: torch.Tensor, sample_shape: torch.Size = torch.Size()
) -> torch.Tensor:
    """Ancestral sampling: y ~ f(x, .), then z ~ g(y, .).

    Parameters
    ----------
    x : torch.Tensor
        Inputs to the composition.
    sample_shape : torch.Size
        Additional sample dimensions.

    Returns
    -------
    torch.Tensor
        Samples from the composed morphism.
    """
    y = self.left.rsample(x, sample_shape)
    if len(sample_shape) > 0:
        leading = y.shape[: len(sample_shape)]
        batch = x.shape[0]
        flat_size = int(torch.tensor(leading).prod().item()) * batch
        if y.dim() > len(sample_shape) + 1:
            event_dims = y.shape[len(sample_shape) + 1 :]
            flat_y = y.reshape(flat_size, *event_dims)
        else:
            flat_y = y.reshape(flat_size)
    else:
        flat_y = y
    z = self.right.rsample(flat_y)
    if len(sample_shape) > 0:
        batch = x.shape[0]
        if z.dim() > 1:
            event_dims = z.shape[1:]
            z = z.reshape(*sample_shape, batch, *event_dims)
        else:
            z = z.reshape(*sample_shape, batch)
    return z

log_prob

log_prob(x: Tensor, y: Tensor) -> Tensor

Log-probability of y given x through the composition.

When the intermediate space is discrete, computes the exact marginalization. When continuous, uses Monte Carlo estimation.

PARAMETER DESCRIPTION
x

Inputs. Shape (batch,) or (batch, dom_dim).

TYPE: Tensor

y

Outputs. Shape (batch,) or (batch, cod_dim).

TYPE: Tensor

RETURNS DESCRIPTION
Tensor

Log-probabilities. Shape (batch,).

Source code in src/quivers/continuous/morphisms.py
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
def log_prob(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """Log-probability of y given x through the composition.

    When the intermediate space is discrete, computes the exact
    marginalization. When continuous, uses Monte Carlo estimation.

    Parameters
    ----------
    x : torch.Tensor
        Inputs. Shape (batch,) or (batch, dom_dim).
    y : torch.Tensor
        Outputs. Shape (batch,) or (batch, cod_dim).

    Returns
    -------
    torch.Tensor
        Log-probabilities. Shape (batch,).
    """
    intermediate = self.left.codomain
    if isinstance(intermediate, SetObject):
        return self._log_prob_exact(x, y, intermediate)
    else:
        return self._log_prob_mc(x, y)

ProductContinuousMorphism

ProductContinuousMorphism(left: ContinuousMorphism, right: ContinuousMorphism)

Bases: ContinuousMorphism

Independent product of two continuous morphisms.

Given f: A -> B and g: C -> D, produces f @ g: (A, C) -> (B, D) where p_{f@g}((y,z) | (x,w)) = f(y | x) * g(z | w).

Domain inputs are concatenated: (x, w) as a single vector. Codomain outputs are concatenated: (y, z) as a single vector. For discrete components, indices are embedded as 1-d floats.

PARAMETER DESCRIPTION
left

Left factor morphism.

TYPE: ContinuousMorphism

right

Right factor morphism.

TYPE: ContinuousMorphism

Source code in src/quivers/continuous/morphisms.py
454
455
456
457
458
459
460
461
462
463
def __init__(self, left: ContinuousMorphism, right: ContinuousMorphism) -> None:
    dom = _combine_spaces(left.domain, right.domain)
    cod = _combine_spaces(left.codomain, right.codomain)
    super().__init__(dom, cod)
    self.left = left
    self.right = right
    self._left_dom_dim = _event_dim(left.domain)
    self._right_dom_dim = _event_dim(right.domain)
    self._left_cod_dim = _event_dim(left.codomain)
    self._right_cod_dim = _event_dim(right.codomain)

FanOutMorphism

FanOutMorphism(components: list)

Bases: ContinuousMorphism

Fan-out morphism: copy input to N morphisms, concatenate outputs.

Given f_1: A -> B_1, f_2: A -> B_2, ..., f_N: A -> B_N, produces fan(f_1, ..., f_N): A -> B_1 * B_2 * ... * B_N where the input A is copied to all N morphisms.

Unlike the tensor product (f @ g), which takes a product domain (A * C), fan-out feeds the same input to all morphisms. This implements the diagonal morphism Delta: A -> A^N followed by the product f_1 @ f_2 @ ... @ f_N.

PARAMETER DESCRIPTION
components

The morphisms to fan out to. All must share the same domain.

TYPE: list[ContinuousMorphism]

Source code in src/quivers/continuous/morphisms.py
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
def __init__(self, components: list) -> None:
    from quivers.core.morphisms import Morphism as _CatMorphism

    if not components:
        raise ValueError("fan-out requires at least one component")
    # Backend-agnostic V-Cat morphisms (those that aren't
    # already ContinuousMorphism subclasses) get wrapped in a
    # deterministic continuous adapter so the FanOut's rsample
    # / log_prob loop can dispatch uniformly. The wrapping
    # exposes the V-Cat tensor through a categorical
    # ``rsample`` that gathers / contracts the tensor against
    # the input; ``log_prob`` evaluates the V-Cat tensor as a
    # categorical likelihood when meaningful.
    wrapped_components: list[ContinuousMorphism] = []
    for c in components:
        if isinstance(c, ContinuousMorphism):
            wrapped_components.append(c)
        elif isinstance(c, _CatMorphism):
            wrapped_components.append(DiscreteAsContinuous(c))
        else:
            raise TypeError(
                f"fan-out: component of type "
                f"{type(c).__name__} is neither a "
                f"ContinuousMorphism nor a V-Cat Morphism"
            )
    domain = wrapped_components[0].domain
    for i, c in enumerate(wrapped_components[1:], 1):
        dom_dim = _event_dim(domain)
        c_dim = _event_dim(c.domain)
        if dom_dim != c_dim:
            raise TypeError(
                f"fan-out: component {i} domain dim {c_dim} != component 0 domain dim {dom_dim}"
            )
    codomain = wrapped_components[0].codomain
    for c in wrapped_components[1:]:
        codomain = _combine_spaces(codomain, c.codomain)
    super().__init__(domain, codomain)
    self._components = torch.nn.ModuleList(wrapped_components)
    self._cod_dims = [_event_dim(c.codomain) for c in wrapped_components]

rsample

rsample(x: Tensor, sample_shape: Size = Size()) -> Tensor

Sample from all components and concatenate outputs.

PARAMETER DESCRIPTION
x

Input tensor (broadcast to all components).

TYPE: Tensor

sample_shape

Additional leading sample dimensions.

TYPE: Size DEFAULT: Size()

RETURNS DESCRIPTION
Tensor

Concatenated outputs from all components.

Source code in src/quivers/continuous/morphisms.py
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
def rsample(
    self, x: torch.Tensor, sample_shape: torch.Size = torch.Size()
) -> torch.Tensor:
    """Sample from all components and concatenate outputs.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor (broadcast to all components).
    sample_shape : torch.Size
        Additional leading sample dimensions.

    Returns
    -------
    torch.Tensor
        Concatenated outputs from all components.
    """
    outs = []
    for comp in self._components:
        y = cast(ContinuousMorphism, comp).rsample(x, sample_shape)
        if y.dim() == 1:
            y = y.unsqueeze(-1)
        outs.append(y)
    return torch.cat(outs, dim=-1)

log_prob

log_prob(x: Tensor, y: Tensor) -> Tensor

Log-probability: sum of component log-probs.

PARAMETER DESCRIPTION
x

Input (same for all components).

TYPE: Tensor

y

Concatenated output values.

TYPE: Tensor

RETURNS DESCRIPTION
Tensor

Sum of log-probabilities. Shape (batch,).

Source code in src/quivers/continuous/morphisms.py
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
def log_prob(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """Log-probability: sum of component log-probs.

    Parameters
    ----------
    x : torch.Tensor
        Input (same for all components).
    y : torch.Tensor
        Concatenated output values.

    Returns
    -------
    torch.Tensor
        Sum of log-probabilities. Shape ``(batch,)``.
    """
    lp = torch.zeros(x.shape[0], device=x.device)
    offset = 0
    for comp_mod, d in zip(self._components, self._cod_dims):
        comp = cast(ContinuousMorphism, comp_mod)
        y_slice = y[..., offset : offset + d]
        if _is_discrete(comp.codomain):
            y_slice = y_slice.squeeze(-1).long()
        lp = lp + comp.log_prob(x, y_slice)
        offset += d
    return lp

DiscreteAsContinuous

DiscreteAsContinuous(inner: object)

Bases: ContinuousMorphism

Wrap a discrete Morphism as a ContinuousMorphism.

Enables composition between discrete and continuous morphisms via the >> operator. The wrapped morphism's tensor is used for both log_prob evaluation and sampling.

Note

Sampling from a discrete distribution is NOT reparameterizable. Gradients do not flow through the discrete samples back to the left morphism's parameters. Use score function estimators (REINFORCE) if gradients through discrete choices are needed.

PARAMETER DESCRIPTION
inner

The discrete morphism to wrap.

TYPE: Morphism

Source code in src/quivers/continuous/morphisms.py
635
636
637
638
639
640
641
642
def __init__(self, inner: object) -> None:
    from quivers.core.morphisms import Morphism

    if not isinstance(inner, Morphism):
        raise TypeError(f"expected a discrete Morphism, got {type(inner).__name__}")
    super().__init__(inner.domain, inner.codomain)
    self._inner = inner
    self._inner_module = inner.module()

log_prob

log_prob(x: Tensor, y: Tensor) -> Tensor

Log-probability from the discrete tensor.

PARAMETER DESCRIPTION
x

Domain indices. Shape (batch,).

TYPE: Tensor

y

Codomain indices. Shape (batch,).

TYPE: Tensor

RETURNS DESCRIPTION
Tensor

Log-probabilities. Shape (batch,).

Source code in src/quivers/continuous/morphisms.py
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
def log_prob(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """Log-probability from the discrete tensor.

    Parameters
    ----------
    x : torch.Tensor
        Domain indices. Shape (batch,).
    y : torch.Tensor
        Codomain indices. Shape (batch,).

    Returns
    -------
    torch.Tensor
        Log-probabilities. Shape (batch,).
    """
    t = self._inner.tensor
    probs = t[x.long(), y.long()]
    return torch.log(probs.clamp(min=1e-07))

rsample

rsample(x: Tensor, sample_shape: Size = Size()) -> Tensor

Sample from the categorical distribution defined by the tensor.

Note: not reparameterizable. Gradients do not flow through the returned samples.

PARAMETER DESCRIPTION
x

Domain indices. Shape (batch,).

TYPE: Tensor

sample_shape

Additional sample dimensions.

TYPE: Size DEFAULT: Size()

RETURNS DESCRIPTION
Tensor

Sampled codomain indices. Shape (*sample_shape, batch).

Source code in src/quivers/continuous/morphisms.py
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
def rsample(
    self, x: torch.Tensor, sample_shape: torch.Size = torch.Size()
) -> torch.Tensor:
    """Sample from the categorical distribution defined by the tensor.

    Note: not reparameterizable. Gradients do not flow through
    the returned samples.

    Parameters
    ----------
    x : torch.Tensor
        Domain indices. Shape (batch,).
    sample_shape : torch.Size
        Additional sample dimensions.

    Returns
    -------
    torch.Tensor
        Sampled codomain indices. Shape (*sample_shape, batch).
    """
    t = self._inner.tensor
    probs = t[x.long()]
    n_samples = (
        int(torch.Size(sample_shape).numel()) if len(sample_shape) > 0 else 1
    )
    samples = torch.multinomial(probs, n_samples, replacement=True)
    if len(sample_shape) == 0:
        return samples.squeeze(-1)
    else:
        return samples.T.reshape(*sample_shape, -1)