Flows

Normalizing flows and flow-based transformations.

flows

Normalizing flows as continuous morphisms.

A normalizing flow defines a bijective map from a simple base distribution (standard normal) to a complex target distribution. The map is composed of invertible layers with tractable Jacobians, enabling exact log-density computation via the change-of-variables formula.

For conditional flows, each layer's parameters depend on the conditioning input x, making the flow a ContinuousMorphism:

p(y | x) = p_base(f^{-1}(y; x)) * |det df^{-1}/dy|

This module provides:

AffineCouplingLayer — single invertible affine coupling layer
ConditionalFlow     — stack of coupling layers as a ContinuousMorphism

AffineCouplingLayer

AffineCouplingLayer(domain: AnySpace, dim: int, mask_even: bool = True, hidden_dim: int = 64)

Bases: Module

Single affine coupling layer (RealNVP-style).

Splits the input z into two halves (z_a, z_b). One half passes through unchanged while the other is affinely transformed based on the first half (and the conditioning variable x):

z_a' = z_a                               (unchanged)
z_b' = z_b * exp(s(x, z_a)) + t(x, z_a) (transformed)

The Jacobian is triangular, so its log-determinant is simply sum(s(x, z_a)).

PARAMETER DESCRIPTION
domain

Conditioning space.

TYPE: AnySpace

dim

Total dimensionality of z.

TYPE: int

mask_even

If True, z_a = even indices, z_b = odd indices. If False, reversed.

TYPE: bool DEFAULT: True

hidden_dim

Hidden layer width for the scale/shift network.

TYPE: int DEFAULT: 64

Source code in src/quivers/continuous/flows.py
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
def __init__(
    self,
    domain: AnySpace,
    dim: int,
    mask_even: bool = True,
    hidden_dim: int = 64,
) -> None:
    super().__init__()
    self._dim = dim
    self._mask_even = mask_even

    # determine split sizes
    if mask_even:
        self._fixed_idx = torch.arange(0, dim, 2)
        self._transform_idx = torch.arange(1, dim, 2)

    else:
        self._fixed_idx = torch.arange(1, dim, 2)
        self._transform_idx = torch.arange(0, dim, 2)

    n_fixed = len(self._fixed_idx)
    n_transform = len(self._transform_idx)

    self.net = _ConditionedNet(
        domain,
        n_fixed,
        n_transform,
        hidden_dim,
    )

forward

forward(x: Tensor, z: Tensor) -> tuple[Tensor, Tensor]

Forward pass: base -> target.

PARAMETER DESCRIPTION
x

Conditioning input.

TYPE: Tensor

z

Input vector. Shape (batch, dim).

TYPE: Tensor

RETURNS DESCRIPTION
z_out

Transformed vector. Shape (batch, dim).

TYPE: Tensor

log_det

Log-determinant of the Jacobian. Shape (batch,).

TYPE: Tensor

Source code in src/quivers/continuous/flows.py
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
def forward(
    self,
    x: torch.Tensor,
    z: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Forward pass: base -> target.

    Parameters
    ----------
    x : torch.Tensor
        Conditioning input.
    z : torch.Tensor
        Input vector. Shape (batch, dim).

    Returns
    -------
    z_out : torch.Tensor
        Transformed vector. Shape (batch, dim).
    log_det : torch.Tensor
        Log-determinant of the Jacobian. Shape (batch,).
    """
    z_fixed = z[..., self._fixed_idx]
    z_transform = z[..., self._transform_idx]

    shift, log_scale = self.net(x, z_fixed)
    z_transformed = z_transform * log_scale.exp() + shift

    z_out = torch.empty_like(z)
    z_out[..., self._fixed_idx] = z_fixed
    z_out[..., self._transform_idx] = z_transformed

    log_det = log_scale.sum(dim=-1)
    return z_out, log_det

inverse

inverse(x: Tensor, z_out: Tensor) -> tuple[Tensor, Tensor]

Inverse pass: target -> base.

PARAMETER DESCRIPTION
x

Conditioning input.

TYPE: Tensor

z_out

Transformed vector. Shape (batch, dim).

TYPE: Tensor

RETURNS DESCRIPTION
z

Original vector. Shape (batch, dim).

TYPE: Tensor

log_det

Log-determinant (negative of forward). Shape (batch,).

TYPE: Tensor

Source code in src/quivers/continuous/flows.py
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
def inverse(
    self,
    x: torch.Tensor,
    z_out: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Inverse pass: target -> base.

    Parameters
    ----------
    x : torch.Tensor
        Conditioning input.
    z_out : torch.Tensor
        Transformed vector. Shape (batch, dim).

    Returns
    -------
    z : torch.Tensor
        Original vector. Shape (batch, dim).
    log_det : torch.Tensor
        Log-determinant (negative of forward). Shape (batch,).
    """
    z_fixed = z_out[..., self._fixed_idx]
    z_transformed = z_out[..., self._transform_idx]

    shift, log_scale = self.net(x, z_fixed)
    z_original = (z_transformed - shift) * (-log_scale).exp()

    z = torch.empty_like(z_out)
    z[..., self._fixed_idx] = z_fixed
    z[..., self._transform_idx] = z_original

    log_det = -log_scale.sum(dim=-1)
    return z, log_det

ConditionalFlow

ConditionalFlow(domain: AnySpace, codomain: Euclidean, n_layers: int = 4, hidden_dim: int = 64)

Bases: ContinuousMorphism

Conditional normalizing flow as a continuous morphism.

Stacks multiple affine coupling layers to form a flexible invertible transformation from a standard normal base to the target distribution, conditioned on input x.

The flow supports exact log-density computation:

log p(y | x) = log N(f^{-1}(y; x); 0, I)
               + sum_k log |det df_k^{-1}/dz_k|

And efficient sampling:

z ~ N(0, I)
y = f_K(... f_2(f_1(z; x); x) ...; x)
PARAMETER DESCRIPTION
domain

Conditioning space.

TYPE: SetObject or ContinuousSpace

codomain

Target continuous space.

TYPE: Euclidean

n_layers

Number of coupling layers. More layers = more expressive.

TYPE: int DEFAULT: 4

hidden_dim

Hidden layer width for scale/shift networks.

TYPE: int DEFAULT: 64

Examples:

>>> from quivers import FinSet
>>> from quivers.continuous.spaces import Euclidean
>>> A = FinSet(name="context", cardinality=10)
>>> Y = Euclidean(name="output", dim=4)
>>> flow = ConditionalFlow(A, Y, n_layers=6)
>>> x = torch.tensor([0, 1, 2])
>>> samples = flow.rsample(x)  # shape (3, 4)
Source code in src/quivers/continuous/flows.py
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
def __init__(
    self,
    domain: AnySpace,
    codomain: Euclidean,
    n_layers: int = 4,
    hidden_dim: int = 64,
) -> None:
    super().__init__(domain, codomain)
    d = codomain.dim

    if d < 2:
        raise ValueError(
            f"ConditionalFlow requires codomain dim >= 2, got {d}. "
            "Use ConditionalNormal for 1-d targets."
        )

    self.layers = nn.ModuleList()

    for i in range(n_layers):
        self.layers.append(
            AffineCouplingLayer(
                domain,
                d,
                mask_even=(i % 2 == 0),
                hidden_dim=hidden_dim,
            )
        )

    self._d = d

log_prob

log_prob(x: Tensor, y: Tensor) -> Tensor

Exact log-density via change of variables.

PARAMETER DESCRIPTION
x

Conditioning inputs.

TYPE: Tensor

y

Target values. Shape (batch, d).

TYPE: Tensor

RETURNS DESCRIPTION
Tensor

Log-densities. Shape (batch,).

Source code in src/quivers/continuous/flows.py
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
def log_prob(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """Exact log-density via change of variables.

    Parameters
    ----------
    x : torch.Tensor
        Conditioning inputs.
    y : torch.Tensor
        Target values. Shape (batch, d).

    Returns
    -------
    torch.Tensor
        Log-densities. Shape (batch,).
    """
    z = y
    total_log_det = torch.zeros(z.shape[0], device=z.device)

    # inverse pass through layers (reverse order)
    for layer in reversed(self.layers):
        z, log_det = cast(AffineCouplingLayer, layer).inverse(x, z)
        total_log_det = total_log_det + log_det

    # base distribution log-density (standard normal)
    log_base = -0.5 * z.pow(2).sum(dim=-1) - 0.5 * self._d * math.log(2 * math.pi)

    return log_base + total_log_det

rsample

rsample(x: Tensor, sample_shape: Size = Size()) -> Tensor

Sample via forward pass through the flow.

PARAMETER DESCRIPTION
x

Conditioning inputs.

TYPE: Tensor

sample_shape

Additional sample dimensions.

TYPE: Size DEFAULT: Size()

RETURNS DESCRIPTION
Tensor

Samples. Shape (*sample_shape, batch, d).

Source code in src/quivers/continuous/flows.py
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
def rsample(
    self,
    x: torch.Tensor,
    sample_shape: torch.Size = torch.Size(),
) -> torch.Tensor:
    """Sample via forward pass through the flow.

    Parameters
    ----------
    x : torch.Tensor
        Conditioning inputs.
    sample_shape : torch.Size
        Additional sample dimensions.

    Returns
    -------
    torch.Tensor
        Samples. Shape (*sample_shape, batch, d).
    """
    batch = x.shape[0]

    if len(sample_shape) > 0:
        n_extra = int(torch.Size(sample_shape).numel())
        total = n_extra * batch

        # replicate x for all samples
        x_rep = (
            x.unsqueeze(0)
            .expand(
                n_extra,
                *x.shape,
            )
            .reshape(total, *x.shape[1:])
            if x.dim() > 1
            else (x.unsqueeze(0).expand(n_extra, batch).reshape(total))
        )

        z = torch.randn(total, self._d, device=x.device)

        for layer in self.layers:
            z, _ = layer.forward(x_rep, z)

        if z.dim() > 1:
            return z.reshape(*sample_shape, batch, self._d)

        return z.reshape(*sample_shape, batch)

    else:
        z = torch.randn(batch, self._d, device=x.device)

        for layer in self.layers:
            z, _ = layer.forward(x, z)

        return z