Deep Markov Model

Overview

The deep Markov model of Krishnan, Shalit, and Sontag (2017) is a state-space model with nonlinear, neural-network-parameterized transition and emission kernels:

\[ s_t = f_\theta(s_{t-1}) + \varepsilon_t, \quad \varepsilon_t \sim \mathcal{N}(0, \sigma_s^2 I) $$ $$ o_t = g_\phi(s_t) + \eta_t, \quad \eta_t \sim \mathcal{N}(0, \sigma_o^2 I) \]

The transition and emission means are MLPs; per-step Normal noise gives a tractable density. The companion recognition network q_\phi(o_t, s_{t-1}) -> s_t carries the variational posterior and is threaded across the sequence by scan to amortize the posterior over the latent trajectory. The combinator surface mirrors the linear-Gaussian SSM: only the per-step cells change.

QVR Source

object Driver : Real 4
object Hidden : Real 32
object State : Real 8
object Obs : Real 4

morphism trans_mlp_1 : Driver * State -> Hidden [role=kernel, scale=0.5] ~ Normal
morphism trans_mlp_2 : Hidden -> State [role=kernel, scale=0.1] ~ Normal
morphism emit_mlp_1 : State -> Hidden [role=kernel, scale=0.5] ~ Normal
morphism emit_mlp_2 : Hidden -> Obs [role=kernel, scale=0.1] ~ Normal
morphism infer_cell : Obs * State -> State [role=kernel, scale=0.1] ~ Normal

let transition_cell = trans_mlp_1 >> trans_mlp_2
let emission = emit_mlp_1 >> emit_mlp_2
let generate = scan(transition_cell) >> emission
let recognize = scan(infer_cell)

export recognize

Walkthrough

The transition stack trans_mlp_1 >> trans_mlp_2 is a two-layer MLP that maps (u_t, s_{t-1}) through a hidden width of 32 down to the 8-d state; the emission stack emit_mlp_1 >> emit_mlp_2 is the symmetric decoder back to the 4-d observation. Both stacks are Kleisli compositions of Gaussian kernels, so the joint per-step kernel is a normalizing-flow-like reparameterisable Gaussian whose mean is the network output.

scan(transition_cell) >> emission is the generative pipeline; scan(infer_cell) is the variational autoencoder-style recognition network that threads the previous belief and the new observation through infer_cell to produce the next belief. The choice of Driver width controls the exogenous input; a non-driven model uses Driver = Euclidean 1 and feeds a zero vector.

Try it

The SVI step counts and NUTS warmup, sample, and chain budgets in the snippets below are illustrative: each block is sized to run in tens of seconds and demonstrate the API surface. Production fits typically need 10x to 100x more SVI steps, longer NUTS warmup, and multiple chains to actually converge to the data-generating parameters.

Generating synthetic data

Pick concrete ground-truth nonlinear dynamics (tanh recurrence on the latent, tanh decoder to observations) and forward-sample a single trajectory of length T. The latent dimension matches State = Real 8; the observation dimension matches Obs = Real 4.

import torch
from quivers.dsl import load

torch.manual_seed(0)
prog = load("docs/examples/source/deep_markov.qvr")
recognize = prog.morphism

T = 32
state_dim, obs_dim = 8, 4
W_s = 0.5 * torch.randn(state_dim, state_dim)
W_o = 0.3 * torch.randn(obs_dim, state_dim)
s = torch.zeros(T + 1, state_dim)
o = torch.zeros(T, obs_dim)
for t in range(T):
    s[t + 1] = torch.tanh(s[t] @ W_s.T) + 0.1 * torch.randn(state_dim)
    o[t] = torch.tanh(s[t + 1] @ W_o.T) + 0.1 * torch.randn(obs_dim)
o_seq = o.unsqueeze(0)
state_seq = s[1:].unsqueeze(0)

SVI fit

The exported recognize is a ScanMorphism whose MLP weights are [role=kernel] parameters without explicit priors; bayesian_lift_parameters lifts each leaf into a unit-Normal sample site so AutoNormalGuide can build a mean-field surrogate. The thin DictWrap adapter exposes log_joint(x, obs_dict) over the scan's positional state-trajectory argument.

from quivers.inference import AutoNormalGuide, ELBO, SVI, bayesian_lift_parameters

torch.manual_seed(1)
prog = load("docs/examples/source/deep_markov.qvr")
inner = prog.morphism
model, x_lift, obs_lift = bayesian_lift_parameters(
    inner, o_seq, {"h": state_seq}, prior_scale=1.0,
)

guide = AutoNormalGuide(model, observed_names={"h"})
optim = torch.optim.Adam(
    list(model.parameters()) + list(guide.parameters()), lr=1e-2,
)
svi = SVI(model, guide, optim, ELBO())

loss0 = svi.step(x_lift, obs_lift)
losses = [svi.step(x_lift, obs_lift) for _ in range(300)]
loss_final = sum(losses[-20:]) / 20.0
oracle_ll = inner.log_joint(o_seq, state_seq).item()
print(f"initial ELBO loss: {loss0:.1f}")
print(f"final ELBO loss:   {loss_final:.1f}")
print(f"oracle -log p(h):  {-oracle_ll:.1f}")

NUTS posterior

The lifted model carries one Normal sample site per leaf parameter; NUTSKernel samples them directly.

from quivers.inference import MCMC, NUTSKernel, bayesian_lift_parameters

torch.manual_seed(2)
prog = load("docs/examples/source/deep_markov.qvr")
model, x_lift, obs_lift = bayesian_lift_parameters(
    prog.morphism, o_seq, {"h": state_seq}, prior_scale=1.0,
)

kernel = NUTSKernel(step_size=0.05, max_tree_depth=3, target_accept=0.8)
mc = MCMC(kernel, num_warmup=15, num_samples=15, num_chains=1)
result = mc.run(model, x_lift, obs_lift)

print("acceptance:", float(result.acceptance_rates.mean()))
print("divergences:", int(result.divergence_counts.sum()))

Categorical Perspective

The transition stack is the Kleisli composition of two Gaussian kernels; the second kernel's mean depends on the sample from the first, so the joint per-step kernel is no longer Gaussian, only a reparameterisable density. scan realizes the iterated Kleisli composition over the time index, so the full trajectory kernel is the right Kan extension of the per-step cell along the time projection.

The recognizer is a directed inverse of the generative kernel: where the prior is a forward chain s_0 -> s_1 -> ... -> s_T -> o_{1:T}, the recognizer is the encoder side of an amortized variational posterior. The two share the latent space State but live in opposite Kleisli morphisms; SVI tunes them jointly against an ELBO.

flowchart LR
    s__t_1_["s_{t-1}"] --> trans_mlp_1["trans_mlp_1"]
    u_t["u_t"] --> trans_mlp_1["trans_mlp_1"]
    trans_mlp_1["trans_mlp_1"] --> h_trans["h_trans"]
    h_trans["h_trans"] --> trans_mlp_2["trans_mlp_2"]
    trans_mlp_2["trans_mlp_2"] --> s_t["s_t"]
    s_t["s_t"] --> emit_mlp_1["emit_mlp_1"]
    emit_mlp_1["emit_mlp_1"] --> h_emit["h_emit"]
    h_emit["h_emit"] --> emit_mlp_2["emit_mlp_2"]
    emit_mlp_2["emit_mlp_2"] --> o_t["o_t"]

References

  • Diederik P. Kingma and Max Welling. 2013. Auto-Encoding Variational Bayes. arXiv preprint arXiv:1312.6114.
  • Rahul G. Krishnan, Uri Shalit, and David Sontag. 2017. Structured inference networks for nonlinear state space models. In Proceedings of the Thirty-First AAAI Conference on Artificial Intelligence (AAAI '17), pages 2101–2109. AAAI Press.