Linear Gaussian SSMs
Contents
# meta-data does not work yet in VScode
# https://github.com/microsoft/vscode-jupyter/issues/1121
{
"tags": [
"hide-cell"
]
}
### Install necessary libraries
try:
import jax
except:
# For cuda version, see https://github.com/google/jax#installation
%pip install --upgrade "jax[cpu]"
import jax
try:
import distrax
except:
%pip install --upgrade distrax
import distrax
try:
import jsl
except:
%pip install git+https://github.com/probml/jsl
import jsl
try:
import rich
except:
%pip install rich
import rich
{
"tags": [
"hide-cell"
]
}
### Import standard libraries
import abc
from dataclasses import dataclass
import functools
import itertools
from typing import Any, Callable, NamedTuple, Optional, Union, Tuple
import matplotlib.pyplot as plt
import numpy as np
import jax
import jax.numpy as jnp
from jax import lax, vmap, jit, grad
from jax.scipy.special import logit
from jax.nn import softmax
from functools import partial
from jax.random import PRNGKey, split
import inspect
import inspect as py_inspect
import rich
from rich import inspect as r_inspect
from rich import print as r_print
def print_source(fname):
r_print(py_inspect.getsource(fname))
Linear Gaussian SSMs¶
Consider the state space model in (2) where we assume the observations are conditionally iid given the hidden states and inputs (i.e. there are no auto-regressive dependencies between the observables). We can rewrite this model as a stochastic nonlinear dynamical system (NLDS) by defining the distribution of the next hidden state as a deterministic function of the past state plus random process noise \(\vepsilon_t\)
where \(\vepsilon_t\) is drawn from the distribution such that the induced distribution on \(\hmmhid_t\) matches \(p(\hmmhid_t|\hmmhid_{t-1}, \inputs_t)\). Similarly we can rewrite the observation distributions as a deterministic function of the hidden state plus observation noise \(\veta_t\):
If we assume additive Gaussian noise, the model becomes
where \(\vepsilon_t \sim \gauss(\vzero,\vQ_t)\) and \(\veta_t \sim \gauss(\vzero,\vR_t)\). We will call these Gaussian SSMs.
If we additionally assume the transition function \(\ssmDynFn\) and the observation function \(\ssmObsFn\) are both linear, then we can rewrite the model as follows:
This is called a linear-Gaussian state space model (LG-SSM), or a linear dynamical system (LDS). We usually assume the parameters are independent of time, in which case the model is said to be time-invariant or homogeneous.
Example: tracking a 2d point¶
Consider an object moving in \(\real^2\). Let the state be the position and velocity of the object, \(\vz_t =\begin{pmatrix} u_t & \dot{u}_t & v_t & \dot{v}_t \end{pmatrix}\). (We use \(u\) and \(v\) for the two coordinates, to avoid confusion with the state and observation variables.) If we use Euler discretization, the dynamics become
where \(\vepsilon_t \sim \gauss(\vzero,\vQ)\) is the process noise.
Let us assume that the process noise is a white noise process added to the velocity components of the state, but not to the location. (This is known as a random accelerations model.) We can approximate the resulting process in discrete time by assuming \(\vQ = \diag(0, q, 0, q)\). (See [Sar13] p60 for a more accurate way to convert the continuous time process to discrete time.)
Now suppose that at each discrete time point we observe the location, corrupted by Gaussian noise. Thus the observation model becomes
where \(\veta_t \sim \gauss(\vzero,\vR)\) is the \keywordDef{observation noise}. We see that the observation matrix \(\ldsObs\) simply ``extracts’’ the relevant parts of the state vector.
Suppose we sample a trajectory and corresponding set of noisy observations from this model, \((\vz_{1:T}, \vy_{1:T}) \sim p(\vz,\vy|\vtheta)\). (We use diagonal observation noise, \(\vR = \diag(\sigma_1^2, \sigma_2^2)\).) The results are shown below.
key = jax.random.PRNGKey(314)
timesteps = 15
delta = 1.0
A = jnp.array([
[1, 0, delta, 0],
[0, 1, 0, delta],
[0, 0, 1, 0],
[0, 0, 0, 1]
])
C = jnp.array([
[1, 0, 0, 0],
[0, 1, 0, 0]
])
state_size, _ = A.shape
observation_size, _ = C.shape
Q = jnp.eye(state_size) * 0.001
R = jnp.eye(observation_size) * 1.0
# Prior parameter distribution
mu0 = jnp.array([8, 10, 1, 0]).astype(float)
Sigma0 = jnp.eye(state_size) * 1.0
from jsl.lds.kalman_filter import LDS, smooth, filter
lds = LDS(A, C, Q, R, mu0, Sigma0)
print(lds)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
LDS(A=DeviceArray([[1., 0., 1., 0.],
[0., 1., 0., 1.],
[0., 0., 1., 0.],
[0., 0., 0., 1.]], dtype=float32), C=DeviceArray([[1, 0, 0, 0],
[0, 1, 0, 0]], dtype=int32), Q=DeviceArray([[0.001, 0. , 0. , 0. ],
[0. , 0.001, 0. , 0. ],
[0. , 0. , 0.001, 0. ],
[0. , 0. , 0. , 0.001]], dtype=float32), R=DeviceArray([[1., 0.],
[0., 1.]], dtype=float32), mu=DeviceArray([ 8., 10., 1., 0.], dtype=float32), Sigma=DeviceArray([[1., 0., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 1., 0.],
[0., 0., 0., 1.]], dtype=float32), state_offset=None, obs_offset=None, nstates=4, nobs=2)
from jsl.demos.plot_utils import plot_ellipse
def plot_tracking_values(observed, filtered, cov_hist, signal_label, ax):
timesteps, _ = observed.shape
ax.plot(observed[:, 0], observed[:, 1], marker="o", linewidth=0,
markerfacecolor="none", markeredgewidth=2, markersize=8, label="observed", c="tab:green")
ax.plot(*filtered[:, :2].T, label=signal_label, c="tab:red", marker="x", linewidth=2)
for t in range(0, timesteps, 1):
covn = cov_hist[t][:2, :2]
plot_ellipse(covn, filtered[t, :2], ax, n_std=2.0, plot_center=False)
ax.axis("equal")
ax.legend()
z_hist, x_hist = lds.sample(key, timesteps)
fig_truth, axs = plt.subplots()
axs.plot(x_hist[:, 0], x_hist[:, 1],
marker="o", linewidth=0, markerfacecolor="none",
markeredgewidth=2, markersize=8,
label="observed", c="tab:green")
axs.plot(z_hist[:, 0], z_hist[:, 1],
linewidth=2, label="truth",
marker="s", markersize=8)
axs.legend()
axs.axis("equal")
(7.24486608505249, 23.857812213897706, 8.042076778411865, 11.636079120635987)
The main task is to infer the hidden states given the noisy observations, i.e., \(p(\vz|\vy,\vtheta)\). We discuss the topic of inference in Inferential goals.