HMM filtering (forwards algorithm)
# meta-data does not work yet in VScode
# https://github.com/microsoft/vscode-jupyter/issues/1121
{
"tags": [
"hide-cell"
]
}
### Install necessary libraries
try:
import jax
except:
# For cuda version, see https://github.com/google/jax#installation
%pip install --upgrade "jax[cpu]"
import jax
try:
import distrax
except:
%pip install --upgrade distrax
import distrax
try:
import jsl
except:
%pip install git+https://github.com/probml/jsl
import jsl
#try:
# import ssm_jax
##except:
# %pip install git+https://github.com/probml/ssm-jax
# import ssm_jax
try:
import rich
except:
%pip install rich
import rich
{
"tags": [
"hide-cell"
]
}
### Import standard libraries
import abc
from dataclasses import dataclass
import functools
import itertools
from typing import Any, Callable, NamedTuple, Optional, Union, Tuple
import matplotlib.pyplot as plt
import numpy as np
import jax
import jax.numpy as jnp
from jax import lax, vmap, jit, grad
from jax.scipy.special import logit
from jax.nn import softmax
from functools import partial
from jax.random import PRNGKey, split
import inspect
import inspect as py_inspect
import rich
from rich import inspect as r_inspect
from rich import print as r_print
def print_source(fname):
r_print(py_inspect.getsource(fname))
HMM filtering (forwards algorithm)¶
The Bayes filter is an algorithm for recursively computing the belief state \(p(\hidden_t|\obs_{1:t})\) given the prior belief from the previous step, \(p(\hidden_{t-1}|\obs_{1:t-1})\), the new observation \(\obs_t\), and the model. This can be done using sequential Bayesian updating. For a dynamical model, this reduces to the predict-update cycle described below.
The prediction step is just the Chapman-Kolmogorov equation:
The prediction step computes the one-step-ahead predictive distribution for the latent state, which updates the posterior from the previous time step into the prior for the current step.
The update step is just Bayes rule:
where the normalization constant is
When the latent states \(\hidden_t\) are discrete, as in HMM, the above integrals become sums. In particular, suppose we define the belief state as \(\alpha_t(j) \defeq p(\hidden_t=j|\obs_{1:t})\), the local evidence as \(\lambda_t(j) \defeq p(\obs_t|\hidden_t=j)\), and the transition matrix \(A(i,j) = p(\hidden_t=j|\hidden_{t-1}=i)\). Then the predict step becomes
and the update step becomes
where the normalization constant for each time step is given by
Since all the quantities are finite length vectors and matrices, we can write the update equation in matrix-vector notation as follows:
where \(\dotstar\) represents elementwise vector multiplication, and the \(\text{normalize}\) function just ensures its argument sums to one.
In {ref}(sec:casino-inference) we illustrate filtering for the casino HMM, applied to a random sequence \(\obs_{1:T}\) of length \(T=300\). In blue, we plot the probability that the dice is in the loaded (vs fair) state, based on the evidence seen so far. The gray bars indicate time intervals during which the generative process actually switched to the loaded dice. We see that the probability generally increases in the right places.
Here is a JAX implementation of the forwards algorithm.
import jsl.hmm.hmm_lib as hmm_lib
print_source(hmm_lib.hmm_forwards_jax)
#https://github.com/probml/JSL/blob/main/jsl/hmm/hmm_lib.py#L189
@jit def hmm_forwards_jax(params, obs_seq, length=None): ''' Calculates a belief state Parameters ---------- params : HMMJax Hidden Markov Model obs_seq: array(seq_len) History of observable events Returns ------- * float The loglikelihood giving log(p(x|model)) * array(seq_len, n_hidden) : All alpha values found for each sample ''' seq_len = len(obs_seq) if length is None: length = seq_len trans_mat, obs_mat, init_dist = params.trans_mat, params.obs_mat, params.init_dist trans_mat = jnp.array(trans_mat) obs_mat = jnp.array(obs_mat) init_dist = jnp.array(init_dist) n_states, n_obs = obs_mat.shape def scan_fn(carry, t): (alpha_prev, log_ll_prev) = carry alpha_n = jnp.where(t < length, obs_mat[:, obs_seq] * (alpha_prev[:, None] * trans_mat).sum(axis=0), jnp.zeros_like(alpha_prev)) alpha_n, cn = normalize(alpha_n) carry = (alpha_n, jnp.log(cn) + log_ll_prev) return carry, alpha_n # initial belief state alpha_0, c0 = normalize(init_dist * obs_mat[:, obs_seq[0]]) # setup scan loop init_state = (alpha_0, jnp.log(c0)) ts = jnp.arange(1, seq_len) carry, alpha_hist = lax.scan(scan_fn, init_state, ts) # post-process alpha_hist = jnp.vstack() (alpha_final, log_ll) = carry return log_ll, alpha_hist