Continuous Programs

Probabilistic programs in continuous domains.

programs

Monadic programs: sequenced probabilistic programs as ContinuousMorphisms.

A MonadicProgram defines a ContinuousMorphism via monadic sequencing of draw steps. Each step samples from a named morphism, optionally conditioned on previously drawn variables, and binds the result. The program returns one or more of the bound variables as its output.

This corresponds to the Kleisli composition pattern used in probabilistic programming languages like PDS (Grove & White), where a sequence of let' x ~ D in ... bindings threads probabilistic state through a generative model.

Features
  • Single and tuple returns: return x or return (x, y, z)
  • Named input parameters for product-domain sub-programs
  • Multi-argument draw steps: draw z ~ f(x, y)
  • Destructuring draws from tuple-returning sub-programs: draw (a, b) ~ sub_prog(x)
Example

Given morphisms f : A -> B and g : B -> C, the monadic program::

program p : A -> C
    draw x ~ f
    draw y ~ g(x)
    return y

is equivalent to the composition f >> g, but the program form allows fan-out (using the input in multiple draws) and non-linear variable dependency graphs.

PDS-style nested programs::

program cg_update(y, z) : Belief * Belief -> Truth * Truth
    draw c ~ bern_c(y)
    draw d ~ bern_d(z)
    return (c, d)

program factivityPrior : Entity -> Truth * Truth * Truth
    draw x ~ prior_x
    draw y ~ prior_y
    draw z ~ prior_z
    draw b ~ bern_b(x)
    draw (c, d) ~ cg_update(y, z)
    return (b, c, d)

MonadicProgram

MonadicProgram(domain: AnySpace, codomain: AnySpace, steps: list[tuple], return_vars: tuple[str, ...], params: tuple[str, ...] | None = None, return_labels: tuple[str, ...] | None = None)

Bases: ContinuousMorphism

A probabilistic program defined by monadic sequencing of draw steps.

Each draw step samples from a ContinuousMorphism and binds the result to one or more named variables. Later steps can reference earlier bindings as their input. The program's output is the value(s) of the designated return variable(s).

PARAMETER DESCRIPTION
domain

The program's input space.

TYPE: SetObject or ContinuousSpace

codomain

The program's output space.

TYPE: SetObject or ContinuousSpace

steps

Each entry is either (var_names, morphism, arg_names) for draw steps, or (var_names, None, value) for let bindings where value is a float constant or str variable reference.

TYPE: list[tuple]

return_vars

Name(s) of the bound variable(s) whose value(s) are the program output.

TYPE: tuple[str, ...]

params

Named input parameters for product-domain programs. When set, the program input is split along the feature dimension and each component is pre-bound in the env.

TYPE: tuple[str, ...] or None DEFAULT: None

return_labels

Optional labels for tuple return fields. When set, the output dict uses these labels as keys instead of the variable names. Length must match return_vars.

TYPE: tuple[str, ...] or None DEFAULT: None

Source code in src/quivers/continuous/programs.py
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
def __init__(
    self,
    domain: AnySpace,
    codomain: AnySpace,
    steps: list[tuple],
    return_vars: tuple[str, ...],
    params: tuple[str, ...] | None = None,
    return_labels: tuple[str, ...] | None = None,
) -> None:
    super().__init__(domain, codomain)
    self._return_vars = return_vars
    self._return_is_single = len(return_vars) == 1
    self._params = params
    self._return_labels = return_labels
    self._step_specs: list[_StepSpec | _LetSpec] = []

    # compute input component dimensions for param splitting
    if params is not None and len(params) > 1:
        self._param_dims = self._compute_component_dims(domain)
        self._param_is_continuous = self._compute_component_continuous(domain)

    else:
        self._param_dims = None
        self._param_is_continuous = None

    # register each morphism as a named submodule so parameters
    # are visible to optimizers; let bindings become _LetSpec
    for step in steps:
        # support both 3-element (backward compat) and 4-element tuples
        if len(step) == 4:
            var_names, morph, arg_or_value, is_observed = step

        else:
            var_names, morph, arg_or_value = step
            is_observed = False

        if morph is None:
            # let binding: arg_or_value is float | str
            self._step_specs.append(_LetSpec(var_names[0], arg_or_value))

        else:
            key = f"_step_{var_names[0]}"
            self.add_module(key, morph)
            self._step_specs.append(
                _StepSpec(var_names, key, arg_or_value, is_observed)
            )

observed_names property

observed_names: set[str]

Return the set of variable names marked as observed in the DSL.

rsample

rsample(x: Tensor, sample_shape: Size = Size(), observations: dict[str, Tensor] | None = None) -> Tensor | dict[str, Tensor]

Run the program forward, returning the designated output(s).

Each draw step is executed in order. Steps that reference the program input use x directly; steps that reference bound variables use those variables' sampled values.

PARAMETER DESCRIPTION
x

Program input.

TYPE: Tensor

sample_shape

Additional leading sample dimensions (applied to the first draw only; subsequent draws inherit the shape).

TYPE: Size DEFAULT: Size()

observations

Values to clamp observed variables to. Keys are variable names, values are tensors of the appropriate shape.

TYPE: dict[str, Tensor] or None DEFAULT: None

RETURNS DESCRIPTION
Tensor or dict[str, Tensor]

The value of the return variable(s). Returns a tensor for single-variable returns, or a dict keyed by variable name for tuple returns.

Source code in src/quivers/continuous/programs.py
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
def rsample(  # type: ignore[override]
    self,
    x: torch.Tensor,
    sample_shape: torch.Size = torch.Size(),
    observations: dict[str, torch.Tensor] | None = None,
) -> torch.Tensor | dict[str, torch.Tensor]:
    """Run the program forward, returning the designated output(s).

    Each draw step is executed in order. Steps that reference
    the program input use ``x`` directly; steps that reference
    bound variables use those variables' sampled values.

    Parameters
    ----------
    x : torch.Tensor
        Program input.
    sample_shape : torch.Size
        Additional leading sample dimensions (applied to the
        first draw only; subsequent draws inherit the shape).
    observations : dict[str, torch.Tensor] or None
        Values to clamp observed variables to. Keys are variable
        names, values are tensors of the appropriate shape.

    Returns
    -------
    torch.Tensor or dict[str, torch.Tensor]
        The value of the return variable(s). Returns a tensor
        for single-variable returns, or a dict keyed by variable
        name for tuple returns.
    """
    if observations is None:
        observations = {}

    env: dict[str, torch.Tensor] = {}

    # pre-populate env with named params (split product input)
    if self._params is not None and self._param_dims is not None:
        splits = torch.split(x, self._param_dims, dim=-1)

        assert self._param_is_continuous is not None
        for pname, chunk, is_cont in zip(
            self._params, splits, self._param_is_continuous
        ):
            # only squeeze discrete components (continuous dim=1 should stay 2D)
            if not is_cont and chunk.shape[-1] == 1:
                env[pname] = chunk.squeeze(-1)

            else:
                env[pname] = chunk

    for i, spec in enumerate(self._step_specs):
        if isinstance(spec, _LetSpec):
            # deterministic binding: constant, alias, or expression
            if isinstance(spec.value, str):
                env[spec.var] = env[spec.value]

            elif callable(spec.value):
                env[spec.var] = cast(torch.Tensor, spec.value(env))

            else:
                env[spec.var] = torch.full(
                    (x.shape[0],),
                    spec.value,
                    device=x.device,
                )

            continue

        assert self._modules[spec.morphism_name] is not None
        morph = cast(ContinuousMorphism, self._modules[spec.morphism_name])
        inp = self._resolve_input(spec, x, env)

        # check if any vars in this step are observed
        if len(spec.vars) == 1:
            var_name = spec.vars[0]

            if spec.is_observed and var_name in observations:
                # clamp to observed value
                env[var_name] = observations[var_name]
                continue

        else:
            # destructuring: if observed vars are present, clamp them
            any_clamped = False

            for v in spec.vars:
                if spec.is_observed and v in observations:
                    env[v] = observations[v]
                    any_clamped = True

            if any_clamped:
                # for partially observed destructuring, sample the rest
                all_clamped = all(v in observations for v in spec.vars)

                if not all_clamped:
                    result = morph.rsample(inp)
                    # only bind un-clamped vars
                    if isinstance(result, dict):
                        result_dict = cast(dict[str, torch.Tensor], result)
                        for v in spec.vars:
                            if v not in observations:
                                env[v] = result_dict[v]

                    else:
                        dims = self._compute_component_dims(morph.codomain)
                        splits = torch.split(result, dims, dim=-1)

                        for v, chunk in zip(spec.vars, splits):
                            if v not in observations:
                                env[v] = (
                                    chunk.squeeze(-1)
                                    if chunk.shape[-1] == 1
                                    else chunk
                                )

                continue

        # only apply sample_shape to the first draw from input
        if i == 0 and spec.args is None and len(sample_shape) > 0:
            result = morph.rsample(inp, sample_shape)

        else:
            result = morph.rsample(inp)

        self._bind_result(spec, result, env)

    # return
    if self._return_is_single:
        return env[self._return_vars[0]]

    # use labels as keys if available, otherwise variable names
    keys = self._return_labels if self._return_labels else self._return_vars
    return {k: env[v] for k, v in zip(keys, self._return_vars)}

log_prob

log_prob(x: Tensor, y: Tensor) -> Tensor

Log-probability is not supported for monadic programs.

Computing log p(y | x) for a monadic program requires marginalizing over all intermediate variables, which is intractable in general. Use rsample for forward sampling and condition via score function estimators or variational methods.

RAISES DESCRIPTION
NotImplementedError

Always.

Source code in src/quivers/continuous/programs.py
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
def log_prob(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """Log-probability is not supported for monadic programs.

    Computing log p(y | x) for a monadic program requires
    marginalizing over all intermediate variables, which is
    intractable in general. Use ``rsample`` for forward sampling
    and condition via score function estimators or variational
    methods.

    Raises
    ------
    NotImplementedError
        Always.
    """
    raise NotImplementedError(
        "log_prob is not supported for monadic programs; "
        "computing p(y | x) requires marginalizing over all "
        "intermediate draws, which is intractable in general. "
        "use rsample() for forward sampling."
    )

log_joint

log_joint(x: Tensor, intermediates: dict[str, Tensor]) -> Tensor

Joint log-density given all intermediate values.

When all intermediate variables are observed (e.g. during inference with HMC/NUTS), computes the joint log-density:

log p(x_1, ..., x_n | input) = sum_i log p(x_i | pa(x_i))

where pa(x_i) is the parent variable of step i (either the program input or a previously drawn variable).

For destructuring draw steps (tuple-returning sub-programs), the intermediates dict should contain entries for each individual variable name.

PARAMETER DESCRIPTION
x

Program input.

TYPE: Tensor

intermediates

Values for ALL bound variables (keyed by variable name or by return label if labels are set).

TYPE: dict[str, Tensor]

RETURNS DESCRIPTION
Tensor

Joint log-density. Shape (batch,).

Source code in src/quivers/continuous/programs.py
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
def log_joint(
    self,
    x: torch.Tensor,
    intermediates: dict[str, torch.Tensor],
) -> torch.Tensor:
    """Joint log-density given all intermediate values.

    When all intermediate variables are observed (e.g. during
    inference with HMC/NUTS), computes the joint log-density:

        log p(x_1, ..., x_n | input) = sum_i log p(x_i | pa(x_i))

    where pa(x_i) is the parent variable of step i (either the
    program input or a previously drawn variable).

    For destructuring draw steps (tuple-returning sub-programs),
    the intermediates dict should contain entries for each
    individual variable name.

    Parameters
    ----------
    x : torch.Tensor
        Program input.
    intermediates : dict[str, torch.Tensor]
        Values for ALL bound variables (keyed by variable name
        or by return label if labels are set).

    Returns
    -------
    torch.Tensor
        Joint log-density. Shape (batch,).
    """
    total = torch.zeros(x.shape[0], device=x.device)

    # if labels are used, map label keys back to variable names
    env = dict(intermediates)

    if self._return_labels:
        for label, var in zip(self._return_labels, self._return_vars):
            if label in env and var not in env:
                env[var] = env[label]

    if self._params is not None and self._param_dims is not None:
        splits = torch.split(x, self._param_dims, dim=-1)

        assert self._param_is_continuous is not None
        for pname, chunk, is_cont in zip(
            self._params, splits, self._param_is_continuous
        ):
            if pname not in env:
                if not is_cont and chunk.shape[-1] == 1:
                    env[pname] = chunk.squeeze(-1)

                else:
                    env[pname] = chunk

    for spec in self._step_specs:
        if isinstance(spec, _LetSpec):
            # deterministic binding: populate env, contribute 0
            if spec.var not in env:
                if isinstance(spec.value, str):
                    env[spec.var] = env[spec.value]

                elif callable(spec.value):
                    env[spec.var] = cast(torch.Tensor, spec.value(env))

                else:
                    env[spec.var] = torch.full(
                        (x.shape[0],),
                        spec.value,
                        device=x.device,
                    )

            continue

        assert self._modules[spec.morphism_name] is not None
        morph = cast(ContinuousMorphism, self._modules[spec.morphism_name])
        inp = self._resolve_input(spec, x, env)

        if len(spec.vars) == 1:
            val = env[spec.vars[0]]
            total = total + morph.log_prob(inp, val)

        else:
            # destructuring step: if sub-program, call its log_joint
            # with the individual intermediate values
            if hasattr(morph, "log_joint") and hasattr(morph, "_return_vars"):
                # reconstruct the sub-program's intermediates from
                # the overall intermediates dict
                sub_morph = cast(MonadicProgram, morph)
                sub_intermediates = {}

                for sub_spec in sub_morph._step_specs:
                    if isinstance(sub_spec, _LetSpec):
                        continue

                    for sv in sub_spec.vars:
                        if sv in env:
                            sub_intermediates[sv] = env[sv]

                total = total + sub_morph.log_joint(inp, sub_intermediates)

            else:
                # product-codomain morphism: reconstruct stacked output
                parts = [env[v] for v in spec.vars]
                val = self._stack_tensors(parts)
                total = total + morph.log_prob(inp, val)

    return total