Trace

Program trace data structures and trace-based inference.

trace

Execution trace for monadic programs.

A trace records every sample site visited during program execution, capturing the morphism, sampled or observed value, and log-density at each site. This is the foundation for all inference algorithms: SVI uses traces to compute the ELBO, and conditioning operates by clamping trace sites to observed data.

The trace function is a free function that operates on any MonadicProgram without modifying it — it walks the program's step specs and resolves inputs using the program's existing infrastructure.

SampleSite dataclass

SampleSite(name: str, morphism: ContinuousMorphism | None, value: Tensor, log_prob: Tensor, is_observed: bool = False, is_deterministic: bool = False)

Record of a single sample site in a program trace.

Holds a torch.Tensor per site; not a value type.

PARAMETER DESCRIPTION
name

Variable name bound at this site.

TYPE: str

morphism

The distribution morphism (None for let bindings).

TYPE: ContinuousMorphism or None

value

The sampled or observed value.

TYPE: Tensor

log_prob

Log-density of the value under the morphism. Shape (batch,). Zero for let bindings.

TYPE: Tensor

is_observed

Whether this site was clamped to an observed value.

TYPE: bool DEFAULT: False

is_deterministic

Whether this is a deterministic let binding.

TYPE: bool DEFAULT: False

Trace dataclass

Trace(sites: dict[str, SampleSite] = dict(), output: Tensor | dict[str, Tensor] | None = None, log_joint: Tensor | None = None)

Complete execution trace of a monadic program.

Mutable accumulator: sites grows as the program executes; not a value type.

PARAMETER DESCRIPTION
sites

All sample sites keyed by variable name.

TYPE: dict[str, SampleSite] DEFAULT: dict()

output

The program's return value.

TYPE: Tensor or dict[str, Tensor] DEFAULT: None

log_joint

Sum of log-densities across all stochastic sites. Shape (batch,).

TYPE: Tensor DEFAULT: None

stochastic_sites property

stochastic_sites: dict[str, SampleSite]

Return only stochastic (non-deterministic) sites.

latent_sites property

latent_sites: dict[str, SampleSite]

Return only latent (non-observed, non-deterministic) sites.

observed_sites property

observed_sites: dict[str, SampleSite]

Return only observed sites.

trace

trace(program: MonadicProgram, x: Tensor, observations: dict[str, Tensor] | None = None) -> Trace

Execute a program and record all sample sites.

Walks the program's step specs in order, sampling from each morphism (or clamping to observed values) and recording the value and log-density at each site.

PARAMETER DESCRIPTION
program

The program to trace.

TYPE: MonadicProgram

x

Program input. Shape (batch, ...).

TYPE: Tensor

observations

Values to clamp observed variables to. Keys are variable names, values are tensors of the appropriate shape.

TYPE: dict[str, Tensor] or None DEFAULT: None

RETURNS DESCRIPTION
Trace

Complete execution trace with all sites, output, and log-joint.

Source code in src/quivers/inference/trace.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
def trace(
    program: MonadicProgram,
    x: torch.Tensor,
    observations: dict[str, torch.Tensor] | None = None,
) -> Trace:
    """Execute a program and record all sample sites.

    Walks the program's step specs in order, sampling from each
    morphism (or clamping to observed values) and recording the
    value and log-density at each site.

    Parameters
    ----------
    program : MonadicProgram
        The program to trace.
    x : torch.Tensor
        Program input. Shape (batch, ...).
    observations : dict[str, torch.Tensor] or None
        Values to clamp observed variables to. Keys are variable
        names, values are tensors of the appropriate shape.

    Returns
    -------
    Trace
        Complete execution trace with all sites, output, and log-joint.
    """
    if observations is None:
        observations = {}

    env: dict[str, torch.Tensor] = {}
    tr = Trace()
    total_lp = torch.zeros(x.shape[0], device=x.device)

    # pre-populate env with named params
    if program._params is not None and program._param_dims is not None:
        splits = torch.split(x, program._param_dims, dim=-1)

        assert program._param_is_continuous is not None
        for pname, chunk, is_cont in zip(
            program._params,
            splits,
            program._param_is_continuous,
        ):
            if not is_cont and chunk.shape[-1] == 1:
                env[pname] = chunk.squeeze(-1)

            else:
                env[pname] = chunk

    for spec in program._step_specs:
        if isinstance(spec, _LetSpec):
            # deterministic binding
            if isinstance(spec.value, str):
                env[spec.var] = env[spec.value]

            elif callable(spec.value):
                env[spec.var] = cast(torch.Tensor, spec.value(env))

            else:
                env[spec.var] = torch.full(
                    (x.shape[0],),
                    spec.value,
                    device=x.device,
                )

            tr.sites[spec.var] = SampleSite(
                name=spec.var,
                morphism=None,
                value=env[spec.var],
                log_prob=torch.zeros(x.shape[0], device=x.device),
                is_deterministic=True,
            )
            continue

        # stochastic draw step
        assert program._modules[spec.morphism_name] is not None
        morph = cast(ContinuousMorphism, program._modules[spec.morphism_name])
        inp = program._resolve_input(spec, x, env)

        if len(spec.vars) == 1:
            var_name = spec.vars[0]
            is_obs = var_name in observations

            if is_obs:
                # clamp to observed value
                val = observations[var_name]
                env[var_name] = val

            else:
                # sample from the morphism
                val = morph.rsample(inp)
                env[var_name] = val

            lp = morph.log_prob(inp, val)
            total_lp = total_lp + lp

            tr.sites[var_name] = SampleSite(
                name=var_name,
                morphism=morph,
                value=val,
                log_prob=lp,
                is_observed=is_obs,
            )

        else:
            # destructuring step
            # check if any destructured vars are observed
            any_observed = any(v in observations for v in spec.vars)

            if any_observed:
                # clamp all destructured vars from observations
                for v in spec.vars:
                    if v in observations:
                        env[v] = observations[v]

                    else:
                        # if only some are observed, we need to sample the rest
                        # for now, treat as fully observed or fully latent
                        result = morph.rsample(inp)
                        program._bind_result(spec, result, env)
                        break

            else:
                result = morph.rsample(inp)
                program._bind_result(spec, result, env)

            # compute log-prob for the full step
            if hasattr(morph, "log_joint") and hasattr(morph, "_return_vars"):
                # sub-program: reconstruct intermediates
                sub_morph = cast(MonadicProgram, morph)
                sub_intermediates = {}

                for sub_spec in sub_morph._step_specs:
                    if isinstance(sub_spec, _LetSpec):
                        continue

                    for sv in sub_spec.vars:
                        if sv in env:
                            sub_intermediates[sv] = env[sv]

                lp = sub_morph.log_joint(inp, sub_intermediates)

            else:
                # product morphism: stack and evaluate
                parts = [env[v] for v in spec.vars]
                stacked = program._stack_tensors(parts)
                lp = morph.log_prob(inp, stacked)

            total_lp = total_lp + lp

            # record each destructured variable as a site
            for v in spec.vars:
                tr.sites[v] = SampleSite(
                    name=v,
                    morphism=morph,
                    value=env[v],
                    log_prob=lp / len(spec.vars),  # split log-prob evenly
                    is_observed=v in observations,
                )

    # compute output
    if program._return_is_single:
        tr.output = env[program._return_vars[0]]

    else:
        keys = (
            program._return_labels if program._return_labels else program._return_vars
        )
        tr.output = {k: env[v] for k, v in zip(keys, program._return_vars)}

    tr.log_joint = total_lp
    return tr