Continuous Programs

Probabilistic programs in continuous domains.

programs

Monadic programs: sequenced probabilistic programs as ContinuousMorphisms.

A MonadicProgram defines a ContinuousMorphism via monadic sequencing of draw steps. Each step samples from a named morphism, optionally conditioned on previously drawn variables, and binds the result. The program returns one or more of the bound variables as its output.

This corresponds to the Kleisli composition pattern used in probabilistic programming languages like PDS (Grove & White), where a sequence of let' x ~ D in ... bindings threads probabilistic state through a generative model.

Features
  • Single and tuple returns: return x or return (x, y, z)
  • Named input parameters for product-domain sub-programs
  • Multi-argument draw steps: draw z ~ f(x, y)
  • Destructuring draws from tuple-returning sub-programs: draw (a, b) ~ sub_prog(x)
Example

Given morphisms f : A -> B and g : B -> C, the monadic program::

program p : A -> C
    draw x ~ f
    draw y ~ g(x)
    return y

is equivalent to the composition f >> g, but the program form allows fan-out (using the input in multiple draws) and non-linear variable dependency graphs.

PDS-style nested programs::

program cg_update(y, z) : Belief * Belief -> Truth * Truth
    draw c ~ bern_c(y)
    draw d ~ bern_d(z)
    return (c, d)

program factivityPrior : Entity -> Truth * Truth * Truth
    draw x ~ prior_x
    draw y ~ prior_y
    draw z ~ prior_z
    draw b ~ bern_b(x)
    draw (c, d) ~ cg_update(y, z)
    return (b, c, d)

MonadicProgram

MonadicProgram(domain: AnySpace, codomain: AnySpace, steps: list[tuple], return_vars: tuple[str, ...], params: tuple[str, ...] | None = None, return_labels: tuple[str, ...] | None = None, effect_set: frozenset[str] | None = None)

Bases: ContinuousMorphism

A probabilistic program defined by monadic sequencing of draw steps.

Each draw step samples from a ContinuousMorphism and binds the result to one or more named variables. Later steps can reference earlier bindings as their input. The program's output is the value(s) of the designated return variable(s).

PARAMETER DESCRIPTION
domain

The program's input space.

TYPE: SetObject or ContinuousSpace

codomain

The program's output space.

TYPE: SetObject or ContinuousSpace

steps

Each entry is either (var_names, morphism, arg_names) for draw steps, or (var_names, None, value) for let bindings where value is a float constant or str variable reference.

TYPE: list[tuple]

return_vars

Name(s) of the bound variable(s) whose value(s) are the program output.

TYPE: tuple[str, ...]

params

Named input parameters for product-domain programs. When set, the program input is split along the feature dimension and each component is pre-bound in the env.

TYPE: tuple[str, ...] or None DEFAULT: None

return_labels

Optional labels for tuple return fields. When set, the output dict uses these labels as keys instead of the variable names. Length must match return_vars.

TYPE: tuple[str, ...] or None DEFAULT: None

Source code in src/quivers/continuous/programs.py
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
def __init__(
    self,
    domain: AnySpace,
    codomain: AnySpace,
    steps: list[tuple],
    return_vars: tuple[str, ...],
    params: tuple[str, ...] | None = None,
    return_labels: tuple[str, ...] | None = None,
    effect_set: frozenset[str] | None = None,
) -> None:
    super().__init__(domain, codomain)
    self._return_vars = return_vars
    self._return_is_single = len(return_vars) == 1
    self._params = params
    self._return_labels = return_labels
    # The v0.5 effect-row annotation. None when unannotated;
    # otherwise carries the declared capability set
    # (Sample / Score / Marginal / Pure) for introspection by
    # downstream inference / dispatch code.
    self.effect_set: frozenset[str] | None = effect_set
    self._step_specs: list[_StepSpec | _LetSpec | _ScoreSpec] = []

    # compute input component dimensions for param splitting
    if params is not None and len(params) > 1:
        self._param_dims = self._compute_component_dims(domain)
        self._param_is_continuous = self._compute_component_continuous(domain)

    else:
        self._param_dims = None
        self._param_is_continuous = None

    # register each morphism as a named submodule so parameters
    # are visible to optimizers; let bindings become _LetSpec
    for step in steps:
        # The fourth tuple element is `is_observed` for draw steps
        # (when morph is not None) and `is_score` for let steps
        # (when morph is None). The two flags are mutually
        # exclusive by the morph-None partition, so the slot is
        # reused without ambiguity.
        if len(step) == 4:
            var_names, morph, arg_or_value, fourth_flag = step

        else:
            var_names, morph, arg_or_value = step
            fourth_flag = False

        if morph is None:
            # let-or-score binding: arg_or_value is float | str | callable.
            # A True `is_score` flag (only meaningful when the value is a
            # callable) routes the result through `total += score(env)`
            # in `log_joint` instead of contributing 0.
            if fourth_flag is True and callable(arg_or_value):
                self._step_specs.append(_ScoreSpec(var_names[0], arg_or_value))

            else:
                self._step_specs.append(_LetSpec(var_names[0], arg_or_value))

        else:
            is_observed = fourth_flag
            key = f"_step_{var_names[0]}"
            from quivers.core.morphisms import as_torch_module

            wrapped = as_torch_module(morph)
            self.add_module(key, wrapped)
            # The runtime accesses ``morph`` via
            # ``self._modules[key]`` and expects an object whose
            # ``rsample`` / ``log_prob`` methods are callable.
            # When ``morph`` was already an `nn.Module`
            # (a ContinuousMorphism), the wrapped value is the
            # morphism itself and those methods are available.
            # When ``morph`` is a backend-agnostic
            # `Morphism` (e.g. a ComposedMorphism from
            # the V-Cat hierarchy), the wrapper exposes the
            # morphism's parameters; the categorical object is
            # attached as ``wrapped._morphism`` so a runtime
            # path that needs it can recover via
            # `extract_morphism`.
            self._step_specs.append(
                _StepSpec(var_names, key, arg_or_value, is_observed)
            )

observed_names property

observed_names: set[str]

Return the set of variable names marked as observed in the DSL.

rsample

rsample(x: Tensor, sample_shape: Size = Size(), observations: dict[str, Tensor] | None = None) -> Tensor | dict[str, Tensor]

Run the program forward, returning the designated output(s).

Each draw step is executed in order. Steps that reference the program input use x directly; steps that reference bound variables use those variables' sampled values.

PARAMETER DESCRIPTION
x

Program input.

TYPE: Tensor

sample_shape

Additional leading sample dimensions (applied to the first draw only; subsequent draws inherit the shape).

TYPE: Size DEFAULT: Size()

observations

Values to clamp observed variables to. Keys are variable names, values are tensors of the appropriate shape.

TYPE: dict[str, Tensor] or None DEFAULT: None

RETURNS DESCRIPTION
Tensor or dict[str, Tensor]

The value of the return variable(s). Returns a tensor for single-variable returns, or a dict keyed by variable name for tuple returns.

Source code in src/quivers/continuous/programs.py
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
def rsample(  # type: ignore[override]
    self,
    x: torch.Tensor,
    sample_shape: torch.Size = torch.Size(),
    observations: dict[str, torch.Tensor] | None = None,
) -> torch.Tensor | dict[str, torch.Tensor]:
    """Run the program forward, returning the designated output(s).

    Each draw step is executed in order. Steps that reference
    the program input use ``x`` directly; steps that reference
    bound variables use those variables' sampled values.

    Parameters
    ----------
    x : torch.Tensor
        Program input.
    sample_shape : torch.Size
        Additional leading sample dimensions (applied to the
        first draw only; subsequent draws inherit the shape).
    observations : dict[str, torch.Tensor] or None
        Values to clamp observed variables to. Keys are variable
        names, values are tensors of the appropriate shape.

    Returns
    -------
    torch.Tensor or dict[str, torch.Tensor]
        The value of the return variable(s). Returns a tensor
        for single-variable returns, or a dict keyed by variable
        name for tuple returns.
    """
    if observations is None:
        observations = {}

    env: dict[str, torch.Tensor] = {}
    # Reserved synthetic key: compiler-emitted let-callables that
    # need the program input (e.g. captured observes inside a
    # grouped marginalize block) read ``env["_x_input"]``.
    env["_x_input"] = x

    # pre-populate env with named params (split product input)
    if self._params is not None and self._param_dims is not None:
        splits = torch.split(x, self._param_dims, dim=-1)

        assert self._param_is_continuous is not None
        for pname, chunk, is_cont in zip(
            self._params, splits, self._param_is_continuous
        ):
            # only squeeze discrete components (continuous dim=1 should stay 2D)
            if not is_cont and chunk.shape[-1] == 1:
                env[pname] = chunk.squeeze(-1)

            else:
                env[pname] = chunk

    for i, spec in enumerate(self._step_specs):
        if isinstance(spec, _ScoreSpec):
            # forward path: bind the score callable's result like a
            # let, score contribution is only meaningful for log_joint.
            env[spec.var] = cast(torch.Tensor, spec.score(env))
            continue

        if isinstance(spec, _LetSpec):
            # deterministic binding: constant, alias, or expression
            if isinstance(spec.value, str):
                env[spec.var] = env[spec.value]

            elif callable(spec.value):
                env[spec.var] = cast(torch.Tensor, spec.value(env))

            else:
                env[spec.var] = torch.full(
                    (x.shape[0],),
                    spec.value,
                    device=x.device,
                )

            continue

        assert self._modules[spec.morphism_name] is not None
        bound = self._modules[spec.morphism_name]
        inp = self._resolve_input(spec, x, env)

        # A bound module may be either a ContinuousMorphism
        # (probabilistic; has rsample / log_prob) or the wrapper
        # produced by `as_torch_module` around a V-Cat
        # `Morphism` (deterministic; has ``_morphism``
        # attached). The deterministic path materialises the
        # morphism's tensor and contracts it against the input,
        # binding the result like a let-step.
        from quivers.core.morphisms import (
            extract_morphism,
        )

        cat_morph = extract_morphism(bound)
        if cat_morph is not None and not isinstance(bound, ContinuousMorphism):
            value = self._apply_categorical_morphism(cat_morph, inp, x.shape[0])
            self._bind_result(spec, value, env)
            continue
        morph = cast(ContinuousMorphism, bound)

        # check if any vars in this step are observed
        if len(spec.vars) == 1:
            var_name = spec.vars[0]

            if spec.is_observed and var_name in observations:
                # clamp to observed value
                env[var_name] = observations[var_name]
                continue

        else:
            # destructuring: if observed vars are present, clamp them
            any_clamped = False

            for v in spec.vars:
                if spec.is_observed and v in observations:
                    env[v] = observations[v]
                    any_clamped = True

            if any_clamped:
                # for partially observed destructuring, sample the rest
                all_clamped = all(v in observations for v in spec.vars)

                if not all_clamped:
                    result = morph.rsample(inp)
                    # only bind un-clamped vars
                    if isinstance(result, dict):
                        result_dict = cast(dict[str, torch.Tensor], result)
                        for v in spec.vars:
                            if v not in observations:
                                env[v] = result_dict[v]

                    else:
                        dims = self._compute_component_dims(morph.codomain)
                        splits = torch.split(result, dims, dim=-1)

                        for v, chunk in zip(spec.vars, splits):
                            if v not in observations:
                                env[v] = (
                                    chunk.squeeze(-1)
                                    if chunk.shape[-1] == 1
                                    else chunk
                                )

                continue

        # only apply sample_shape to the first draw from input
        if i == 0 and spec.args is None and len(sample_shape) > 0:
            result = morph.rsample(inp, sample_shape)

        else:
            result = morph.rsample(inp)

        self._bind_result(spec, result, env)

    # return
    if self._return_is_single:
        return env[self._return_vars[0]]

    # use labels as keys if available, otherwise variable names
    keys = self._return_labels if self._return_labels else self._return_vars
    return {k: env[v] for k, v in zip(keys, self._return_vars)}

log_prob

log_prob(x: Tensor, y: Tensor) -> Tensor

Log-probability is not supported for monadic programs.

Computing log p(y | x) for a monadic program requires marginalizing over all intermediate variables, which is intractable in general. Use rsample for forward sampling and condition via score function estimators or variational methods.

RAISES DESCRIPTION
NotImplementedError

Always.

Source code in src/quivers/continuous/programs.py
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
def log_prob(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """Log-probability is not supported for monadic programs.

    Computing log p(y | x) for a monadic program requires
    marginalizing over all intermediate variables, which is
    intractable in general. Use ``rsample`` for forward sampling
    and condition via score function estimators or variational
    methods.

    Raises
    ------
    NotImplementedError
        Always.
    """
    raise NotImplementedError(
        "log_prob is not supported for monadic programs; "
        "computing p(y | x) requires marginalizing over all "
        "intermediate draws, which is intractable in general. "
        "use rsample() for forward sampling."
    )

log_joint

log_joint(x: Tensor, intermediates: dict[str, Tensor]) -> Tensor

Joint log-density given all intermediate values.

When all intermediate variables are observed (e.g. during inference with HMC/NUTS), computes the joint log-density:

log p(x_1, ..., x_n | input) = sum_i log p(x_i | pa(x_i))

where pa(x_i) is the parent variable of step i (either the program input or a previously drawn variable).

For destructuring draw steps (tuple-returning sub-programs), the intermediates dict should contain entries for each individual variable name.

PARAMETER DESCRIPTION
x

Program input.

TYPE: Tensor

intermediates

Values for ALL bound variables (keyed by variable name or by return label if labels are set).

TYPE: dict[str, Tensor]

RETURNS DESCRIPTION
Tensor

Joint log-density. Shape (batch,).

Source code in src/quivers/continuous/programs.py
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
def log_joint(
    self,
    x: torch.Tensor,
    intermediates: dict[str, torch.Tensor],
) -> torch.Tensor:
    """Joint log-density given all intermediate values.

    When all intermediate variables are observed (e.g. during
    inference with HMC/NUTS), computes the joint log-density:

        log p(x_1, ..., x_n | input) = sum_i log p(x_i | pa(x_i))

    where pa(x_i) is the parent variable of step i (either the
    program input or a previously drawn variable).

    For destructuring draw steps (tuple-returning sub-programs),
    the intermediates dict should contain entries for each
    individual variable name.

    Parameters
    ----------
    x : torch.Tensor
        Program input.
    intermediates : dict[str, torch.Tensor]
        Values for ALL bound variables (keyed by variable name
        or by return label if labels are set).

    Returns
    -------
    torch.Tensor
        Joint log-density. Shape (batch,).
    """
    total = torch.zeros(x.shape[0], device=x.device)

    # if labels are used, map label keys back to variable names
    env = dict(intermediates)
    # Reserved synthetic key: compiler-emitted let-callables that
    # need the program input (e.g. captured observes inside a
    # grouped marginalize block when the family takes the program
    # input directly) read ``env["_x_input"]``.
    env["_x_input"] = x

    if self._return_labels:
        for label, var in zip(self._return_labels, self._return_vars):
            if label in env and var not in env:
                env[var] = env[label]

    if self._params is not None and self._param_dims is not None:
        splits = torch.split(x, self._param_dims, dim=-1)

        assert self._param_is_continuous is not None
        for pname, chunk, is_cont in zip(
            self._params, splits, self._param_is_continuous
        ):
            if pname not in env:
                if not is_cont and chunk.shape[-1] == 1:
                    env[pname] = chunk.squeeze(-1)

                else:
                    env[pname] = chunk

    for spec in self._step_specs:
        if isinstance(spec, _ScoreSpec):
            # Score step (e.g. compiled marginalize): the callable
            # returns a (batch,)-shaped log-density contribution
            # that is both bound to env (for any later step that
            # references it) and added to the joint.
            val = cast(torch.Tensor, spec.score(env))
            env[spec.var] = val
            total = total + val
            continue

        if isinstance(spec, _LetSpec):
            # deterministic binding: populate env, contribute 0
            if spec.var not in env:
                if isinstance(spec.value, str):
                    env[spec.var] = env[spec.value]

                elif callable(spec.value):
                    env[spec.var] = cast(torch.Tensor, spec.value(env))

                else:
                    env[spec.var] = torch.full(
                        (x.shape[0],),
                        spec.value,
                        device=x.device,
                    )

            continue

        assert self._modules[spec.morphism_name] is not None
        morph = cast(ContinuousMorphism, self._modules[spec.morphism_name])
        inp = self._resolve_input(spec, x, env)

        if len(spec.vars) == 1:
            val = env[spec.vars[0]]
            total = total + morph.log_prob(inp, val)

        else:
            # destructuring step: if sub-program, call its log_joint
            # with the individual intermediate values
            if hasattr(morph, "log_joint") and hasattr(morph, "_return_vars"):
                # reconstruct the sub-program's intermediates from
                # the overall intermediates dict
                sub_morph = cast(MonadicProgram, morph)
                sub_intermediates = {}

                for sub_spec in sub_morph._step_specs:
                    if isinstance(sub_spec, (_LetSpec, _ScoreSpec)):
                        continue

                    for sv in sub_spec.vars:
                        if sv in env:
                            sub_intermediates[sv] = env[sv]

                total = total + sub_morph.log_joint(inp, sub_intermediates)

            else:
                # product-codomain morphism: reconstruct stacked output
                parts = [env[v] for v in spec.vars]
                val = self._stack_tensors(parts)
                total = total + morph.log_prob(inp, val)

    return total