Deduction Systems

The user-facing surface for working with weighted chart deductions: the model type, its abstract primitives, and three orthogonal operations on it. All public symbols re-exported from quivers.stochastic.deduction.

Submodule Job
primitives Abstract building blocks (Axiom, Deduction, Goal, Schedule, DeductiveSystem). Most users do not touch these directly.
fit Point-estimate gradient fitting (MAP / MLE) of the deduction's learnable log-weights.
bayes Lift the parameters into a Bayesian MonadicProgram whose posterior NUTS / SVI can target.
sample Exact length-conditional forward sampling of yields from the chart's distribution.

deduction

The deduction-system surface: model type, abstract primitives, and the three independent operations on a deduction.

This package gathers everything a user needs for working with a weighted deduction system in one importable path. The contents fall into three orthogonal groups:

  • Model type (DeductionSystem) — the concrete agenda-based chart deduction that compiles from a deduction block in the QVR DSL. Re-exported from quivers.stochastic.agenda.
  • Abstract primitives (.primitives) — Axiom, Deduction, Goal, Schedule, DeductiveSystem: the protocol layer the agenda implementation derives from. Most users will not need these directly; they exist for custom-deduction subclasses and for the inside-algorithm framework (quivers.stochastic.inside).
  • Operations — three independent surfaces over a DeductionSystem with no overlap in purpose:
    • .fit — point-estimate gradient fitting (MAP / MLE);
    • .bayes — Bayesian wrapping for NUTS / SVI;
    • .sample — exact length-conditional forward sampling of yields.

These three live in separate submodules because they answer different questions (estimate parameters at a point vs. sample a posterior vs. generate synthetic data) and so that adding more operations in the future does not bloat any one module.

DeductionSystem dataclass

DeductionSystem(rules: tuple[InferenceRule, ...], semiring: ChartSemiring, axiom_injector: Callable[[Any], list[tuple[Item, Tensor]]], goal: Callable[[Item], bool], agenda_factory: Callable[[], Agenda] = FIFOAgenda, chart_factory: Callable[[], Chart] = HashChart, max_iterations: int = 100000, tolerance: float = 0.0)

A weighted deductive system parameterized over its components.

The system is parameterized by:

  • An item algebra (implicit in the patterns the rules use).
  • A list of arity-n InferenceRule hyperedges.
  • A ChartSemiring for weight aggregation.
  • An axiom injector In -> [(Item, Weight)] producing the input's lexical / boundary items.
  • A goal predicate Item -> bool selecting the result items.
  • (Optional) a chart constructor and an agenda strategy.

The same data structure subsumes CKY (FIFO agenda, span items, Boolean / inside semiring), Viterbi (priority agenda with the current weight as priority, max-times semiring), A* parsing (priority agenda with an admissible heuristic, tropical semiring), MLTT type-checking (LIFO agenda, judgment items, Boolean semiring), and weighted Datalog (FIFO, atoms, any naturally-ordered semiring).

run

run(input_value: Any) -> AgendaResult

Run the deduction system on an input value.

Source code in src/quivers/stochastic/agenda.py
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
def run(self, input_value: Any) -> AgendaResult:
    """Run the deduction system on an input value."""
    axioms = self.axiom_injector(input_value)
    registry = getattr(self, "_loss_registry", None)
    deduction_name = getattr(self, "_deduction_name", None)
    rule_loss_acc: list[torch.Tensor] = []

    def _rule_callback(
        rule_name: str,
        antecedents: list[tuple[Item, torch.Tensor]],
        conclusion: Item,
        conclusion_w: torch.Tensor,
    ) -> None:
        if registry is None or deduction_name is None:
            return
        env = {
            "rule": rule_name,
            "deduction": deduction_name,
            "antecedents": list(antecedents),
            "conclusion": conclusion,
            "weight": conclusion_w,
        }
        val = registry.evaluate_on(
            "rule",
            target=rule_name,
            env=env,
            rule_deduction=deduction_name,
        )
        rule_loss_acc.append(val)

    # If the user supplied a positive ``tolerance``, propagate
    # it to the chart so its aggregation step terminates on
    # convergence. The default chart_factory is ``HashChart``,
    # whose constructor accepts the tolerance; user-supplied
    # alternative chart factories can ignore the argument.
    try:
        chart_inst = self.chart_factory(tolerance=self.tolerance)
    except TypeError:
        chart_inst = self.chart_factory()
    result = run_agenda(
        axioms=axioms,
        rules=self.rules,
        semiring=self.semiring,
        agenda=self.agenda_factory(),
        goal=self.goal,
        max_iterations=self.max_iterations,
        chart=chart_inst,
        rule_callback=(_rule_callback if registry is not None else None),
    )
    # Propagate any attached item-encoder to the result.
    comp = getattr(self, "_item_encoder", None)
    if comp is not None:
        result.encoder = comp
    # Evaluate chart-attached losses on the completed chart.
    if registry is not None and deduction_name is not None:
        chart_env = {
            "deduction": deduction_name,
            "chart": result.chart,
            "goal_items": result.goal_items,
        }
        chart_loss = registry.evaluate_on(
            "chart",
            target=deduction_name,
            env=chart_env,
        )
        losses = rule_loss_acc + [chart_loss]
    else:
        losses = rule_loss_acc
    if losses:
        total = losses[0]
        for v in losses[1:]:
            total = total + v
        result.attached_loss = total
    return result

__call__

__call__(input_value: Any) -> ChartView

Run the deduction and return a ChartView.

Convenience for the user-facing presheaf-evaluation API: the chart's weights are differentiable tensors, and the view exposes weight, enumerate, derivations, and goal_weight methods for downstream programs.

Source code in src/quivers/stochastic/agenda.py
1040
1041
1042
1043
1044
1045
1046
1047
1048
def __call__(self, input_value: Any) -> ChartView:
    """Run the deduction and return a `ChartView`.

    Convenience for the user-facing presheaf-evaluation API:
    the chart's weights are differentiable tensors, and the
    view exposes ``weight``, ``enumerate``, ``derivations``,
    and ``goal_weight`` methods for downstream programs.
    """
    return ChartView(self.run(input_value))

parameters

parameters(recurse: bool = True) -> Iterable[Parameter]

Yield every learnable parameter owned by this system.

Walks the optional _axiom_module (lexicon log-weights) and _rule_module (per-rule, per-binding log-weights) submodules attached by the compiler. The recurse flag is the standard torch.nn.Module.parameters signature so user code can pass a DeductionSystem anywhere a nn.Module parameter iterator is expected.

Source code in src/quivers/stochastic/agenda.py
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
def parameters(self, recurse: bool = True) -> Iterable[torch.nn.Parameter]:
    """Yield every learnable parameter owned by this system.

    Walks the optional ``_axiom_module`` (lexicon log-weights)
    and ``_rule_module`` (per-rule, per-binding log-weights)
    submodules attached by the compiler. The ``recurse`` flag
    is the standard `torch.nn.Module.parameters` signature
    so user code can pass a ``DeductionSystem`` anywhere a
    ``nn.Module`` parameter iterator is expected.
    """
    for attr in ("_axiom_module", "_rule_module"):
        mod = getattr(self, attr, None)
        if mod is not None and hasattr(mod, "parameters"):
            yield from mod.parameters(recurse=recurse)

named_parameters

named_parameters(prefix: str = '', recurse: bool = True) -> Iterable[tuple[str, Parameter]]

Yield (name, parameter) pairs over all learnable parameters.

Source code in src/quivers/stochastic/agenda.py
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
def named_parameters(
    self,
    prefix: str = "",
    recurse: bool = True,
) -> Iterable[tuple[str, torch.nn.Parameter]]:
    """Yield ``(name, parameter)`` pairs over all learnable parameters."""
    for attr in ("_axiom_module", "_rule_module"):
        mod = getattr(self, attr, None)
        if mod is not None and hasattr(mod, "named_parameters"):
            sub_prefix = f"{prefix}.{attr}" if prefix else attr
            for n, p in mod.named_parameters(
                prefix=sub_prefix,
                recurse=recurse,
            ):
                yield n, p

Axiom

Bases: Module

Initial items in a weighted deductive system.

An axiom creates the chart tensor and populates it with initial weights derived from the input. For span-based chart parsing, this is the lexical step.

forward abstractmethod

forward(input: Tensor, semiring: ChartSemiring) -> Tensor

Create and populate the initial chart.

PARAMETER DESCRIPTION
input

Raw input (e.g. token indices).

TYPE: Tensor

semiring

The scoring semiring.

TYPE: ChartSemiring

RETURNS DESCRIPTION
Tensor

The chart tensor with initial items filled in.

Source code in src/quivers/stochastic/deduction/primitives.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
@abstractmethod
def forward(
    self,
    input: torch.Tensor,
    semiring: ChartSemiring,
) -> torch.Tensor:
    """Create and populate the initial chart.

    Parameters
    ----------
    input : torch.Tensor
        Raw input (e.g. token indices).
    semiring : ChartSemiring
        The scoring semiring.

    Returns
    -------
    torch.Tensor
        The chart tensor with initial items filled in.
    """
    ...

Deduction

Bases: Module

A weighted inference step in a deductive system.

A deduction reads from the chart, applies inference rules, and writes updated weights. This is a morphism in the V-enriched category of chart states.

forward abstractmethod

forward(chart: Tensor, semiring: ChartSemiring, **context) -> Tensor

Apply this deduction step.

PARAMETER DESCRIPTION
chart

Current chart state.

TYPE: Tensor

semiring

The scoring semiring.

TYPE: ChartSemiring

**context

Schedule-provided context (e.g. span_length, span_start, split for CKY).

DEFAULT: {}

RETURNS DESCRIPTION
Tensor

Updated chart.

Source code in src/quivers/stochastic/deduction/primitives.py
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
@abstractmethod
def forward(
    self,
    chart: torch.Tensor,
    semiring: ChartSemiring,
    **context,
) -> torch.Tensor:
    """Apply this deduction step.

    Parameters
    ----------
    chart : torch.Tensor
        Current chart state.
    semiring : ChartSemiring
        The scoring semiring.
    **context
        Schedule-provided context (e.g. ``span_length``,
        ``span_start``, ``split`` for CKY).

    Returns
    -------
    torch.Tensor
        Updated chart.
    """
    ...

DeductiveSystem

DeductiveSystem(axiom: Axiom, deductions: list[Deduction], goal: Goal, schedule: Schedule, semiring: ChartSemiring | None = None)

Bases: Module

A weighted deductive system evaluated to fixpoint.

This is the abstract core of all parsing algorithms::

input -> axiom -> schedule(deductions) -> goal -> output

The system is an nn.Module whose learnable parameters come from the axiom (e.g. lexical weights) and deductions (e.g. rule weights).

PARAMETER DESCRIPTION
axiom

Initial item population.

TYPE: Axiom

deductions

Inference rules.

TYPE: list[Deduction]

goal

Result extraction.

TYPE: Goal

schedule

Evaluation strategy.

TYPE: Schedule

semiring

Scoring algebra (default: LOG_PROB).

TYPE: ChartSemiring DEFAULT: None

Source code in src/quivers/stochastic/deduction/primitives.py
188
189
190
191
192
193
194
195
196
197
198
199
200
201
def __init__(
    self,
    axiom: Axiom,
    deductions: list[Deduction],
    goal: Goal,
    schedule: Schedule,
    semiring: ChartSemiring | None = None,
) -> None:
    super().__init__()
    self.axiom = axiom
    self.deductions = nn.ModuleList(deductions)
    self.goal = goal
    self._schedule = schedule
    self._semiring = semiring or LOG_PROB

semiring property

semiring: ChartSemiring

The scoring algebra.

forward

forward(input: Tensor) -> Tensor

Run the deductive system on input.

PARAMETER DESCRIPTION
input

Raw input (e.g. token indices).

TYPE: Tensor

RETURNS DESCRIPTION
Tensor

Goal item weights.

Source code in src/quivers/stochastic/deduction/primitives.py
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
def forward(self, input: torch.Tensor) -> torch.Tensor:
    """Run the deductive system on input.

    Parameters
    ----------
    input : torch.Tensor
        Raw input (e.g. token indices).

    Returns
    -------
    torch.Tensor
        Goal item weights.
    """
    chart = self.axiom(input, self._semiring)
    chart = self._schedule.run(chart, self.deductions, self._semiring)
    return self.goal(chart)

Goal

Bases: Module

Extract the result from a completed chart.

A goal identifies which chart items constitute the answer and extracts their weights.

forward abstractmethod

forward(chart: Tensor) -> Tensor

Extract goal items from the chart.

PARAMETER DESCRIPTION
chart

Completed chart.

TYPE: Tensor

RETURNS DESCRIPTION
Tensor

Goal item weights.

Source code in src/quivers/stochastic/deduction/primitives.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
@abstractmethod
def forward(self, chart: torch.Tensor) -> torch.Tensor:
    """Extract goal items from the chart.

    Parameters
    ----------
    chart : torch.Tensor
        Completed chart.

    Returns
    -------
    torch.Tensor
        Goal item weights.
    """
    ...

Schedule

Bases: ABC

Evaluation strategy for a deductive system.

Different schedules compute the same fixpoint in different orders. CKY processes spans bottom-up; an agenda schedule uses a priority queue. The schedule is independent of the deduction rules.

run abstractmethod

run(chart: Tensor, deductions: ModuleList, semiring: ChartSemiring) -> Tensor

Execute deductions on chart to fixpoint.

PARAMETER DESCRIPTION
chart

Chart with axioms filled in.

TYPE: Tensor

deductions

The deduction steps.

TYPE: ModuleList

semiring

The scoring semiring.

TYPE: ChartSemiring

RETURNS DESCRIPTION
Tensor

Completed chart.

Source code in src/quivers/stochastic/deduction/primitives.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
@abstractmethod
def run(
    self,
    chart: torch.Tensor,
    deductions: nn.ModuleList,
    semiring: ChartSemiring,
) -> torch.Tensor:
    """Execute deductions on chart to fixpoint.

    Parameters
    ----------
    chart : torch.Tensor
        Chart with axioms filled in.
    deductions : nn.ModuleList
        The deduction steps.
    semiring : ChartSemiring
        The scoring semiring.

    Returns
    -------
    torch.Tensor
        Completed chart.
    """
    ...

nuts_program_from_deduction

nuts_program_from_deduction(ded: DeductionSystem, corpus: Sequence[Sequence[str]], *, prior_scale: float = 1.0, site_prefix: str = 'log_w') -> tuple[MonadicProgram, Tensor, dict[str, Tensor]]

Lift a deduction system's learnable parameters to a MonadicProgram suitable for NUTS / SVI.

The returned program has one torch.distributions.Normal sample site per learnable parameter (lexicon entries and rule bindings alike) plus one score step that substitutes the sampled values into the deduction's parameter slots and adds :math:\sum_n \log Z(s_n; \mathbf{w}) to the joint.

PARAMETER DESCRIPTION
ded

Deduction whose parameters are lifted.

TYPE: DeductionSystem

corpus

Corpus the score step closes over.

TYPE: sequence of sentences

prior_scale

Standard deviation of the Normal prior on each parameter.

TYPE: float DEFAULT: 1.0

site_prefix

Stem of each sample-site's name (the parameter's path is appended for round-trip mapping).

TYPE: str DEFAULT: 'log_w'

RETURNS DESCRIPTION
(model, x, observations)

The lifted program plus a (1, 1) placeholder input and an empty observation dict, ready to feed to quivers.inference.MCMC .run.

Source code in src/quivers/stochastic/deduction/bayes.py
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
def nuts_program_from_deduction(
    ded: DeductionSystem,
    corpus: Sequence[Sequence[str]],
    *,
    prior_scale: float = 1.0,
    site_prefix: str = "log_w",
) -> tuple[MonadicProgram, torch.Tensor, dict[str, torch.Tensor]]:
    """Lift a deduction system's learnable parameters to a
    `MonadicProgram` suitable for NUTS / SVI.

    The returned program has one
    `torch.distributions.Normal` sample site per learnable
    parameter (lexicon entries and rule bindings alike) plus one
    score step that substitutes the sampled values into the
    deduction's parameter slots and adds
    :math:`\\sum_n \\log Z(s_n; \\mathbf{w})` to the joint.

    Parameters
    ----------
    ded : DeductionSystem
        Deduction whose parameters are lifted.
    corpus : sequence of sentences
        Corpus the score step closes over.
    prior_scale : float
        Standard deviation of the Normal prior on each parameter.
    site_prefix : str
        Stem of each sample-site's name (the parameter's path is
        appended for round-trip mapping).

    Returns
    -------
    (model, x, observations)
        The lifted program plus a ``(1, 1)`` placeholder input and
        an empty observation dict, ready to feed to
        [`quivers.inference.MCMC`][quivers.inference.MCMC] ``.run``.
    """
    materialise_parameters(ded, corpus)
    locator, paths, _ = build_locator(ded)
    if not paths:
        raise ValueError(
            "nuts_program_from_deduction: the deduction has no "
            "learnable parameters (neither lexicon nor rules are "
            "marked #[learnable])"
        )
    site_names: list[str] = []
    for path in paths:
        safe = path.replace("/", "__").replace(".", "_")
        site_names.append(f"{site_prefix}__{safe}")

    prior_morph = _make_normal_prior_morphism(prior_scale)
    steps: list[tuple] = [((site,), prior_morph, None) for site in site_names]

    def _score_fn(
        env: dict[str, torch.Tensor],
        _ded: DeductionSystem = ded,
        _corpus: list[list[str]] = [list(s) for s in corpus],
        _site_names: list[str] = list(site_names),
        _paths: list[str] = list(paths),
        _locator: Callable[[str], tuple[torch.nn.Module, str]] = locator,
    ) -> torch.Tensor:
        site_values = [env[name] for name in _site_names]
        batch = site_values[0].shape[0]
        out = torch.zeros(batch, dtype=torch.get_default_dtype())
        for b in range(batch):
            overrides = {path: v[b].reshape(()) for path, v in zip(_paths, site_values)}
            with _swap_named_parameters(_locator, overrides):
                log_z = torch.zeros(())
                for sentence in _corpus:
                    log_z = log_z + _ded(sentence).goal_weight()
                out[b] = log_z
        return out

    steps.append((("log_Z",), None, _score_fn, True))
    model = MonadicProgram(
        domain=Unit,
        codomain=Unit,
        steps=steps,
        return_vars=("log_Z",),
    )
    return model, torch.zeros(1, 1), {}

adam_fit_deduction

adam_fit_deduction(ded: DeductionSystem, corpus: Sequence[Sequence[str]], *, steps: int = 300, lr: float = 0.05, prior_scale: float | None = None) -> list[float]

Maximise the corpus log-marginal under an optional Normal prior on the parameters.

PARAMETER DESCRIPTION
ded

Deduction whose _axiom_module and _rule_module parameters are optimised.

TYPE: DeductionSystem

corpus

Each sentence is a sequence of token strings the deduction's axiom injector accepts.

TYPE: sequence of sentences

steps

Adam steps.

TYPE: int DEFAULT: 300

lr

Adam learning rate.

TYPE: float DEFAULT: 0.05

prior_scale

If supplied, adds a Gaussian regulariser :math:\tfrac{1}{2\sigma^2}\lVert \mathbf{w} \rVert^2 to the loss (MAP). Defaults to None (MLE).

TYPE: float DEFAULT: None

RETURNS DESCRIPTION
list[float]

The loss trajectory; length == steps.

Source code in src/quivers/stochastic/deduction/fit.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def adam_fit_deduction(
    ded: DeductionSystem,
    corpus: Sequence[Sequence[str]],
    *,
    steps: int = 300,
    lr: float = 5e-2,
    prior_scale: float | None = None,
) -> list[float]:
    """Maximise the corpus log-marginal under an optional Normal
    prior on the parameters.

    Parameters
    ----------
    ded : DeductionSystem
        Deduction whose ``_axiom_module`` and ``_rule_module``
        parameters are optimised.
    corpus : sequence of sentences
        Each sentence is a sequence of token strings the
        deduction's axiom injector accepts.
    steps : int
        Adam steps.
    lr : float
        Adam learning rate.
    prior_scale : float, optional
        If supplied, adds a Gaussian regulariser
        :math:`\\tfrac{1}{2\\sigma^2}\\lVert \\mathbf{w} \\rVert^2`
        to the loss (MAP). Defaults to ``None`` (MLE).

    Returns
    -------
    list[float]
        The loss trajectory; length == ``steps``.
    """
    materialise_parameters(ded, corpus)
    params = list(ded.parameters())
    if not params:
        return []
    optim = torch.optim.Adam(params, lr=lr)
    history: list[float] = []
    for _ in range(steps):
        optim.zero_grad()
        log_z = torch.zeros(())
        for sentence in corpus:
            log_z = log_z + ded(list(sentence)).goal_weight()
        loss = -log_z
        if prior_scale is not None:
            inv_var = 1.0 / (prior_scale**2)
            for p in params:
                loss = loss + 0.5 * inv_var * (p**2).sum()
        loss.backward()
        optim.step()
        history.append(float(loss.detach()))
    return history

sample_corpus

sample_corpus(ded: DeductionSystem, *, length: int, n_samples: int, seed: int | None = None) -> list[list[str]]

Sample n_samples yields of length length from the chart's length-conditional distribution under the deduction's current parameters.

PARAMETER DESCRIPTION
ded

The deduction with materialised parameters.

TYPE: DeductionSystem

length

Length of yields to enumerate.

TYPE: int

n_samples

Number of sentences to draw.

TYPE: int

seed

Seed for the multinomial draws.

TYPE: int DEFAULT: None

Source code in src/quivers/stochastic/deduction/sample.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def sample_corpus(
    ded: DeductionSystem,
    *,
    length: int,
    n_samples: int,
    seed: int | None = None,
) -> list[list[str]]:
    """Sample ``n_samples`` yields of length ``length`` from the
    chart's length-conditional distribution under the deduction's
    current parameters.

    Parameters
    ----------
    ded : DeductionSystem
        The deduction with materialised parameters.
    length : int
        Length of yields to enumerate.
    n_samples : int
        Number of sentences to draw.
    seed : int, optional
        Seed for the multinomial draws.
    """
    vocab = _vocabulary(ded)
    if not vocab:
        raise ValueError(
            "sample_corpus: cannot determine the deduction's "
            "vocabulary; set ``ded._vocabulary`` explicitly or "
            "call ``materialise_parameters`` first"
        )

    yields: list[list[str]] = []
    log_weights: list[torch.Tensor] = []
    for combo in itertools.product(vocab, repeat=length):
        chart = ded(list(combo))
        w = chart.goal_weight()
        if torch.isfinite(w):
            yields.append(list(combo))
            log_weights.append(w)
    if not yields:
        raise ValueError(
            f"sample_corpus: no yield of length {length} parses under "
            f"the deduction's current parameters"
        )
    logw = torch.stack([w.detach() for w in log_weights])
    probs = torch.softmax(logw, dim=0)
    gen = torch.Generator()
    if seed is not None:
        gen.manual_seed(seed)
    idxs = torch.multinomial(probs, n_samples, replacement=True, generator=gen)
    return [yields[int(i.item())] for i in idxs]

bayesian_lift_parameters

bayesian_lift_parameters(inner_model: Module, x: Tensor, observations: dict[str, Tensor], *, prior_scale: float = 1.0, site_prefix: str = 'theta', additional_latents: dict[str, tuple[int, ...]] | None = None, latent_placeholder_scale: float = 10.0) -> tuple[MonadicProgram, Tensor, dict[str, Tensor]]

Lift every learnable parameter of inner_model into a Normal-prior sample site, and optionally lift named latent sites of the inner program into NUTS-sampleable variables.

Mathematics

Let :math:\theta denote the inner model's learnable parameters and :math:\mathbf{z} an optional collection of intermediate latents named in additional_latents. The target joint posterior is

.. math:: p(\theta, \mathbf{z} \mid x, y) \;\propto\; p(\theta) \, p_{\mathrm{inner}}(\mathbf{z}, y \mid x, \theta).

The lifted program declares

  • one Normal sample site :math:\theta_i \sim \mathcal{N}(0, \sigma_\theta^2) per parameter (prior_scale);
  • one Normal sample site per latent in additional_latents with a placeholder scale :math:\sigma_z (latent_placeholder_scale); and
  • one score step that, after substituting :math:\theta into the inner's parameter slots, computes

.. math:: \log p_{\mathrm{inner}}(\mathbf{z}, y \mid x, \theta) \;-\; \sum_{z \in \text{latents}} \log \mathcal{N}(z; 0, \sigma_z^2).

The placeholder priors on :math:\mathbf{z} cancel exactly, so the lifted log-density equals :math:\log p(\theta) + \log p_{\mathrm{inner}}(\mathbf{z}, y \mid x, \theta) pointwise. NUTS therefore samples :math:(\theta, \mathbf{z}) from the exact joint posterior, and the log-density is deterministic given the full state (no MC noise across leapfrog steps).

Methodological notes

Why a Normal prior on parameters?

  1. Maximum entropy. Among all distributions on :math:\mathbb{R}^n with finite variance, :math:\mathcal{N}(0, \sigma_\theta^2) is the maximum-entropy choice. Among priors that admit any second moment at all, it is the least informative.
  2. Equivalence to weight decay. A :math:\mathcal{N}(0, \sigma_\theta^2) prior is the MAP equivalent of L2 regularization with coefficient :math:1/(2\sigma_\theta^2). Standard frequentist weight decay and gradient-descent training inherit a direct Bayesian reading under this prior.
  3. Computational fit with NUTS. The unconstrained support :math:\mathbb{R}^n matches NUTS's native state space, so no bijector is needed between latent and prior support. The log-density is smooth everywhere, so leapfrog dynamics are well-behaved.

Assumptions the user must respect.

  1. Parameters must be unconstrained reals. If a learnable represents a variance, rate, probability, or simplex component, a Normal prior is mathematically invalid (it places mass on the forbidden region). Models must use the unconstrained parameterization (log-scale, logit-p, log-rate, soft-max logits, etc.). In QVR's standard model definitions, distribution families are parameterized in this way and torch.nn.Parameter\ s are unconstrained reals.
  2. The default prior_scale=1.0 assumes O(1) parameter magnitude. This is consistent with typical neural-network initialization schemes (Xavier, Kaiming). Override prior_scale for models with very different expected parameter magnitudes.
  3. A Normal prior is generic, not informed. The lift is a one-size-fits-all wrapper. Users with substantive domain knowledge about :math:\theta should write a program block with explicit sample priors instead of relying on the lift.

Why a placeholder Normal on lifted latents?

The placeholder prior on each :math:\mathbf{z}_i \sim \mathcal{N}(0, \sigma_z^2) is algebraically irrelevant by construction: the lifted sample-site prior and the placeholder cancellation in the score step sum to zero pointwise. The target distribution NUTS samples is the true joint posterior regardless of :math:\sigma_z. A placeholder exists at all because NUTS's LatentRegistry enumerates dimensions from declared sample sites; each site needs a base measure to define the unconstrained support and to seed mass-matrix and step-size adaptation during warmup. Normal is the standard choice for an unconstrained latent.

:math:\sigma_z affects mixing efficiency, not correctness. A placeholder scale mismatched to the true posterior scale of :math:\mathbf{z} lengthens warmup (the mass matrix has to adapt further) without biasing the chain. The default latent_placeholder_scale=10.0 is large enough that initial NUTS proposals span a meaningful neighbourhood of zero, small enough that they do not immediately diverge.

PARAMETER DESCRIPTION
inner_model

Module exposing log_joint(x, observations) -> Tensor.

TYPE: Module

x

Passed straight through to the inner's log_joint. When additional_latents is supplied, observations must not contain those keys; they are supplied per NUTS evaluation from the env.

TYPE: (tensor, dict)

observations

Passed straight through to the inner's log_joint. When additional_latents is supplied, observations must not contain those keys; they are supplied per NUTS evaluation from the env.

TYPE: (tensor, dict)

prior_scale

:math:\sigma_\theta. Standard deviation of the Normal prior on each parameter site.

TYPE: float DEFAULT: 1.0

site_prefix

Stem of each parameter sample-site's name.

TYPE: str DEFAULT: 'theta'

additional_latents

Mapping from intermediate-latent site name (a key the inner's log_joint expects in its observations dict) to the latent's tensor shape (without the batch dim). When None (the default), the lift is parameter-only and inner.log_joint must accept observations as passed in.

TYPE: dict[str, tuple[int, ...]] | None DEFAULT: None

latent_placeholder_scale

:math:\sigma_z. Standard deviation of the placeholder Normal prior on each lifted latent. Any positive value works (it cancels exactly in the score step); a moderate value keeps NUTS's initial proposal magnitudes in a reasonable range.

TYPE: float DEFAULT: 10.0

RETURNS DESCRIPTION
(model, x_, observations_)

The lifted MonadicProgram and the placeholder input / empty observation dict the inference layer feeds it.

Source code in src/quivers/inference/lifts.py
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
def bayesian_lift_parameters(
    inner_model: nn.Module,
    x: torch.Tensor,
    observations: dict[str, torch.Tensor],
    *,
    prior_scale: float = 1.0,
    site_prefix: str = "theta",
    additional_latents: dict[str, tuple[int, ...]] | None = None,
    latent_placeholder_scale: float = 10.0,
) -> tuple[MonadicProgram, torch.Tensor, dict[str, torch.Tensor]]:
    """Lift every learnable parameter of ``inner_model`` into a
    Normal-prior sample site, and optionally lift named latent
    sites of the inner program into NUTS-sampleable variables.

    Mathematics
    -----------
    Let :math:`\\theta` denote the inner model's learnable
    parameters and :math:`\\mathbf{z}` an optional collection of
    intermediate latents named in ``additional_latents``. The
    target joint posterior is

    .. math::
        p(\\theta, \\mathbf{z} \\mid x, y) \\;\\propto\\;
        p(\\theta) \\, p_{\\mathrm{inner}}(\\mathbf{z}, y \\mid x, \\theta).

    The lifted program declares

    * one Normal sample site
      :math:`\\theta_i \\sim \\mathcal{N}(0, \\sigma_\\theta^2)` per
      parameter (``prior_scale``);
    * one Normal sample site per latent in
      ``additional_latents`` with a *placeholder* scale
      :math:`\\sigma_z` (``latent_placeholder_scale``); and
    * one score step that, after substituting :math:`\\theta`
      into the inner's parameter slots, computes

      .. math::
          \\log p_{\\mathrm{inner}}(\\mathbf{z}, y \\mid x, \\theta)
          \\;-\\; \\sum_{z \\in \\text{latents}} \\log \\mathcal{N}(z; 0, \\sigma_z^2).

    The placeholder priors on :math:`\\mathbf{z}` cancel exactly,
    so the lifted log-density equals
    :math:`\\log p(\\theta) + \\log p_{\\mathrm{inner}}(\\mathbf{z}, y \\mid x, \\theta)`
    pointwise. NUTS therefore samples
    :math:`(\\theta, \\mathbf{z})` from the exact joint posterior,
    and the log-density is deterministic given the full state
    (no MC noise across leapfrog steps).

    Methodological notes
    --------------------
    *Why a Normal prior on parameters?*

    1. **Maximum entropy.** Among all distributions on
       :math:`\\mathbb{R}^n` with finite variance,
       :math:`\\mathcal{N}(0, \\sigma_\\theta^2)` is the
       maximum-entropy choice. Among priors that admit any second
       moment at all, it is the least informative.
    2. **Equivalence to weight decay.** A
       :math:`\\mathcal{N}(0, \\sigma_\\theta^2)` prior is the MAP
       equivalent of L2 regularization with coefficient
       :math:`1/(2\\sigma_\\theta^2)`. Standard frequentist weight
       decay and gradient-descent training inherit a direct
       Bayesian reading under this prior.
    3. **Computational fit with NUTS.** The unconstrained support
       :math:`\\mathbb{R}^n` matches NUTS's native state space, so
       no bijector is needed between latent and prior support.
       The log-density is smooth everywhere, so leapfrog dynamics
       are well-behaved.

    *Assumptions the user must respect.*

    1. **Parameters must be unconstrained reals.** If a learnable
       represents a variance, rate, probability, or simplex
       component, a Normal prior is mathematically invalid (it
       places mass on the forbidden region). Models must use the
       unconstrained parameterization (log-scale, logit-p,
       log-rate, soft-max logits, etc.). In QVR's standard model
       definitions, distribution families are parameterized in
       this way and `torch.nn.Parameter`\\ s are
       unconstrained reals.
    2. **The default ``prior_scale=1.0`` assumes O(1) parameter
       magnitude.** This is consistent with typical neural-network
       initialization schemes (Xavier, Kaiming). Override
       ``prior_scale`` for models with very different expected
       parameter magnitudes.
    3. **A Normal prior is generic, not informed.** The lift is a
       one-size-fits-all wrapper. Users with substantive domain
       knowledge about :math:`\\theta` should write a ``program``
       block with explicit ``sample`` priors instead of relying on
       the lift.

    *Why a placeholder Normal on lifted latents?*

    The placeholder prior on each
    :math:`\\mathbf{z}_i \\sim \\mathcal{N}(0, \\sigma_z^2)` is
    *algebraically irrelevant* by construction: the lifted
    sample-site prior and the placeholder cancellation in the
    score step sum to zero pointwise. The target distribution NUTS
    samples is the true joint posterior regardless of
    :math:`\\sigma_z`. A placeholder exists at all because NUTS's
    `LatentRegistry` enumerates dimensions from declared
    sample sites; each site needs a base measure to define the
    unconstrained support and to seed mass-matrix and step-size
    adaptation during warmup. Normal is the standard choice for an
    unconstrained latent.

    :math:`\\sigma_z` affects *mixing efficiency*, not
    *correctness*. A placeholder scale mismatched to the true
    posterior scale of :math:`\\mathbf{z}` lengthens warmup (the
    mass matrix has to adapt further) without biasing the chain.
    The default ``latent_placeholder_scale=10.0`` is large enough
    that initial NUTS proposals span a meaningful neighbourhood of
    zero, small enough that they do not immediately diverge.

    Parameters
    ----------
    inner_model : nn.Module
        Module exposing ``log_joint(x, observations) -> Tensor``.
    x, observations : tensor, dict
        Passed straight through to the inner's ``log_joint``. When
        ``additional_latents`` is supplied, ``observations`` must
        not contain those keys; they are supplied per NUTS
        evaluation from the env.
    prior_scale : float
        :math:`\\sigma_\\theta`. Standard deviation of the Normal
        prior on each parameter site.
    site_prefix : str
        Stem of each parameter sample-site's name.
    additional_latents : dict[str, tuple[int, ...]] | None
        Mapping from intermediate-latent site name (a key the
        inner's ``log_joint`` expects in its observations dict)
        to the latent's tensor shape (without the batch dim).
        When ``None`` (the default), the lift is parameter-only
        and ``inner.log_joint`` must accept ``observations`` as
        passed in.
    latent_placeholder_scale : float
        :math:`\\sigma_z`. Standard deviation of the placeholder
        Normal prior on each lifted latent. Any positive value
        works (it cancels exactly in the score step); a moderate
        value keeps NUTS's initial proposal magnitudes in a
        reasonable range.

    Returns
    -------
    (model, x_, observations_)
        The lifted ``MonadicProgram`` and the placeholder
        input / empty observation dict the inference layer feeds
        it.
    """
    if not hasattr(inner_model, "log_joint"):
        raise ValueError(
            "bayesian_lift_parameters: inner_model must expose "
            "``log_joint(x, observations)``"
        )
    if additional_latents and any(k in observations for k in additional_latents):
        bad = sorted(k for k in additional_latents if k in observations)
        raise ValueError(
            "bayesian_lift_parameters: keys "
            f"{bad!r} appear in both ``observations`` and "
            "``additional_latents``; a latent cannot be both "
            "observed and sampled"
        )

    inner_log_joint = inner_model.log_joint

    params = list(inner_model.named_parameters())
    if not params and not additional_latents:
        raise ValueError(
            "bayesian_lift_parameters: inner_model has no parameters "
            "and no additional_latents were declared"
        )

    safe_paths = [n.replace(".", "_") for n, _ in params]
    param_sites = [f"{site_prefix}__{p}" for p in safe_paths]
    locator, _paths, _ = _build_locator(inner_model)
    flat_dims = [int(p.numel()) for _, p in params]
    param_shapes = [tuple(p.shape) for _, p in params]
    param_paths = [n for n, _ in params]

    latent_names: list[str] = []
    latent_shapes: list[tuple[int, ...]] = []
    latent_flat_dims: list[int] = []
    latent_sites: list[str] = []
    if additional_latents:
        for name, shape in additional_latents.items():
            latent_names.append(name)
            latent_shapes.append(tuple(int(s) for s in shape))
            flat = 1
            for s in shape:
                flat *= int(s)
            latent_flat_dims.append(flat)
            latent_sites.append(f"latent__{name}")

    steps: list[tuple] = []
    for site, dim in zip(param_sites, flat_dims):
        steps.append(
            (
                (site,),
                _make_normal_prior_morphism(prior_scale, dim=dim),
                None,
            )
        )
    for site, dim in zip(latent_sites, latent_flat_dims):
        steps.append(
            (
                (site,),
                _make_normal_prior_morphism(latent_placeholder_scale, dim=dim),
                None,
            )
        )

    placeholder_scale_f = float(latent_placeholder_scale)

    def _score_fn(
        env: dict[str, torch.Tensor],
        _param_paths: list[str] = list(param_paths),
        _param_shapes: list[tuple[int, ...]] = list(param_shapes),
        _param_sites: list[str] = list(param_sites),
        _latent_names: list[str] = list(latent_names),
        _latent_shapes: list[tuple[int, ...]] = list(latent_shapes),
        _latent_sites: list[str] = list(latent_sites),
        _locator: Callable[[str], tuple[nn.Module, str]] = locator,
        _inner_x: torch.Tensor = x,
        _inner_obs: dict[str, torch.Tensor] = dict(observations),
        _placeholder_scale: float = placeholder_scale_f,
    ) -> torch.Tensor:
        ref = env[_param_sites[0]] if _param_sites else env[_latent_sites[0]]
        batch = ref.shape[0]
        out = torch.zeros(batch, dtype=torch.get_default_dtype())
        for b in range(batch):
            overrides: dict[str, torch.Tensor] = {}
            for path, shape, name in zip(_param_paths, _param_shapes, _param_sites):
                v = env[name][b]
                overrides[path] = v.reshape(shape) if shape else v.reshape(())
            latent_dict: dict[str, torch.Tensor] = {}
            placeholder_log_prior = torch.zeros((), dtype=torch.get_default_dtype())
            for name, shape, site in zip(_latent_names, _latent_shapes, _latent_sites):
                flat_b = env[site][b]
                latent_dict[name] = (
                    flat_b.reshape(shape) if shape else flat_b.reshape(())
                )
                placeholder_log_prior = (
                    placeholder_log_prior
                    + D.Normal(
                        torch.zeros_like(flat_b),
                        torch.full_like(flat_b, _placeholder_scale),
                    )
                    .log_prob(flat_b)
                    .sum()
                )
            with _swap_named_parameters(_locator, overrides):
                merged_obs = {**_inner_obs, **latent_dict}
                ll = inner_log_joint(_inner_x, merged_obs)
                ll_b = ll.sum() if ll.dim() > 0 else ll
                # Cancel the placeholder priors on lifted latents
                # so the net log-density is exactly
                # ``log p(theta) + log p_inner(z, y | x, theta)``.
                out[b] = ll_b - placeholder_log_prior
        return out

    steps.append((("log_lik",), None, _score_fn, True))
    lifted = MonadicProgram(
        domain=Unit,
        codomain=Unit,
        steps=steps,
        return_vars=("log_lik",),
    )
    return lifted, torch.zeros(1, 1), {}