Variational Guides¶
Variational guide distributions for approximate inference. The shipped guides (AutoNormalGuide, AutoDeltaGuide, AutoMultivariateNormalGuide, AutoLowRankMultivariateNormalGuide, AutoLaplaceApproximation, AutoNormalizingFlow, AutoIAFGuide, AutoNeuralSplineGuide, AutoMixtureGuide) live as submodules of quivers.inference.guides and share the Guide ABC and the LatentRegistry introspection layer.
guides
¶
Variational guide families.
Public surface (re-exported by the parent quivers.inference
package): one ABC (Guide) plus a zoo of concrete
Auto*Guide subclasses spanning the standard variational-family
ladder from mean-field Normal to normalizing-flow stacks and
hierarchical / mixture / structured guides.
Every concrete guide is built against a single
quivers.inference.registry.LatentRegistry and obeys the
shape contract documented on Guide.
Guide
¶
Bases: Module, ABC
Abstract variational guide.
Subclasses MUST implement rsample and log_prob
and expose latent_names. They MAY override
registry if they construct their registry lazily, but
the default implementation expects self._registry to be
set in __init__.
latent_names
abstractmethod
property
¶
latent_names: list[str]
Names of latent variables this guide covers.
build_registry
classmethod
¶
build_registry(model: MonadicProgram, observed_names: set[str] | frozenset[str]) -> LatentRegistry
Convenience wrapper around
LatentRegistry.from_model so guide constructors
can do self._registry = self.build_registry(model, obs)
without an extra import.
Source code in src/quivers/inference/guides/base.py
44 45 46 47 48 49 50 51 52 53 54 | |
rsample
abstractmethod
¶
rsample(x: Tensor) -> dict[str, Tensor]
Reparameterized sample from :math:q_\phi(z \mid x).
| PARAMETER | DESCRIPTION |
|---|---|
x
|
Program input. Shape
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
dict[str, Tensor]
|
Per-site constrained samples shaped to match the model's trace-side convention. |
Source code in src/quivers/inference/guides/base.py
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 | |
log_prob
abstractmethod
¶
log_prob(x: Tensor, sites: dict[str, Tensor]) -> Tensor
Log density of sites under :math:q_\phi(z \mid x),
with the change-of-variables Jacobian correction baked in.
| RETURNS | DESCRIPTION |
|---|---|
Tensor
|
Shape |
Source code in src/quivers/inference/guides/base.py
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 | |
AutoDeltaGuide
¶
AutoDeltaGuide(model: MonadicProgram, observed_names: set[str], init_value: float = 0.0)
Bases: Guide
Dirac-delta MAP guide with per-site constrained bijector.
| PARAMETER | DESCRIPTION |
|---|---|
model
|
Generative model.
TYPE:
|
observed_names
|
Variable names treated as observations.
TYPE:
|
init_value
|
Initial unconstrained-space coordinate for every latent.
Default
TYPE:
|
Source code in src/quivers/inference/guides/delta.py
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 | |
rsample
¶
rsample(x: Tensor) -> dict[str, Tensor]
Return the learned point estimates in the prior's support.
Source code in src/quivers/inference/guides/delta.py
74 75 76 77 78 79 80 81 | |
log_prob
¶
log_prob(x: Tensor, sites: dict[str, Tensor]) -> Tensor
Delta log-density: zero everywhere (the delta term and its Jacobian cancel in the ELBO under the standard score-function trick).
Source code in src/quivers/inference/guides/delta.py
83 84 85 86 87 88 89 90 91 | |
AutoIAFGuide
¶
AutoIAFGuide(model: MonadicProgram, observed_names: set[str], num_flows: int = 4, hidden_dim: int | None = None, num_hidden_layers: int = 2)
Bases: AutoNormalizingFlow
Inverse-autoregressive-flow guide.
Default normalizing-flow guide for variational inference
(Kingma-Salimans-Jozefowicz et al. 2016). Stack of
InverseAutoregressiveTransform layers, each separated
by a reverse permutation so successive layers have different
autoregressive orderings.
Sampling is parallel (one MLP forward per layer); density evaluation is sequential (one coordinate at a time per layer), so this guide should be used with objectives that sample more than they score the same flow (ELBO, IWAE).
| PARAMETER | DESCRIPTION |
|---|---|
model
|
Generative model.
TYPE:
|
observed_names
|
Variable names treated as observations.
TYPE:
|
num_flows
|
Number of IAF blocks in the stack. Default
TYPE:
|
hidden_dim
|
Hidden width of every MADE inside the stack. Default
TYPE:
|
num_hidden_layers
|
Number of hidden layers in each MADE. Default
TYPE:
|
Source code in src/quivers/inference/guides/flow.py
254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 | |
AutoNeuralSplineGuide
¶
AutoNeuralSplineGuide(model: MonadicProgram, observed_names: set[str], num_flows: int = 4, num_bins: int = 8, tail_bound: float = 3.0, hidden_dim: int | None = None, num_hidden_layers: int = 2)
Bases: AutoNormalizingFlow
Neural-spline-flow guide (Durkan-Bekasov-Murray-Papamakarios 2019).
Stack of monotone rational-quadratic spline coupling layers
(NeuralSplineCouplingTransform) with alternating
half-masks. Sharper than IAF for posteriors with bounded
support or sharp modes; comparable runtime.
| PARAMETER | DESCRIPTION |
|---|---|
model
|
Generative model.
TYPE:
|
observed_names
|
Variable names treated as observations.
TYPE:
|
num_flows
|
Number of coupling layers. Default
TYPE:
|
num_bins
|
Number of spline bins per coordinate. Default
TYPE:
|
tail_bound
|
Inputs outside
TYPE:
|
hidden_dim
|
Hidden width of the coupling MLPs. Default
TYPE:
|
num_hidden_layers
|
Hidden layers in each coupling MLP. Default
TYPE:
|
Source code in src/quivers/inference/guides/flow.py
315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 | |
AutoNormalizingFlow
¶
AutoNormalizingFlow(model: MonadicProgram, observed_names: set[str], transforms: list[TransformModule])
Bases: Guide
Normalising-flow variational guide over the flat latent vector.
| PARAMETER | DESCRIPTION |
|---|---|
model
|
Generative model to build a guide for.
TYPE:
|
observed_names
|
Variable names treated as observations.
TYPE:
|
transforms
|
Flow stack applied to the standard-Normal base. Each
TYPE:
|
Source code in src/quivers/inference/guides/flow.py
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 | |
rsample
¶
rsample(x: Tensor) -> dict[str, Tensor]
One flow draw, unflattened and bijected to constrained space.
Source code in src/quivers/inference/guides/flow.py
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | |
log_prob
¶
log_prob(x: Tensor, sites: dict[str, Tensor]) -> Tensor
Log-density at the supplied constrained sites.
Source code in src/quivers/inference/guides/flow.py
153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 | |
AutoLaplaceApproximation
¶
AutoLaplaceApproximation(model: MonadicProgram, observed_names: set[str], init_value: float = 0.0)
Bases: Guide
Laplace-approximation guide.
| PARAMETER | DESCRIPTION |
|---|---|
model
|
Generative model.
TYPE:
|
observed_names
|
Variable names treated as observations.
TYPE:
|
init_value
|
Initial unconstrained-space MAP estimate. Default
TYPE:
|
Source code in src/quivers/inference/guides/laplace.py
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 | |
fit_hessian
¶
fit_hessian(model: MonadicProgram, x: Tensor, observations: dict[str, Tensor], *, jitter: float = 0.0001) -> None
Compute and cache the Hessian-derived Cholesky factor.
Solves the eigenproblem of the negative-log-joint Hessian at
the current MAP, projects negative eigenvalues to jitter
(so the resulting Gaussian is always positive-definite), and
stores the matching lower-triangular Cholesky factor of the
inverse Hessian as the posterior scale_tril.
Call this after MAP optimisation has converged. Subsequent
rsample / log_prob calls sample from
:math:\mathcal{N}(z^\star, H^{-1}).
Source code in src/quivers/inference/guides/laplace.py
102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | |
rsample
¶
rsample(x: Tensor) -> dict[str, Tensor]
Sample from the Laplace posterior, unflatten, and biject.
Source code in src/quivers/inference/guides/laplace.py
164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 | |
log_prob
¶
log_prob(x: Tensor, sites: dict[str, Tensor]) -> Tensor
Log-density at the supplied constrained sites.
Returns zero before fit_hessian (MAP-phase delta
convention); after fit_hessian returns the Gaussian
log-density plus the per-site bijector Jacobian correction.
Source code in src/quivers/inference/guides/laplace.py
185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 | |
AutoMixtureGuide
¶
AutoMixtureGuide(components: list[Guide], init_temperature: float = 1.0)
Bases: Guide
Finite mixture variational guide.
| PARAMETER | DESCRIPTION |
|---|---|
components
|
Component guides. All components must share the same
TYPE:
|
init_temperature
|
Initial Gumbel-Softmax temperature. Default
TYPE:
|
Source code in src/quivers/inference/guides/mixture.py
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 | |
set_temperature
¶
set_temperature(value: float) -> None
Anneal the Gumbel-Softmax temperature.
Source code in src/quivers/inference/guides/mixture.py
93 94 95 96 97 98 99 | |
rsample
¶
rsample(x: Tensor) -> dict[str, Tensor]
Reparameterized mixture draw via Gumbel-Softmax.
Each call samples a Gumbel-Softmax weight vector
:math:w \in \Delta^{K-1} and returns
:math:\sum_k w_k \cdot v^{(k)} per site, where
:math:v^{(k)} is component k's constrained-space
sample. Because the constrained-space sites' supports are
not in general convex (e.g. a Cholesky factor on
torch.distributions.constraints.corr_cholesky), the
soft mixture can drift outside any single component's
support during training; the categorical-pick fallback in
hard_rsample returns a single component's sample
for use at inference time.
Source code in src/quivers/inference/guides/mixture.py
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | |
hard_rsample
¶
hard_rsample(x: Tensor) -> dict[str, Tensor]
Categorical-pick variant: sample a component index and return that component's draw verbatim. Use at inference time when soft-mixture interpolation would violate a support constraint.
Source code in src/quivers/inference/guides/mixture.py
138 139 140 141 142 143 144 145 | |
log_prob
¶
log_prob(x: Tensor, sites: dict[str, Tensor]) -> Tensor
Mixture log-density via logsumexp over components.
Source code in src/quivers/inference/guides/mixture.py
151 152 153 154 155 156 157 158 159 160 161 | |
AutoLowRankMultivariateNormalGuide
¶
AutoLowRankMultivariateNormalGuide(model: MonadicProgram, observed_names: set[str], rank: int = 5, init_scale: float = 0.1)
Bases: _MVNCommon
Low-rank-plus-diagonal multivariate-Normal guide.
Covariance :math:\Sigma = W W^\top + \mathrm{diag}(\sigma^2)
with W of shape :math:(D, r) and :math:\sigma \in
\mathbb{R}^{D}_{>0}. Memory :math:O(Dr); sampling and
log-density via Woodbury / matrix-determinant lemma in
torch.distributions.LowRankMultivariateNormal.
Captures the dominant r posterior correlation directions
while remaining tractable for D in the hundreds-to-
thousands range, where full-rank is infeasible.
| PARAMETER | DESCRIPTION |
|---|---|
model
|
Generative model to build a guide for.
TYPE:
|
observed_names
|
Variable names treated as observations.
TYPE:
|
rank
|
Number of correlated directions. Default
TYPE:
|
init_scale
|
Initial diagonal scale.
TYPE:
|
Source code in src/quivers/inference/guides/multivariate_normal.py
235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 | |
AutoMultivariateNormalGuide
¶
AutoMultivariateNormalGuide(model: MonadicProgram, observed_names: set[str], init_scale: float = 0.1)
Bases: _MVNCommon
Full-rank multivariate-Normal variational guide.
Parameterises a joint Gaussian over the registry's flat unconstrained vector with a learnable lower-triangular Cholesky factor. Captures every pairwise posterior correlation across every latent site — the right choice when posterior couplings are strong (hierarchical regression with crossed random effects, parameter pairs with multiplicative interaction).
| PARAMETER | DESCRIPTION |
|---|---|
model
|
Generative model to build a guide for.
TYPE:
|
observed_names
|
Variable names treated as observations.
TYPE:
|
init_scale
|
Initial diagonal of the Cholesky factor. Default
TYPE:
|
Source code in src/quivers/inference/guides/multivariate_normal.py
187 188 189 190 191 192 193 194 195 196 197 | |
AutoNormalGuide
¶
AutoNormalGuide(model: MonadicProgram, observed_names: set[str], init_scale: float = 0.1)
Bases: Guide
Mean-field Normal guide with per-site constrained-support bijector.
| PARAMETER | DESCRIPTION |
|---|---|
model
|
Generative model to build a guide for.
TYPE:
|
observed_names
|
Variable names treated as observations (skipped in the guide; their values flow through the conditioning data dict at trace time).
TYPE:
|
init_scale
|
Initial scale (in unconstrained space) of every latent.
Default
TYPE:
|
Source code in src/quivers/inference/guides/normal.py
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 | |
rsample
¶
rsample(x: Tensor) -> dict[str, Tensor]
Reparameterized mean-field Normal-then-bijector sample.
Source code in src/quivers/inference/guides/normal.py
128 129 130 131 132 133 134 135 | |
log_prob
¶
log_prob(x: Tensor, sites: dict[str, Tensor]) -> Tensor
Pushforward log-density at constrained values sites.
Uses the change-of-variables identity:
log q(v) = log Normal(z; loc, scale) + log|det J_{T^{-1}}(v)|
where z = bijector.inv(v). The plate / scalar shape
dispatch matches rsample's convention.
Source code in src/quivers/inference/guides/normal.py
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 | |