Inside Algorithm

CKY inside algorithm for differentiable PCFG parsing. Computes sentence log-probabilities by dynamic programming over all parse trees.

inside

CKY inside algorithm for probabilistic context-free grammars.

Implements the inside algorithm as a differentiable PyTorch module, enabling end-to-end gradient-based learning of PCFG parameters.

A PCFG is specified by two stochastic morphisms:

  • binary : N -> N * N — binary production probabilities. For each nonterminal A, binary[A, B, C] is the probability of the rule A -> B C.
  • lexical : N -> T — terminal production probabilities. For each nonterminal A, lexical[A, t] is the probability of the rule A -> t.

The inside algorithm computes, for each nonterminal A and span (i, j) of the input sentence:

beta(A, i, j) = P(w_i ... w_{j-1} | A)

The sentence log-probability is log beta(start, 0, L) where start is the start symbol index (default 0).

All computation is done in log-space for numerical stability, using logsumexp for marginalization. This preserves gradient flow for learning rule probabilities end-to-end.

Categorical perspective

The inside algorithm implements a morphism

inside(binary, lexical) : FreeMonoid(T) -> 1

that maps variable-length terminal strings to their probability under the grammar. This is the counit of the adjunction between the free monad on the polynomial functor induced by the grammar and the forgetful functor to strings.

Examples:

>>> from quivers.core.objects import FinSet, ProductSet
>>> from quivers.stochastic.morphisms import StochasticMorphism
>>> N = FinSet(name="N", cardinality=5)
>>> T = FinSet(name="T", cardinality=10)
>>> binary = StochasticMorphism(N, ProductSet(N, N))
>>> lexical = StochasticMorphism(N, T)
>>> cky = InsideAlgorithm(binary, lexical, start=0)
>>> tokens = torch.randint(0, 10, (4, 6))  # batch=4, length=6
>>> log_probs = cky(tokens)  # (4,)

InsideAlgorithm

InsideAlgorithm(binary: Morphism, lexical: Morphism, start: int = 0)

Bases: Module

CKY inside algorithm for differentiable PCFG parsing.

Computes sentence log-probabilities under a PCFG defined by binary and lexical production rules, both expressed as stochastic morphisms.

PARAMETER DESCRIPTION
binary

Binary production rules. Must be a morphism N -> N * N where N is a finite set of nonterminals. The tensor has shape (|N|, |N|, |N|) with binary[A, B, C] = P(A -> B C).

TYPE: Morphism

lexical

Lexical (terminal) production rules. Must be a morphism N -> T where T is a finite set of terminals. The tensor has shape (|N|, |T|) with lexical[A, t] = P(A -> t).

TYPE: Morphism

start

Index of the start symbol in N (default 0).

TYPE: int DEFAULT: 0

RAISES DESCRIPTION
TypeError

If the morphisms have incompatible types.

Source code in src/quivers/stochastic/inside.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
def __init__(self, binary: Morphism, lexical: Morphism, start: int = 0) -> None:
    super().__init__()
    if not isinstance(binary.codomain, ProductSet):
        raise TypeError(
            f"binary morphism codomain must be a ProductSet, got {binary.codomain!r}"
        )
    if binary.domain != lexical.domain:
        raise TypeError(
            f"binary and lexical must share the same domain (nonterminals), got {binary.domain!r} and {lexical.domain!r}"
        )
    self._binary = binary
    self._lexical = lexical
    self._start = start
    self._n_nonterm = binary.domain.size
    self._n_term = lexical.codomain.size
    self._binary_mod = binary.module()
    self._lexical_mod = lexical.module()

n_nonterminals property

n_nonterminals: int

Number of nonterminal symbols.

n_terminals property

n_terminals: int

Number of terminal symbols.

start property

start: int

Index of the start symbol.

forward

forward(tokens: Tensor) -> Tensor

Compute sentence log-probabilities via the inside algorithm.

PARAMETER DESCRIPTION
tokens

Integer tensor of terminal indices. Shape (batch, seq_len) or (seq_len,) for a single sentence.

TYPE: Tensor

RETURNS DESCRIPTION
Tensor

Log-probability of each sentence under the grammar. Shape (batch,) or scalar for a single sentence.

Source code in src/quivers/stochastic/inside.py
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
    """Compute sentence log-probabilities via the inside algorithm.

    Parameters
    ----------
    tokens : torch.Tensor
        Integer tensor of terminal indices. Shape
        ``(batch, seq_len)`` or ``(seq_len,)`` for a single
        sentence.

    Returns
    -------
    torch.Tensor
        Log-probability of each sentence under the grammar.
        Shape ``(batch,)`` or scalar for a single sentence.
    """
    squeeze = False
    if tokens.dim() == 1:
        tokens = tokens.unsqueeze(0)
        squeeze = True
    if tokens.shape[1] == 0:
        raise ValueError("cannot parse empty sentences")
    chart = self._fill_chart(tokens)
    result = chart[:, self._start, 0, tokens.shape[1]]
    if squeeze:
        return result.squeeze(0)
    return result

inside_chart

inside_chart(tokens: Tensor) -> Tensor

Compute the full inside chart (for analysis/debugging).

PARAMETER DESCRIPTION
tokens

Integer tensor of terminal indices. Shape (batch, seq_len) or (seq_len,).

TYPE: Tensor

RETURNS DESCRIPTION
Tensor

The full inside chart in log-space. Shape (batch, N, seq_len, seq_len+1) where entry [b, A, i, j] is log P(w_i..w_{j-1} | A).

Source code in src/quivers/stochastic/inside.py
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
def inside_chart(self, tokens: torch.Tensor) -> torch.Tensor:
    """Compute the full inside chart (for analysis/debugging).

    Parameters
    ----------
    tokens : torch.Tensor
        Integer tensor of terminal indices. Shape
        ``(batch, seq_len)`` or ``(seq_len,)``.

    Returns
    -------
    torch.Tensor
        The full inside chart in log-space. Shape
        ``(batch, N, seq_len, seq_len+1)`` where entry
        ``[b, A, i, j]`` is ``log P(w_i..w_{j-1} | A)``.
    """
    squeeze = False
    if tokens.dim() == 1:
        tokens = tokens.unsqueeze(0)
        squeeze = True
    if tokens.shape[1] == 0:
        raise ValueError("cannot parse empty sentences")
    chart = self._fill_chart(tokens)
    if squeeze:
        return chart.squeeze(0)
    return chart