Decoders

Decoder is a torch.nn.Module that realizes a Kleisli coalgebra Vec_D → Kern(T_Σ), given an input vector, defines a distribution over terms of a signature. Two operations:

  • sample(vec, ctx, sort) draws a single Term.
  • log_prob(term, vec, ctx, sort) scores an observed term under the same distribution.

The corecursion over a signature Σ:

  1. At each sort position, the decoder produces logits over its choice set, every constructor and binder whose codomain is that sort, plus the built-in BoundVar whenever the context contains at least one in-scope variable of that sort.
  2. For the chosen op, the parent vector is split into per-child sub-vectors by the per-(sort, arity) factor function, and the decoder recurses on each child.
  3. Data-sorted children are sampled from a closed vocabulary via the per-sort primitive head; index-sorted children are sampled via binder_select over the in-scope variables.
  4. Binder ops extend Γ before recursing on their scoped arguments, exactly mirroring the encoder.

Termination is depth-bounded at construction. At the budget limit the choice set is restricted to recursion-terminating ops; if no such op exists at a sort, the decoder raises with a precise diagnostic.

decoder

Decoder runtime: Kleisli coalgebras Vec_D -> Kern(T_Σ).

A Decoder is a torch.nn.Module exposing two operations:

  • sample — draws a single Term from the distribution induced by an input vector.
  • log_prob — scores an observed term under the same distribution.

The corecursion structure over a signature Σ is:

  1. At each sort position, the decoder produces logits over its choice set — every constructor and binder whose codomain is that sort, plus the built-in BOUND_VAR_OP whenever the context Γ contains at least one in-scope variable of that sort.
  2. For the chosen op, the parent vector is split into per-child sub-vectors by the per-(sort, arity) factor function, and the decoder recurses on each child under the same canonical form used by the encoder.
  3. Data-sorted children are sampled from a closed vocabulary via the per-sort primitive head; index-sorted children are sampled via binder_select over the in-scope variables.
  4. Binder ops extend Γ before recursing on their scoped arguments, exactly mirroring the encoder.

Termination is depth-bounded at construction. At the budget limit the choice set is restricted to recursion-terminating ops (every child sort is data / index, never object); if no such op exists at a sort, the decoder raises.

No silent type coercion or sentinel value is ever emitted: an observed term whose shape doesn't match the canonical form raises.

Decoder

Decoder(name: str, signature: Signature, sort_dims: dict[str, int], depth: int, structure_fns: dict[str, Callable[[Tensor], Tensor]], primitive_fns: dict[str, Callable[[Tensor], Tensor]], factor_fns: dict[str, dict[int, Callable[[Tensor], tuple[Tensor, ...]]]], binder_select_fn: Callable[[Tensor, list[Tensor]], Tensor], data_vocab: dict[str, list[DataLeaf]], modules_owned: list[Module] | None = None)

Bases: Module

A Kleisli coalgebraic decoder over a signature.

Construction parameters

name : str Identifier used in diagnostics. signature : Signature The Σ whose terms this decoder generates. sort_dims : dict[str, int] Per-sort embedding dimension. depth : int Hard upper bound on recursion depth. Sampling beyond this depth is restricted to recursion-terminating ops at each sort. structure_fns : dict[str, callable] Per-sort logit producers (vec) -> Tensor. Indexed by sort name, plus a shared "*" entry consulted when no sort-specific entry exists. primitive_fns : dict[str, callable] Per-data-sort logit producers over the closed token vocab. factor_fns : dict[str, dict[int, callable]] Per-sort, per-arity child-vector projections (parent_vec) -> tuple[Tensor, …] of n sub-vectors. binder_select_fn : callable (parent_vec, list_of_var_embeddings) -> Tensor of logits over the in-scope variables of a sort, used by BOUND_VAR_OP and by index-sorted child positions. data_vocab : dict[str, list] Per-data-sort closed vocabulary aligned with the column order of the corresponding primitive_fns output.

Source code in src/quivers/structural/decoder.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
def __init__(
    self,
    name: str,
    signature: Signature,
    sort_dims: dict[str, int],
    depth: int,
    structure_fns: dict[str, Callable[[torch.Tensor], torch.Tensor]],
    primitive_fns: dict[str, Callable[[torch.Tensor], torch.Tensor]],
    factor_fns: dict[
        str, dict[int, Callable[[torch.Tensor], tuple[torch.Tensor, ...]]]
    ],
    binder_select_fn: Callable[[torch.Tensor, list[torch.Tensor]], torch.Tensor],
    data_vocab: dict[str, list[DataLeaf]],
    modules_owned: list[nn.Module] | None = None,
) -> None:
    super().__init__()
    self.name = name
    self.signature = signature
    self.sort_dims = dict(sort_dims)
    if depth <= 0:
        raise ValueError(f"decoder {name!r}: depth must be positive, got {depth}")
    self.depth = depth
    self.structure_fns = dict(structure_fns)
    self.primitive_fns = dict(primitive_fns)
    self.factor_fns = {s: dict(fs) for s, fs in factor_fns.items()}
    self.binder_select_fn = binder_select_fn
    self.data_vocab = dict(data_vocab)
    self._candidates_by_sort = self._collect_candidates()
    for i, m in enumerate(modules_owned or []):
        self.add_module(f"_dec_{i}", m)