Scan Morphism

scan(cell) realises the iterated Kleisli composition of a per-step cell across a sequence input. ScanMorphism.rsample(x) runs the per-step kernel forward, threading the hidden state; ScanMorphism.log_joint(x, h) returns the per-step log-density sum and accepts the hidden-state trajectory either as a positional tensor or as a {state_key: tensor} dict, so the standard inference contract log_joint(x, observations: dict) works without an adapter.

scan

Scan combinator: temporal recurrence over sequences.

A ScanMorphism wraps a recurrent cell and applies it across a sequence, threading hidden state from one time step to the next. This implements the standard RNN pattern:

h_t = cell(x_t, h_{t-1})

where cell : A * H -> H is a morphism (either a plain ContinuousMorphism or a MonadicProgram) whose domain is a product of the per-timestep input space A and the hidden state space H, and whose codomain is H.

Given a cell : A * H -> H, scan(cell) produces a morphism A -> H that, at runtime:

  1. Expects a 3D input tensor of shape (batch, seq_len, dim_A).
  2. Initializes hidden state h_0 (zeros or a learned parameter).
  3. At each step t, concatenates x[:, t, :] with h to form the cell input, then calls cell.rsample to produce the new h.
  4. Returns the final hidden state h_T of shape (batch, dim_H).

The scan's type in the categorical framework is:

scan(f : A x H -> H) : A -> H

where the sequence structure is implicit in the tensor's time dimension, following standard neural network conventions.

Initialization strategies
  • "zeros": h_0 = 0 (default).
  • "learned": h_0 is a learnable nn.Parameter.

Examples:

>>> from quivers.continuous.spaces import Euclidean, ProductSpace
>>> from quivers.continuous.families import ConditionalNormal
>>> A = Euclidean(name="input", dim=32)
>>> H = Euclidean(name="hidden", dim=64)
>>> cell = ConditionalNormal(ProductSpace(A, H), H, scale=0.1)
>>> scanned = ScanMorphism(cell, init="zeros")
>>> scanned.domain   # Euclidean(name="input", dim=32)
>>> scanned.codomain # Euclidean(name="hidden", dim=64)
>>> x = torch.randn(8, 10, 32)  # batch=8, seq_len=10, input_dim=32
>>> h = scanned.rsample(x)      # (8, 64)

ScanMorphism

ScanMorphism(cell: ContinuousMorphism, init: str = 'zeros')

Bases: ContinuousMorphism

Temporal scan: apply a recurrent cell across a sequence.

Wraps a cell morphism f : A * H -> H and produces a morphism A -> H that iterates over the time dimension of a 3D input tensor, threading hidden state forward.

This implements standard RNN-style recurrence::

h_0 = init
h_t = cell(concat(x_t, h_{t-1}))  for t = 1..T

The scan returns the final hidden state h_T.

PARAMETER DESCRIPTION
cell

The recurrent cell. Must have a product domain A * H and codomain H, where H matches the last component of the product domain.

TYPE: ContinuousMorphism

init

Initialization strategy for h_0. One of "zeros" (default) or "learned" (trainable initial state).

TYPE: str DEFAULT: 'zeros'

Source code in src/quivers/continuous/scan.py
130
131
132
133
134
135
136
137
138
139
140
141
142
143
def __init__(self, cell: ContinuousMorphism, init: str = "zeros") -> None:
    input_space = _extract_input_space(cell)
    hidden_space = cell.codomain
    super().__init__(input_space, hidden_space)
    self._cell = cell
    self._init_strategy = init
    self._input_dim = _event_dim(input_space)
    self._hidden_dim = _event_dim(hidden_space)
    if init == "learned":
        self._h0 = nn.Parameter(torch.zeros(self._hidden_dim))
    elif init != "zeros":
        raise ValueError(
            f"unknown init strategy {init!r}; expected 'zeros' or 'learned'"
        )

rsample

rsample(x: Tensor, sample_shape: Size = Size()) -> Tensor

Run the cell across the time dimension of x.

PARAMETER DESCRIPTION
x

Input sequence. Shape (batch, seq_len, input_dim).

TYPE: Tensor

sample_shape

Additional leading sample dimensions (applied to the cell's rsample at the first time step only).

TYPE: Size DEFAULT: Size()

RETURNS DESCRIPTION
Tensor

Final hidden state. Shape (batch, hidden_dim), or (*sample_shape, batch, hidden_dim) if sample_shape is non-empty.

Source code in src/quivers/continuous/scan.py
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
def rsample(
    self, x: torch.Tensor, sample_shape: torch.Size = torch.Size()
) -> torch.Tensor:
    """Run the cell across the time dimension of x.

    Parameters
    ----------
    x : torch.Tensor
        Input sequence. Shape ``(batch, seq_len, input_dim)``.
    sample_shape : torch.Size
        Additional leading sample dimensions (applied to the
        cell's rsample at the first time step only).

    Returns
    -------
    torch.Tensor
        Final hidden state. Shape ``(batch, hidden_dim)``,
        or ``(*sample_shape, batch, hidden_dim)`` if
        sample_shape is non-empty.
    """
    if x.dim() == 2:
        x = x.unsqueeze(1)
    batch, seq_len, _ = x.shape
    if self._init_strategy == "learned":
        h = self._h0.unsqueeze(0).expand(batch, -1)
    else:
        h = torch.zeros(batch, self._hidden_dim, device=x.device, dtype=x.dtype)
    for t in range(seq_len):
        x_t = x[:, t, :]
        cell_input = torch.cat([x_t, h], dim=-1)
        if t == 0 and len(sample_shape) > 0:
            h = self._cell.rsample(cell_input, sample_shape)
            h = self._flatten_cell_output(h)
            if len(sample_shape) > 0 and h.dim() > 2:
                x = x.unsqueeze(0).expand(*sample_shape, *x.shape)
        else:
            if h.dim() > 2:
                x_t = x[..., t, :]
                cell_input = torch.cat([x_t, h], dim=-1)
            h = self._cell.rsample(cell_input)
            h = self._flatten_cell_output(h)
    return h

log_prob

log_prob(x: Tensor, y: Tensor) -> Tensor

Log-density of the scan-induced kernel at the final state.

scan(cell) denotes a Kleisli morphism :math:\mathbf{x}_{1:T} \to \mathcal{G}(h_T) whose marginal density is the integral over all intermediate hidden states. Closed form is generally intractable, but when the cell is a continuous morphism whose randomness is concentrated in its weight latents (the standard Bayesian-RNN setting) the per-step distribution given fixed weights is a deterministic function. In that regime scan denotes a deterministic composition and log_prob is identically zero (a Dirac delta at the realised :math:h_T); the cell's weight latents carry the model's stochasticity and are scored on their own sample steps.

Returns a (batch,)-shaped tensor of zeros so the surrounding log_joint can add it without further special-casing.

Source code in src/quivers/continuous/scan.py
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
def log_prob(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """Log-density of the scan-induced kernel at the final state.

    ``scan(cell)`` denotes a Kleisli morphism
    :math:`\\mathbf{x}_{1:T} \\to \\mathcal{G}(h_T)` whose marginal
    density is the integral over all intermediate hidden states.
    Closed form is generally intractable, but when the cell is
    a continuous morphism whose randomness is concentrated in
    its weight latents (the standard Bayesian-RNN setting) the
    per-step distribution given fixed weights is a deterministic
    function. In that regime ``scan`` denotes a deterministic
    composition and ``log_prob`` is identically zero (a Dirac
    delta at the realised :math:`h_T`); the cell's weight
    latents carry the model's stochasticity and are scored on
    their own ``sample`` steps.

    Returns a ``(batch,)``-shaped tensor of zeros so the surrounding
    ``log_joint`` can add it without further special-casing.
    """
    batch = x.shape[0] if x.dim() >= 1 else 1
    return torch.zeros(batch, device=x.device, dtype=x.dtype)

log_joint

log_joint(x: Tensor, hidden_states: 'torch.Tensor | dict[str, torch.Tensor]', *, state_key: str = 'h') -> Tensor

Joint log-density given all intermediate hidden states.

Computes: log p(h_1, ..., h_T | x_{1:T}) = sum_t log p(h_t | x_t, h_{t-1})

PARAMETER DESCRIPTION
x

Input sequence. Shape (batch, seq_len, input_dim).

TYPE: Tensor

hidden_states

All hidden states including final, shape (batch, seq_len, hidden_dim). May be passed positionally as a tensor or via a dict keyed by state_key (so the inference layer's standard log_joint(x, observations: dict) contract works without an adapter).

TYPE: Tensor | dict[str, Tensor]

state_key

Dict key under which the hidden-state tensor is looked up when hidden_states is a dict. Defaults to "h".

TYPE: str DEFAULT: 'h'

RETURNS DESCRIPTION
Tensor

Joint log-density. Shape (batch,).

Source code in src/quivers/continuous/scan.py
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
def log_joint(
    self,
    x: torch.Tensor,
    hidden_states: "torch.Tensor | dict[str, torch.Tensor]",
    *,
    state_key: str = "h",
) -> torch.Tensor:
    """Joint log-density given all intermediate hidden states.

    Computes:
        log p(h_1, ..., h_T | x_{1:T}) =
            sum_t log p(h_t | x_t, h_{t-1})

    Parameters
    ----------
    x : torch.Tensor
        Input sequence. Shape ``(batch, seq_len, input_dim)``.
    hidden_states : torch.Tensor | dict[str, torch.Tensor]
        All hidden states including final, shape
        ``(batch, seq_len, hidden_dim)``. May be passed
        positionally as a tensor or via a dict keyed by
        ``state_key`` (so the inference layer's standard
        ``log_joint(x, observations: dict)`` contract works
        without an adapter).
    state_key : str
        Dict key under which the hidden-state tensor is
        looked up when ``hidden_states`` is a dict. Defaults
        to ``"h"``.

    Returns
    -------
    torch.Tensor
        Joint log-density. Shape ``(batch,)``.
    """
    if isinstance(hidden_states, dict):
        hidden_states = hidden_states[state_key]
    batch, seq_len, _ = x.shape
    total = torch.zeros(batch, device=x.device)
    if self._init_strategy == "learned":
        h = self._h0.unsqueeze(0).expand(batch, -1)
    else:
        h = torch.zeros(batch, self._hidden_dim, device=x.device, dtype=x.dtype)
    for t in range(seq_len):
        x_t = x[:, t, :]
        h_t = hidden_states[:, t, :]
        cell_input = torch.cat([x_t, h], dim=-1)
        total = total + self._cell.log_prob(cell_input, h_t)
        h = h_t
    return total