Bayesian Wrap

nuts_program_from_deduction lifts the deduction's learnable log-weights into a MonadicProgram whose joint log-density is \(-\tfrac{1}{2\sigma^2}\lVert \mathbf{w} \rVert^2 + \sum_n \log Z(s_n; \mathbf{w})\), ready for MCMC.

The sampler targets exactly that joint with a deterministic log-density and exact gradients. Whether the joint is the Bayesian posterior \(p(\mathbf{w} \mid S)\) depends on the modelling reading (CRF / globally normalised vs. PCFG / locally normalised); see the module docstring for the precise statement and the cancellation condition.

bayes

Bayesian posterior wrapping for weighted deduction systems.

nuts_program_from_deduction lifts the deduction's learnable log-weights into a quivers.continuous.programs.MonadicProgram whose log_joint is :math:-\tfrac{1}{2\sigma^2}\lVert \mathbf{w} \rVert^2 + \sum_n \log Z(s_n; \mathbf{w}). The resulting program is ready for quivers.inference.MCMC with quivers.inference.NUTSKernel.

Modelling note

The sampler targets exactly :math:\pi(\mathbf{w}) \propto \exp(-\lVert \mathbf{w} \rVert^2/(2\sigma^2) + \sum_n \log Z(s_n; \mathbf{w})) with a deterministic log-density and exact gradients. Whether that joint is the Bayesian posterior :math:p(\mathbf{w} \mid S) depends on the modelling reading:

  • Undirected / globally-normalised (CRF / log-linear / energy-based): :math:\pi(\mathbf{w}) is the posterior; the implementation is exact.
  • Directed / locally-normalised PCFG: the true sentence likelihood is :math:Z(s; \mathbf{w}) / \sum_{s'} Z(s'; \mathbf{w}); the global normaliser depends on :math:\mathbf{w} and is intractable. The sampler then targets a pseudo-posterior differing from the true posterior by a factor of :math:\bigl(\sum_{s'} Z(s'; \mathbf{w})\bigr)^{-N}. Users committed to this reading should constrain rule weights to local simplices via a Dirichlet + softmax surface rather than the free-parameter Normal lift this function provides.

nuts_program_from_deduction

nuts_program_from_deduction(ded: DeductionSystem, corpus: Sequence[Sequence[str]], *, prior_scale: float = 1.0, site_prefix: str = 'log_w') -> tuple[MonadicProgram, Tensor, dict[str, Tensor]]

Lift a deduction system's learnable parameters to a MonadicProgram suitable for NUTS / SVI.

The returned program has one torch.distributions.Normal sample site per learnable parameter (lexicon entries and rule bindings alike) plus one score step that substitutes the sampled values into the deduction's parameter slots and adds :math:\sum_n \log Z(s_n; \mathbf{w}) to the joint.

PARAMETER DESCRIPTION
ded

Deduction whose parameters are lifted.

TYPE: DeductionSystem

corpus

Corpus the score step closes over.

TYPE: sequence of sentences

prior_scale

Standard deviation of the Normal prior on each parameter.

TYPE: float DEFAULT: 1.0

site_prefix

Stem of each sample-site's name (the parameter's path is appended for round-trip mapping).

TYPE: str DEFAULT: 'log_w'

RETURNS DESCRIPTION
(model, x, observations)

The lifted program plus a (1, 1) placeholder input and an empty observation dict, ready to feed to quivers.inference.MCMC .run.

Source code in src/quivers/stochastic/deduction/bayes.py
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
def nuts_program_from_deduction(
    ded: DeductionSystem,
    corpus: Sequence[Sequence[str]],
    *,
    prior_scale: float = 1.0,
    site_prefix: str = "log_w",
) -> tuple[MonadicProgram, torch.Tensor, dict[str, torch.Tensor]]:
    """Lift a deduction system's learnable parameters to a
    `MonadicProgram` suitable for NUTS / SVI.

    The returned program has one
    `torch.distributions.Normal` sample site per learnable
    parameter (lexicon entries and rule bindings alike) plus one
    score step that substitutes the sampled values into the
    deduction's parameter slots and adds
    :math:`\\sum_n \\log Z(s_n; \\mathbf{w})` to the joint.

    Parameters
    ----------
    ded : DeductionSystem
        Deduction whose parameters are lifted.
    corpus : sequence of sentences
        Corpus the score step closes over.
    prior_scale : float
        Standard deviation of the Normal prior on each parameter.
    site_prefix : str
        Stem of each sample-site's name (the parameter's path is
        appended for round-trip mapping).

    Returns
    -------
    (model, x, observations)
        The lifted program plus a ``(1, 1)`` placeholder input and
        an empty observation dict, ready to feed to
        [`quivers.inference.MCMC`][quivers.inference.MCMC] ``.run``.
    """
    materialise_parameters(ded, corpus)
    locator, paths, _ = build_locator(ded)
    if not paths:
        raise ValueError(
            "nuts_program_from_deduction: the deduction has no "
            "learnable parameters (neither lexicon nor rules are "
            "marked #[learnable])"
        )
    site_names: list[str] = []
    for path in paths:
        safe = path.replace("/", "__").replace(".", "_")
        site_names.append(f"{site_prefix}__{safe}")

    prior_morph = _make_normal_prior_morphism(prior_scale)
    steps: list[tuple] = [((site,), prior_morph, None) for site in site_names]

    def _score_fn(
        env: dict[str, torch.Tensor],
        _ded: DeductionSystem = ded,
        _corpus: list[list[str]] = [list(s) for s in corpus],
        _site_names: list[str] = list(site_names),
        _paths: list[str] = list(paths),
        _locator: Callable[[str], tuple[torch.nn.Module, str]] = locator,
    ) -> torch.Tensor:
        site_values = [env[name] for name in _site_names]
        batch = site_values[0].shape[0]
        out = torch.zeros(batch, dtype=torch.get_default_dtype())
        for b in range(batch):
            overrides = {path: v[b].reshape(()) for path, v in zip(_paths, site_values)}
            with _swap_named_parameters(_locator, overrides):
                log_z = torch.zeros(())
                for sentence in _corpus:
                    log_z = log_z + _ded(sentence).goal_weight()
                out[b] = log_z
        return out

    steps.append((("log_Z",), None, _score_fn, True))
    model = MonadicProgram(
        domain=Unit,
        codomain=Unit,
        steps=steps,
        return_vars=("log_Z",),
    )
    return model, torch.zeros(1, 1), {}