Decoders¶
Decoder is a torch.nn.Module that realizes a Kleisli
coalgebra Vec_D → Kern(T_Σ), given an input vector, defines
a distribution over terms of a signature. Two operations:
sample(vec, ctx, sort)draws a singleTerm.log_prob(term, vec, ctx, sort)scores an observed term under the same distribution.
The corecursion over a signature Σ:
- At each sort position, the decoder produces logits over its
choice set, every constructor and binder whose codomain is
that sort, plus the built-in
BoundVarwhenever the context contains at least one in-scope variable of that sort. - For the chosen op, the parent vector is split into per-child
sub-vectors by the per-(sort, arity)
factorfunction, and the decoder recurses on each child. - Data-sorted children are sampled from a closed vocabulary via
the per-sort
primitivehead; index-sorted children are sampled viabinder_selectover the in-scope variables. - Binder ops extend Γ before recursing on their scoped arguments, exactly mirroring the encoder.
Termination is depth-bounded at construction. At the budget limit the choice set is restricted to recursion-terminating ops; if no such op exists at a sort, the decoder raises with a precise diagnostic.
decoder
¶
Decoder runtime: Kleisli coalgebras Vec_D -> Kern(T_Σ).
A Decoder is a torch.nn.Module exposing two operations:
sample— draws a singleTermfrom the distribution induced by an input vector.log_prob— scores an observed term under the same distribution.
The corecursion structure over a signature Σ is:
- At each sort position, the decoder produces logits over its
choice set — every constructor and binder whose codomain is
that sort, plus the built-in
BOUND_VAR_OPwhenever the context Γ contains at least one in-scope variable of that sort. - For the chosen op, the parent vector is split into per-child
sub-vectors by the per-(sort, arity)
factorfunction, and the decoder recurses on each child under the same canonical form used by the encoder. - Data-sorted children are sampled from a closed vocabulary via
the per-sort
primitivehead; index-sorted children are sampled viabinder_selectover the in-scope variables. - Binder ops extend Γ before recursing on their scoped arguments, exactly mirroring the encoder.
Termination is depth-bounded at construction. At the budget limit the choice set is restricted to recursion-terminating ops (every child sort is data / index, never object); if no such op exists at a sort, the decoder raises.
No silent type coercion or sentinel value is ever emitted: an observed term whose shape doesn't match the canonical form raises.
Decoder
¶
Decoder(name: str, signature: Signature, sort_dims: dict[str, int], depth: int, structure_fns: dict[str, Callable[[Tensor], Tensor]], primitive_fns: dict[str, Callable[[Tensor], Tensor]], factor_fns: dict[str, dict[int, Callable[[Tensor], tuple[Tensor, ...]]]], binder_select_fn: Callable[[Tensor, list[Tensor]], Tensor], data_vocab: dict[str, list[DataLeaf]], modules_owned: list[Module] | None = None)
Bases: Module
A Kleisli coalgebraic decoder over a signature.
Construction parameters
name : str
Identifier used in diagnostics.
signature : Signature
The Σ whose terms this decoder generates.
sort_dims : dict[str, int]
Per-sort embedding dimension.
depth : int
Hard upper bound on recursion depth. Sampling beyond this
depth is restricted to recursion-terminating ops at each
sort.
structure_fns : dict[str, callable]
Per-sort logit producers (vec) -> Tensor. Indexed by
sort name, plus a shared "*" entry consulted when no
sort-specific entry exists.
primitive_fns : dict[str, callable]
Per-data-sort logit producers over the closed token vocab.
factor_fns : dict[str, dict[int, callable]]
Per-sort, per-arity child-vector projections
(parent_vec) -> tuple[Tensor, …] of n sub-vectors.
binder_select_fn : callable
(parent_vec, list_of_var_embeddings) -> Tensor of
logits over the in-scope variables of a sort, used by
BOUND_VAR_OP and by index-sorted child positions.
data_vocab : dict[str, list]
Per-data-sort closed vocabulary aligned with the column
order of the corresponding primitive_fns output.
Source code in src/quivers/structural/decoder.py
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 | |