{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "(sec:forwards)=\n", "# HMM filtering (forwards algorithm)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "### Import standard libraries\n", "\n", "import abc\n", "from dataclasses import dataclass\n", "import functools\n", "from functools import partial\n", "import itertools\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "from typing import Any, Callable, NamedTuple, Optional, Union, Tuple\n", "\n", "import jax\n", "import jax.numpy as jnp\n", "from jax import lax, vmap, jit, grad\n", "#from jax.scipy.special import logit\n", "#from jax.nn import softmax\n", "import jax.random as jr\n", "\n", "\n", "\n", "import distrax\n", "import optax\n", "\n", "import jsl\n", "import ssm_jax" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## Introduction\n", "\n", "\n", "The $\\keyword{Bayes filter}$ is an algorithm for recursively computing\n", "the belief state\n", "$p(\\hidden_t|\\obs_{1:t})$ given\n", "the prior belief from the previous step,\n", "$p(\\hidden_{t-1}|\\obs_{1:t-1})$,\n", "the new observation $\\obs_t$,\n", "and the model.\n", "This can be done using $\\keyword{sequential Bayesian updating}$.\n", "For a dynamical model, this reduces to the\n", "$\\keyword{predict-update}$ cycle described below.\n", "\n", "The $\\keyword{prediction step}$ is just the $\\keyword{Chapman-Kolmogorov equation}$:\n", "\\begin{align}\n", "p(\\hidden_t|\\obs_{1:t-1})\n", "= \\int p(\\hidden_t|\\hidden_{t-1}) p(\\hidden_{t-1}|\\obs_{1:t-1}) d\\hidden_{t-1}\n", "\\end{align}\n", "The prediction step computes\n", "the $\\keyword{one-step-ahead predictive distribution}$\n", "for the latent state, which converts\n", "the posterior from the previous time step to become the prior\n", "for the current step.\n", "\n", "\n", "The $\\keyword{update step}$\n", "is just Bayes rule:\n", "\\begin{align}\n", "p(\\hidden_t|\\obs_{1:t}) = \\frac{1}{Z_t}\n", "p(\\obs_t|\\hidden_t) p(\\hidden_t|\\obs_{1:t-1})\n", "\\end{align}\n", "where the normalization constant is\n", "\\begin{align}\n", "Z_t = \\int p(\\obs_t|\\hidden_t) p(\\hidden_t|\\obs_{1:t-1}) d\\hidden_{t}\n", "= p(\\obs_t|\\obs_{1:t-1})\n", "\\end{align}\n", "\n", "Note that we can derive the log marginal likelihood from these normalization constants\n", "as follows:\n", "```{math}\n", ":label: eqn:logZ\n", "\n", "\\log p(\\obs_{1:T})\n", "= \\sum_{t=1}^{T} \\log p(\\obs_t|\\obs_{1:t-1})\n", "= \\sum_{t=1}^{T} \\log Z_t\n", "```\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "When the latent states $\\hidden_t$ are discrete, as in HMM,\n", "the above integrals become sums.\n", "In particular, suppose we define\n", "the $\\keyword{belief state}$ as $\\alpha_t(j) \\defeq p(\\hidden_t=j|\\obs_{1:t})$,\n", "the $\\keyword{local evidence}$ (or $\\keyword{local likelihood}$)\n", "as $\\lambda_t(j) \\defeq p(\\obs_t|\\hidden_t=j)$,\n", "and the transition matrix as\n", "$\\hmmTrans(i,j) = p(\\hidden_t=j|\\hidden_{t-1}=i)$.\n", "Then the predict step becomes\n", "```{math}\n", ":label: eqn:predictiveHMM\n", "\\alpha_{t|t-1}(j) \\defeq p(\\hidden_t=j|\\obs_{1:t-1})\n", " = \\sum_i \\alpha_{t-1}(i) A(i,j)\n", "```\n", "and the update step becomes\n", "```{math}\n", ":label: eqn:fwdsEqn\n", "\\alpha_t(j)\n", "= \\frac{1}{Z_t} \\lambda_t(j) \\alpha_{t|t-1}(j)\n", "= \\frac{1}{Z_t} \\lambda_t(j) \\left[\\sum_i \\alpha_{t-1}(i) \\hmmTrans(i,j) \\right]\n", "```\n", "where\n", "the normalization constant for each time step is given by\n", "```{math}\n", ":label: eqn:HMMZ\n", "\\begin{align}\n", "Z_t \\defeq p(\\obs_t|\\obs_{1:t-1})\n", "&= \\sum_{j=1}^K p(\\obs_t|\\hidden_t=j) p(\\hidden_t=j|\\obs_{1:t-1}) \\\\\n", "&= \\sum_{j=1}^K \\lambda_t(j) \\alpha_{t|t-1}(j)\n", "\\end{align}\n", "```\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "Since all the quantities are finite length vectors and matrices,\n", "we can implement the whole procedure using matrix vector multoplication:\n", "```{math}\n", ":label: eqn:fwdsAlgoMatrixForm\n", "\\valpha_t =\\text{normalize}\\left(\n", "\\vlambda_t \\dotstar (\\hmmTrans^{\\trans} \\valpha_{t-1}) \\right)\n", "```\n", "where $\\dotstar$ represents\n", "elementwise vector multiplication,\n", "and the $\\text{normalize}$ function just ensures its argument sums to one.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Example\n", "\n", "In {ref}`sec:casino-inference`\n", "we illustrate\n", "filtering for the casino HMM,\n", "applied to a random sequence $\\obs_{1:T}$ of length $T=300$.\n", "In blue, we plot the probability that the dice is in the loaded (vs fair) state,\n", "based on the evidence seen so far.\n", "The gray bars indicate time intervals during which the generative\n", "process actually switched to the loaded dice.\n", "We see that the probability generally increases in the right places." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Normalization constants\n", "\n", "In most publications on HMMs,\n", "such as {cite}`Rabiner89`,\n", "the forwards message is defined\n", "as the following unnormalized joint probability:\n", "```{math}\n", "\\alpha'_t(j) = p(\\hidden_t=j,\\obs_{1:t}) \n", "= \\lambda_t(j) \\left[\\sum_i \\alpha'_{t-1}(i) A(i,j) \\right]\n", "```\n", "In this book we define the forwards message as the normalized\n", "conditional probability\n", "```{math}\n", "\\alpha_t(j) = p(\\hidden_t=j|\\obs_{1:t}) \n", "= \\frac{1}{Z_t} \\lambda_t(j) \\left[\\sum_i \\alpha_{t-1}(i) A(i,j) \\right]\n", "```\n", "where $Z_t = p(\\obs_t|\\obs_{1:t-1})$.\n", "\n", "The \"traditional\" unnormalized form has several problems.\n", "First, it rapidly suffers from numerical underflow,\n", "since the probability of\n", "the joint event that $(\\hidden_t=j,\\obs_{1:t})$\n", "is vanishingly small. \n", "To see why, suppose the observations are independent of the states.\n", "In this case, the unnormalized joint has the form\n", "\\begin{align}\n", "p(\\hidden_t=j,\\obs_{1:t}) = p(\\hidden_t=j)\\prod_{i=1}^t p(\\obs_i)\n", "\\end{align}\n", "which becomes exponentially small with $t$, because we multiply\n", "many probabilities which are less than one.\n", "Second, the unnormalized probability is less interpretable,\n", "since it is a joint distribution over states and observations,\n", "rather than a conditional probability of states given observations.\n", "Third, the unnormalized joint form is harder to approximate\n", "than the normalized form.\n", "Of course,\n", "the two definitions only differ by a\n", "multiplicative constant\n", "{cite}`Devijver85`,\n", "so the algorithmic difference is just\n", "one line of code (namely the presence or absence of a call to the `normalize` function).\n", "\n", "\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Naive implementation\n", "\n", "Below we give a simple numpy implementation of the forwards algorithm.\n", "We assume the HMM uses categorical observations, for simplicity.\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "\n", "def normalize_np(u, axis=0, eps=1e-15):\n", " u = np.where(u == 0, 0, np.where(u < eps, eps, u))\n", " c = u.sum(axis=axis)\n", " c = np.where(c == 0, 1, c)\n", " return u / c, c\n", "\n", "def hmm_forwards_np(trans_mat, obs_mat, init_dist, obs_seq):\n", " n_states, n_obs = obs_mat.shape\n", " seq_len = len(obs_seq)\n", "\n", " alpha_hist = np.zeros((seq_len, n_states))\n", " ll_hist = np.zeros(seq_len) # loglikelihood history\n", "\n", " alpha_n = init_dist * obs_mat[:, obs_seq[0]]\n", " alpha_n, cn = normalize_np(alpha_n)\n", "\n", " alpha_hist[0] = alpha_n\n", " log_normalizer = np.log(cn)\n", "\n", " for t in range(1, seq_len):\n", " alpha_n = obs_mat[:, obs_seq[t]] * (alpha_n[:, None] * trans_mat).sum(axis=0)\n", " alpha_n, zn = normalize_np(alpha_n)\n", "\n", " alpha_hist[t] = alpha_n\n", " log_normalizer = np.log(zn) + log_normalizer\n", "\n", " return log_normalizer, alpha_hist" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Numerically stable implementation \n", "\n", "\n", "\n", "In practice it is more numerically stable to compute\n", "the log likelihoods $\\ell_t(j) = \\log p(\\obs_t|\\hidden_t=j)$,\n", "rather than the likelioods $\\lambda_t(j) = p(\\obs_t|\\hidden_t=j)$.\n", "In this case, we can perform the posterior updating in a numerically stable way as follows.\n", "Define $L_t = \\max_j \\ell_t(j)$ and\n", "\\begin{align}\n", "\\tilde{p}(\\hidden_t=j,\\obs_t|\\obs_{1:t-1})\n", "&\\defeq p(\\hidden_t=j|\\obs_{1:t-1}) p(\\obs_t|\\hidden_t=j) e^{-L_t} \\\\\n", " &= p(\\hidden_t=j|\\obs_{1:t-1}) e^{\\ell_t(j) - L_t}\n", "\\end{align}\n", "Then we have\n", "\\begin{align}\n", "p(\\hidden_t=j|\\obs_t,\\obs_{1:t-1})\n", " &= \\frac{1}{\\tilde{Z}_t} \\tilde{p}(\\hidden_t=j,\\obs_t|\\obs_{1:t-1}) \\\\\n", "\\tilde{Z}_t &= \\sum_j \\tilde{p}(\\hidden_t=j,\\obs_t|\\obs_{1:t-1})\n", "= p(\\obs_t|\\obs_{1:t-1}) e^{-L_t} \\\\\n", "\\log Z_t &= \\log p(\\obs_t|\\obs_{1:t-1}) = \\log \\tilde{Z}_t + L_t\n", "\\end{align}\n", "\n", "Below we show some JAX code that implements this core operation.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "def _condition_on(probs, ll):\n", " ll_max = ll.max()\n", " new_probs = probs * jnp.exp(ll - ll_max)\n", " norm = new_probs.sum()\n", " new_probs /= norm\n", " log_norm = jnp.log(norm) + ll_max\n", " return new_probs, log_norm" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "With the above function, we can implement a more numerically stable version of the forwards filter,\n", "that works for any likelihood function, as shown below. It takes in the prior predictive distribution,\n", "$\\alpha_{t|t-1}$,\n", "stored in `predicted_probs`, and conditions them on the log-likelihood for each time step $\\ell_t$ to get the\n", "posterior, $\\alpha_t$, stored in `filtered_probs`,\n", "which is then converted to the prediction for the next state, $\\alpha_{t+1|t}$." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def _predict(probs, A):\n", " return A.T @ probs\n", "\n", "\n", "def hmm_filter(initial_distribution,\n", " transition_matrix,\n", " log_likelihoods):\n", " def _step(carry, t):\n", " log_normalizer, predicted_probs = carry\n", "\n", " # Get parameters for time t\n", " get = lambda x: x[t] if x.ndim == 3 else x\n", " A = get(transition_matrix)\n", " ll = log_likelihoods[t]\n", "\n", " # Condition on emissions at time t, being careful not to overflow\n", " filtered_probs, log_norm = _condition_on(predicted_probs, ll)\n", " # Update the log normalizer\n", " log_normalizer += log_norm\n", " # Predict the next state\n", " predicted_probs = _predict(filtered_probs, A)\n", "\n", " return (log_normalizer, predicted_probs), (filtered_probs, predicted_probs)\n", "\n", " num_timesteps = len(log_likelihoods)\n", " carry = (0.0, initial_distribution)\n", " (log_normalizer, _), (filtered_probs, predicted_probs) = lax.scan(\n", " _step, carry, jnp.arange(num_timesteps))\n", " return log_normalizer, filtered_probs, predicted_probs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "TODO: check equivalence of these two implementations!" ] } ], "metadata": { "interpreter": { "hash": "6407c60499271029b671b4ff687c4ed4626355c45fd34c44476827f4be42c4d7" }, "kernelspec": { "display_name": "Python 3.9.2 ('spyder-dev')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.2" } }, "nbformat": 4, "nbformat_minor": 4 }