{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: distrax in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (0.0.1)\n", "Collecting distrax\n", " Downloading distrax-0.1.2-py3-none-any.whl (272 kB)\n", "\u001b[K |████████████████████████████████| 272 kB 6.9 MB/s eta 0:00:01\n", "\u001b[?25hRequirement already satisfied: jax>=0.1.55 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from distrax) (0.2.11)\n", "Requirement already satisfied: absl-py>=0.9.0 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from distrax) (0.12.0)\n", "Requirement already satisfied: chex>=0.0.7 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from distrax) (0.0.8)\n", "Requirement already satisfied: jaxlib>=0.1.67 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from distrax) (0.1.70)\n", "Requirement already satisfied: numpy>=1.18.0 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from distrax) (1.19.5)\n", "Collecting tensorflow-probability>=0.15.0\n", " Using cached tensorflow_probability-0.16.0-py2.py3-none-any.whl (6.3 MB)\n", "Requirement already satisfied: six in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from absl-py>=0.9.0->distrax) (1.15.0)\n", "Requirement already satisfied: dm-tree>=0.1.5 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from chex>=0.0.7->distrax) (0.1.6)\n", "Requirement already satisfied: toolz>=0.9.0 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from chex>=0.0.7->distrax) (0.11.1)\n", "Requirement already satisfied: opt-einsum in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from jax>=0.1.55->distrax) (3.3.0)\n", "Requirement already satisfied: flatbuffers<3.0,>=1.12 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from jaxlib>=0.1.67->distrax) (1.12)\n", "Requirement already satisfied: scipy in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from jaxlib>=0.1.67->distrax) (1.6.3)\n", "Requirement already satisfied: cloudpickle>=1.3 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from tensorflow-probability>=0.15.0->distrax) (1.6.0)\n", "Requirement already satisfied: decorator in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from tensorflow-probability>=0.15.0->distrax) (4.4.2)\n", "Requirement already satisfied: gast>=0.3.2 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from tensorflow-probability>=0.15.0->distrax) (0.4.0)\n", "Installing collected packages: tensorflow-probability, distrax\n", " Attempting uninstall: tensorflow-probability\n", " Found existing installation: tensorflow-probability 0.13.0\n", " Uninstalling tensorflow-probability-0.13.0:\n", " Successfully uninstalled tensorflow-probability-0.13.0\n", " Attempting uninstall: distrax\n", " Found existing installation: distrax 0.0.1\n", " Uninstalling distrax-0.0.1:\n", " Successfully uninstalled distrax-0.0.1\n", "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", "jsl 0.0.0 requires dataclasses, which is not installed.\u001b[0m\n", "Successfully installed distrax-0.1.2 tensorflow-probability-0.16.0\n", "\u001b[33mWARNING: You are using pip version 21.2.4; however, version 22.0.4 is available.\n", "You should consider upgrading via the '/opt/anaconda3/envs/spyder-dev/bin/python -m pip install --upgrade pip' command.\u001b[0m\n", "Note: you may need to restart the kernel to use updated packages.\n" ] } ], "source": [ "# meta-data does not work yet in VScode\n", "# https://github.com/microsoft/vscode-jupyter/issues/1121\n", "\n", "{\n", " \"tags\": [\n", " \"hide-cell\"\n", " ]\n", "}\n", "\n", "\n", "### Install necessary libraries\n", "\n", "try:\n", " import jax\n", "except:\n", " # For cuda version, see https://github.com/google/jax#installation\n", " %pip install --upgrade \"jax[cpu]\" \n", " import jax\n", "\n", "try:\n", " import distrax\n", "except:\n", " %pip install --upgrade distrax\n", " import distrax\n", "\n", "try:\n", " import jsl\n", "except:\n", " %pip install git+https://github.com/probml/jsl\n", " import jsl\n", "\n", "try:\n", " import rich\n", "except:\n", " %pip install rich\n", " import rich\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "{\n", " \"tags\": [\n", " \"hide-cell\"\n", " ]\n", "}\n", "\n", "\n", "### Import standard libraries\n", "\n", "import abc\n", "from dataclasses import dataclass\n", "import functools\n", "import itertools\n", "\n", "from typing import Any, Callable, NamedTuple, Optional, Union, Tuple\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "\n", "import jax\n", "import jax.numpy as jnp\n", "from jax import lax, vmap, jit, grad\n", "from jax.scipy.special import logit\n", "from jax.nn import softmax\n", "from functools import partial\n", "from jax.random import PRNGKey, split\n", "\n", "import inspect\n", "import inspect as py_inspect\n", "import rich\n", "from rich import inspect as r_inspect\n", "from rich import print as r_print\n", "\n", "def print_source(fname):\n", " r_print(py_inspect.getsource(fname))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```{math}\n", "\n", "\\newcommand\\floor[1]{\\lfloor#1\\rfloor}\n", "\n", "\\newcommand{\\real}{\\mathbb{R}}\n", "\n", "% Numbers\n", "\\newcommand{\\vzero}{\\boldsymbol{0}}\n", "\\newcommand{\\vone}{\\boldsymbol{1}}\n", "\n", "% Greek https://www.latex-tutorial.com/symbols/greek-alphabet/\n", "\\newcommand{\\valpha}{\\boldsymbol{\\alpha}}\n", "\\newcommand{\\vbeta}{\\boldsymbol{\\beta}}\n", "\\newcommand{\\vchi}{\\boldsymbol{\\chi}}\n", "\\newcommand{\\vdelta}{\\boldsymbol{\\delta}}\n", "\\newcommand{\\vDelta}{\\boldsymbol{\\Delta}}\n", "\\newcommand{\\vepsilon}{\\boldsymbol{\\epsilon}}\n", "\\newcommand{\\vzeta}{\\boldsymbol{\\zeta}}\n", "\\newcommand{\\vXi}{\\boldsymbol{\\Xi}}\n", "\\newcommand{\\vell}{\\boldsymbol{\\ell}}\n", "\\newcommand{\\veta}{\\boldsymbol{\\eta}}\n", "%\\newcommand{\\vEta}{\\boldsymbol{\\Eta}}\n", "\\newcommand{\\vgamma}{\\boldsymbol{\\gamma}}\n", "\\newcommand{\\vGamma}{\\boldsymbol{\\Gamma}}\n", "\\newcommand{\\vmu}{\\boldsymbol{\\mu}}\n", "\\newcommand{\\vmut}{\\boldsymbol{\\tilde{\\mu}}}\n", "\\newcommand{\\vnu}{\\boldsymbol{\\nu}}\n", "\\newcommand{\\vkappa}{\\boldsymbol{\\kappa}}\n", "\\newcommand{\\vlambda}{\\boldsymbol{\\lambda}}\n", "\\newcommand{\\vLambda}{\\boldsymbol{\\Lambda}}\n", "\\newcommand{\\vLambdaBar}{\\overline{\\vLambda}}\n", "%\\newcommand{\\vnu}{\\boldsymbol{\\nu}}\n", "\\newcommand{\\vomega}{\\boldsymbol{\\omega}}\n", "\\newcommand{\\vOmega}{\\boldsymbol{\\Omega}}\n", "\\newcommand{\\vphi}{\\boldsymbol{\\phi}}\n", "\\newcommand{\\vvarphi}{\\boldsymbol{\\varphi}}\n", "\\newcommand{\\vPhi}{\\boldsymbol{\\Phi}}\n", "\\newcommand{\\vpi}{\\boldsymbol{\\pi}}\n", "\\newcommand{\\vPi}{\\boldsymbol{\\Pi}}\n", "\\newcommand{\\vpsi}{\\boldsymbol{\\psi}}\n", "\\newcommand{\\vPsi}{\\boldsymbol{\\Psi}}\n", "\\newcommand{\\vrho}{\\boldsymbol{\\rho}}\n", "\\newcommand{\\vtheta}{\\boldsymbol{\\theta}}\n", "\\newcommand{\\vthetat}{\\boldsymbol{\\tilde{\\theta}}}\n", "\\newcommand{\\vTheta}{\\boldsymbol{\\Theta}}\n", "\\newcommand{\\vsigma}{\\boldsymbol{\\sigma}}\n", "\\newcommand{\\vSigma}{\\boldsymbol{\\Sigma}}\n", "\\newcommand{\\vSigmat}{\\boldsymbol{\\tilde{\\Sigma}}}\n", "\\newcommand{\\vsigmoid}{\\vsigma}\n", "\\newcommand{\\vtau}{\\boldsymbol{\\tau}}\n", "\\newcommand{\\vxi}{\\boldsymbol{\\xi}}\n", "\n", "\n", "% Lower Roman (Vectors)\n", "\\newcommand{\\va}{\\mathbf{a}}\n", "\\newcommand{\\vb}{\\mathbf{b}}\n", "\\newcommand{\\vBt}{\\mathbf{\\tilde{B}}}\n", "\\newcommand{\\vc}{\\mathbf{c}}\n", "\\newcommand{\\vct}{\\mathbf{\\tilde{c}}}\n", "\\newcommand{\\vd}{\\mathbf{d}}\n", "\\newcommand{\\ve}{\\mathbf{e}}\n", "\\newcommand{\\vf}{\\mathbf{f}}\n", "\\newcommand{\\vg}{\\mathbf{g}}\n", "\\newcommand{\\vh}{\\mathbf{h}}\n", "%\\newcommand{\\myvh}{\\mathbf{h}}\n", "\\newcommand{\\vi}{\\mathbf{i}}\n", "\\newcommand{\\vj}{\\mathbf{j}}\n", "\\newcommand{\\vk}{\\mathbf{k}}\n", "\\newcommand{\\vl}{\\mathbf{l}}\n", "\\newcommand{\\vm}{\\mathbf{m}}\n", "\\newcommand{\\vn}{\\mathbf{n}}\n", "\\newcommand{\\vo}{\\mathbf{o}}\n", "\\newcommand{\\vp}{\\mathbf{p}}\n", "\\newcommand{\\vq}{\\mathbf{q}}\n", "\\newcommand{\\vr}{\\mathbf{r}}\n", "\\newcommand{\\vs}{\\mathbf{s}}\n", "\\newcommand{\\vt}{\\mathbf{t}}\n", "\\newcommand{\\vu}{\\mathbf{u}}\n", "\\newcommand{\\vv}{\\mathbf{v}}\n", "\\newcommand{\\vw}{\\mathbf{w}}\n", "\\newcommand{\\vws}{\\vw_s}\n", "\\newcommand{\\vwt}{\\mathbf{\\tilde{w}}}\n", "\\newcommand{\\vWt}{\\mathbf{\\tilde{W}}}\n", "\\newcommand{\\vwh}{\\hat{\\vw}}\n", "\\newcommand{\\vx}{\\mathbf{x}}\n", "%\\newcommand{\\vx}{\\mathbf{x}}\n", "\\newcommand{\\vxt}{\\mathbf{\\tilde{x}}}\n", "\\newcommand{\\vy}{\\mathbf{y}}\n", "\\newcommand{\\vyt}{\\mathbf{\\tilde{y}}}\n", "\\newcommand{\\vz}{\\mathbf{z}}\n", "%\\newcommand{\\vzt}{\\mathbf{\\tilde{z}}}\n", "\n", "\n", "% Upper Roman (Matrices)\n", "\\newcommand{\\vA}{\\mathbf{A}}\n", "\\newcommand{\\vB}{\\mathbf{B}}\n", "\\newcommand{\\vC}{\\mathbf{C}}\n", "\\newcommand{\\vD}{\\mathbf{D}}\n", "\\newcommand{\\vE}{\\mathbf{E}}\n", "\\newcommand{\\vF}{\\mathbf{F}}\n", "\\newcommand{\\vG}{\\mathbf{G}}\n", "\\newcommand{\\vH}{\\mathbf{H}}\n", "\\newcommand{\\vI}{\\mathbf{I}}\n", "\\newcommand{\\vJ}{\\mathbf{J}}\n", "\\newcommand{\\vK}{\\mathbf{K}}\n", "\\newcommand{\\vL}{\\mathbf{L}}\n", "\\newcommand{\\vM}{\\mathbf{M}}\n", "\\newcommand{\\vMt}{\\mathbf{\\tilde{M}}}\n", "\\newcommand{\\vN}{\\mathbf{N}}\n", "\\newcommand{\\vO}{\\mathbf{O}}\n", "\\newcommand{\\vP}{\\mathbf{P}}\n", "\\newcommand{\\vQ}{\\mathbf{Q}}\n", "\\newcommand{\\vR}{\\mathbf{R}}\n", "\\newcommand{\\vS}{\\mathbf{S}}\n", "\\newcommand{\\vT}{\\mathbf{T}}\n", "\\newcommand{\\vU}{\\mathbf{U}}\n", "\\newcommand{\\vV}{\\mathbf{V}}\n", "\\newcommand{\\vW}{\\mathbf{W}}\n", "\\newcommand{\\vX}{\\mathbf{X}}\n", "%\\newcommand{\\vXs}{\\vX_{\\vs}}\n", "\\newcommand{\\vXs}{\\vX_{s}}\n", "\\newcommand{\\vXt}{\\mathbf{\\tilde{X}}}\n", "\\newcommand{\\vY}{\\mathbf{Y}}\n", "\\newcommand{\\vZ}{\\mathbf{Z}}\n", "\\newcommand{\\vZt}{\\mathbf{\\tilde{Z}}}\n", "\\newcommand{\\vzt}{\\mathbf{\\tilde{z}}}\n", "\n", "\n", "%%%%\n", "\\newcommand{\\hidden}{\\vz}\n", "\\newcommand{\\hid}{\\hidden}\n", "\\newcommand{\\observed}{\\vy}\n", "\\newcommand{\\obs}{\\observed}\n", "\\newcommand{\\inputs}{\\vu}\n", "\\newcommand{\\input}{\\inputs}\n", "\n", "\\newcommand{\\hmmTrans}{\\vA}\n", "\\newcommand{\\hmmObs}{\\vB}\n", "\\newcommand{\\hmmInit}{\\vpi}\n", "\\newcommand{\\hmmhid}{\\hidden}\n", "\\newcommand{\\hmmobs}{\\obs}\n", "\n", "\\newcommand{\\ldsDyn}{\\vA}\n", "\\newcommand{\\ldsObs}{\\vC}\n", "\\newcommand{\\ldsDynIn}{\\vB}\n", "\\newcommand{\\ldsObsIn}{\\vD}\n", "\\newcommand{\\ldsDynNoise}{\\vQ}\n", "\\newcommand{\\ldsObsNoise}{\\vR}\n", "\n", "\\newcommand{\\ssmDynFn}{f}\n", "\\newcommand{\\ssmObsFn}{h}\n", "\n", "\n", "%%%\n", "\\newcommand{\\gauss}{\\mathcal{N}}\n", "\n", "\\newcommand{\\diag}{\\mathrm{diag}}\n", "```\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "(sec:ssm-intro)=\n", "# What are State Space Models?\n", "\n", "\n", "A state space model or SSM\n", "is a partially observed Markov model,\n", "in which the hidden state, $\\hidden_t$,\n", "evolves over time according to a Markov process,\n", "possibly conditional on external inputs or controls $\\input_t$,\n", "and each hidden state generates some\n", "observations $\\obs_t$ at each time step.\n", "(In this book, we mostly focus on discrete time systems,\n", "although we consider the continuous-time case in XXX.)\n", "We get to see the observations, but not the hidden state.\n", "Our main goal is to infer the hidden state given the observations.\n", "However, we can also use the model to predict future observations,\n", "by first predicting future hidden states, and then predicting\n", "what observations they might generate.\n", "By using a hidden state $\\hidden_t$\n", "to represent the past observations, $\\obs_{1:t-1}$,\n", "the model can have ``infinite'' memory,\n", "unlike a standard Markov model.\n", "\n", "Formally we can define an SSM \n", "as the following joint distribution:\n", "```{math}\n", ":label: eq:SSM-ar\n", "p(\\hmmobs_{1:T},\\hmmhid_{1:T}|\\inputs_{1:T})\n", " = \\left[ p(\\hmmhid_1|\\inputs_1) \\prod_{t=2}^{T}\n", " p(\\hmmhid_t|\\hmmhid_{t-1},\\inputs_t) \\right]\n", " \\left[ \\prod_{t=1}^T p(\\hmmobs_t|\\hmmhid_t, \\inputs_t, \\hmmobs_{t-1}) \\right]\n", "```\n", "where $p(\\hmmhid_t|\\hmmhid_{t-1},\\inputs_t)$ is the\n", "transition model,\n", "$p(\\hmmobs_t|\\hmmhid_t, \\inputs_t, \\hmmobs_{t-1})$ is the\n", "observation model,\n", "and $\\inputs_{t}$ is an optional input or action.\n", "See {numref}`Figure %s ` \n", "for an illustration of the corresponding graphical model.\n", "\n", "\n", "```{figure} /figures/SSM-AR-inputs.png\n", ":scale: 100%\n", ":name: ssm-ar\n", "\n", "Illustration of an SSM as a graphical model.\n", "```\n", "\n", "\n", "We often consider a simpler setting in which the\n", " observations are conditionally independent of each other\n", "(rather than having Markovian dependencies) given the hidden state.\n", "In this case the joint simplifies to \n", "```{math}\n", ":label: eq:SSM-input\n", "p(\\hmmobs_{1:T},\\hmmhid_{1:T}|\\inputs_{1:T})\n", " = \\left[ p(\\hmmhid_1|\\inputs_1) \\prod_{t=2}^{T}\n", " p(\\hmmhid_t|\\hmmhid_{t-1},\\inputs_t) \\right]\n", " \\left[ \\prod_{t=1}^T p(\\hmmobs_t|\\hmmhid_t, \\inputs_t) \\right]\n", "```\n", "Sometimes there are no external inputs, so the model further\n", "simplifies to the following unconditional generative model: \n", "```{math}\n", ":label: eq:SSM-no-input\n", "p(\\hmmobs_{1:T},\\hmmhid_{1:T})\n", " = \\left[ p(\\hmmhid_1) \\prod_{t=2}^{T}\n", " p(\\hmmhid_t|\\hmmhid_{t-1}) \\right]\n", " \\left[ \\prod_{t=1}^T p(\\hmmobs_t|\\hmmhid_t) \\right]\n", "```\n", "See {numref}`Figure %s ` \n", "for an illustration of the corresponding graphical model.\n", "\n", "\n", "```{figure} /figures/SSM-simplified.png\n", ":scale: 100%\n", ":name: ssm-simplified\n", "\n", "Illustration of a simplified SSM.\n", "```\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "(sec:hmm-intro)=\n", "# Hidden Markov Models\n", "\n", "In this section, we discuss the\n", "hidden Markov model or HMM,\n", "which is a state space model in which the hidden states\n", "are discrete, so $\\hmmhid_t \\in \\{1,\\ldots, K\\}$.\n", "The observations may be discrete,\n", "$\\hmmobs_t \\in \\{1,\\ldots, C\\}$,\n", "or continuous,\n", "$\\hmmobs_t \\in \\real^D$,\n", "or some combination,\n", "as we illustrate below.\n", "More details can be found in e.g., \n", "{cite}`Rabiner89,Fraser08,Cappe05`.\n", "For an interactive introduction,\n", "see https://nipunbatra.github.io/hmm/." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "(sec:casino)=\n", "### Example: Casino HMM\n", "\n", "To illustrate HMMs with categorical observation model,\n", "we consider the \"Ocassionally dishonest casino\" model from {cite}`Durbin98`.\n", "There are 2 hidden states, representing whether the dice being used in the casino is fair or loaded.\n", "Each state defines a distribution over the 6 possible observations.\n", "\n", "The transition model is denoted by\n", "```{math}\n", "p(z_t=j|z_{t-1}=i) = \\hmmTrans_{ij}\n", "```\n", "Here the $i$'th row of $\\vA$ corresponds to the outgoing distribution from state $i$.\n", "This is a row stochastic matrix,\n", "meaning each row sums to one.\n", "We can visualize\n", "the non-zero entries in the transition matrix by creating a state transition diagram,\n", "as shown in \n", "{numref}`Figure %s `\n", "%{ref}`casino-fig`.\n", "\n", "```{figure} /figures/casino.png\n", ":scale: 50%\n", ":name: casino-fig\n", "\n", "Illustration of the casino HMM.\n", "```\n", "\n", "The observation model\n", "$p(\\obs_t|\\hidden_t=j)$ has the form\n", "```{math}\n", "p(\\obs_t=k|\\hidden_t=j) = \\hmmObs_{jk} \n", "```\n", "This is represented by the histograms associated with each\n", "state in {ref}`casino-fig`.\n", "\n", "Finally,\n", "the initial state distribution is denoted by\n", "```{math}\n", "p(z_1=j) = \\hmmInit_j\n", "```\n", "\n", "Collectively we denote all the parameters by $\\vtheta=(\\hmmTrans, \\hmmObs, \\hmmInit)$.\n", "\n", "Now let us implement this model in code." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# state transition matrix\n", "A = np.array([\n", " [0.95, 0.05],\n", " [0.10, 0.90]\n", "])\n", "\n", "# observation matrix\n", "B = np.array([\n", " [1/6, 1/6, 1/6, 1/6, 1/6, 1/6], # fair die\n", " [1/10, 1/10, 1/10, 1/10, 1/10, 5/10] # loaded die\n", "])\n", "\n", "pi = np.array([0.5, 0.5])\n", "\n", "(nstates, nobs) = np.shape(B)\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "import distrax\n", "from distrax import HMM\n", "\n", "\n", "hmm = HMM(trans_dist=distrax.Categorical(probs=A),\n", " init_dist=distrax.Categorical(probs=pi),\n", " obs_dist=distrax.Categorical(probs=B))\n", "\n", "print(hmm)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "Let's sample from the model. We will generate a sequence of latent states, $\\hid_{1:T}$,\n", "which we then convert to a sequence of observations, $\\obs_{1:T}$." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Printing sample observed/latent...\n", "x: 633665342652353616444236412331351246651613325161656366246242\n", "z: 222222211111111111111111111111111111111222111111112222211111\n" ] } ], "source": [ "\n", "\n", "\n", "seed = 314\n", "n_samples = 300\n", "z_hist, x_hist = hmm.sample(seed=PRNGKey(seed), seq_len=n_samples)\n", "\n", "z_hist_str = \"\".join((np.array(z_hist) + 1).astype(str))[:60]\n", "x_hist_str = \"\".join((np.array(x_hist) + 1).astype(str))[:60]\n", "\n", "print(\"Printing sample observed/latent...\")\n", "print(f\"x: {x_hist_str}\")\n", "print(f\"z: {z_hist_str}\")" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
  def sample(self,\n",
       "             *,\n",
       "             seed: chex.PRNGKey,\n",
       "             seq_len: chex.Array) -> Tuple:\n",
       "    \"\"\"Sample from this HMM.\n",
       "\n",
       "    Samples an observation of given length according to this\n",
       "    Hidden Markov Model and gives the sequence of the hidden states\n",
       "    as well as the observation.\n",
       "\n",
       "    Args:\n",
       "      seed: Random key of shape (2,) and dtype uint32.\n",
       "      seq_len: The length of the observation sequence.\n",
       "\n",
       "    Returns:\n",
       "      Tuple of hidden state sequence, and observation sequence.\n",
       "    \"\"\"\n",
       "    rng_key, rng_init = jax.random.split(seed)\n",
       "    initial_state = self._init_dist.sample(seed=rng_init)\n",
       "\n",
       "    def draw_state(prev_state, key):\n",
       "      state = self._trans_dist.sample(seed=key)\n",
       "      return state, state\n",
       "\n",
       "    rng_state, rng_obs = jax.random.split(rng_key)\n",
       "    keys = jax.random.split(rng_state, seq_len - 1)\n",
       "    _, states = jax.lax.scan(draw_state, initial_state, keys)\n",
       "    states = jnp.append(initial_state, states)\n",
       "\n",
       "    def draw_obs(state, key):\n",
       "      return self._obs_dist.sample(seed=key)\n",
       "\n",
       "    keys = jax.random.split(rng_obs, seq_len)\n",
       "    obs_seq = jax.vmap(draw_obs, in_axes=(0, 0))(states, keys)\n",
       "\n",
       "    return states, obs_seq\n",
       "\n",
       "
\n" ], "text/plain": [ " def \u001b[1;35msample\u001b[0m\u001b[1m(\u001b[0mself,\n", " *,\n", " seed: chex.PRNGKey,\n", " seq_len: chex.Array\u001b[1m)\u001b[0m -> Tuple:\n", " \u001b[32m\"\"\u001b[0m\"Sample from this HMM.\n", "\n", " Samples an observation of given length according to this\n", " Hidden Markov Model and gives the sequence of the hidden states\n", " as well as the observation.\n", "\n", " Args:\n", " seed: Random key of shape \u001b[1m(\u001b[0m\u001b[1;36m2\u001b[0m,\u001b[1m)\u001b[0m and dtype uint32.\n", " seq_len: The length of the observation sequence.\n", "\n", " Returns:\n", " Tuple of hidden state sequence, and observation sequence.\n", " \u001b[32m\"\"\u001b[0m\"\n", " rng_key, rng_init = \u001b[1;35mjax.random.split\u001b[0m\u001b[1m(\u001b[0mseed\u001b[1m)\u001b[0m\n", " initial_state = \u001b[1;35mself._init_dist.sample\u001b[0m\u001b[1m(\u001b[0m\u001b[33mseed\u001b[0m=\u001b[35mrng_init\u001b[0m\u001b[1m)\u001b[0m\n", "\n", " def \u001b[1;35mdraw_state\u001b[0m\u001b[1m(\u001b[0mprev_state, key\u001b[1m)\u001b[0m:\n", " state = \u001b[1;35mself._trans_dist.sample\u001b[0m\u001b[1m(\u001b[0m\u001b[33mseed\u001b[0m=\u001b[35mkey\u001b[0m\u001b[1m)\u001b[0m\n", " return state, state\n", "\n", " rng_state, rng_obs = \u001b[1;35mjax.random.split\u001b[0m\u001b[1m(\u001b[0mrng_key\u001b[1m)\u001b[0m\n", " keys = \u001b[1;35mjax.random.split\u001b[0m\u001b[1m(\u001b[0mrng_state, seq_len - \u001b[1;36m1\u001b[0m\u001b[1m)\u001b[0m\n", " _, states = \u001b[1;35mjax.lax.scan\u001b[0m\u001b[1m(\u001b[0mdraw_state, initial_state, keys\u001b[1m)\u001b[0m\n", " states = \u001b[1;35mjnp.append\u001b[0m\u001b[1m(\u001b[0minitial_state, states\u001b[1m)\u001b[0m\n", "\n", " def \u001b[1;35mdraw_obs\u001b[0m\u001b[1m(\u001b[0mstate, key\u001b[1m)\u001b[0m:\n", " return \u001b[1;35mself._obs_dist.sample\u001b[0m\u001b[1m(\u001b[0m\u001b[33mseed\u001b[0m=\u001b[35mkey\u001b[0m\u001b[1m)\u001b[0m\n", "\n", " keys = \u001b[1;35mjax.random.split\u001b[0m\u001b[1m(\u001b[0mrng_obs, seq_len\u001b[1m)\u001b[0m\n", " obs_seq = \u001b[1;35mjax.vmap\u001b[0m\u001b[1m(\u001b[0mdraw_obs, \u001b[33min_axes\u001b[0m=\u001b[1m(\u001b[0m\u001b[1;36m0\u001b[0m, \u001b[1;36m0\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m\u001b[1m(\u001b[0mstates, keys\u001b[1m)\u001b[0m\n", "\n", " return states, obs_seq\n", "\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Here is the source code for the sampling algorithm.\n", "\n", "print_source(hmm.sample)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Our primary goal will be to infer the latent state from the observations,\n", "so we can detect if the casino is being dishonest or not. This will\n", "affect how we choose to gamble our money.\n", "We discuss various ways to perform this inference below." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "(sec:lillypad)=\n", "## Example: Lillypad HMM\n", "\n", "\n", "If $\\obs_t$ is continuous, it is common to use a Gaussian\n", "observation model:\n", "```{math}\n", "p(\\obs_t|\\hidden_t=j) = \\gauss(\\obs_t|\\vmu_j,\\vSigma_j)\n", "```\n", "This is sometimes called a Gaussian HMM.\n", "\n", "As a simple example, suppose we have an HMM with 3 hidden states,\n", "each of which generates a 2d Gaussian.\n", "We can represent these Gaussian distributions are 2d ellipses,\n", "as we show below.\n", "We call these ``lilly pads'', because of their shape.\n", "We can imagine a frog hopping from one lilly pad to another.\n", "(This analogy is due to the late Sam Roweis.)\n", "The frog will stay on a pad for a while (corresponding to remaining in the same\n", "discrete state $\\hidden_t$), and then jump to a new pad\n", "(corresponding to a transition to a new state).\n", "The data we see are just the 2d points (e.g., water droplets)\n", "coming from near the pad that the frog is currently on.\n", "Thus this model is like a Gaussian mixture model,\n", "in that it generates clusters of observations,\n", "except now there is temporal correlation between the data points.\n", "\n", "Let us now illustrate this model in code.\n", "\n" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "# Let us create the model\n", "\n", "initial_probs = jnp.array([0.3, 0.2, 0.5])\n", "\n", "# transition matrix\n", "A = jnp.array([\n", "[0.3, 0.4, 0.3],\n", "[0.1, 0.6, 0.3],\n", "[0.2, 0.3, 0.5]\n", "])\n", "\n", "# Observation model\n", "mu_collection = jnp.array([\n", "[0.3, 0.3],\n", "[0.8, 0.5],\n", "[0.3, 0.8]\n", "])\n", "\n", "S1 = jnp.array([[1.1, 0], [0, 0.3]])\n", "S2 = jnp.array([[0.3, -0.5], [-0.5, 1.3]])\n", "S3 = jnp.array([[0.8, 0.4], [0.4, 0.5]])\n", "cov_collection = jnp.array([S1, S2, S3]) / 60\n", "\n", "\n", "import tensorflow_probability as tfp\n", "\n", "if False:\n", " hmm = HMM(trans_dist=distrax.Categorical(probs=A),\n", " init_dist=distrax.Categorical(probs=initial_probs),\n", " obs_dist=distrax.MultivariateNormalFullCovariance(\n", " loc=mu_collection, covariance_matrix=cov_collection))\n", "else:\n", " hmm = HMM(trans_dist=distrax.Categorical(probs=A),\n", " init_dist=distrax.Categorical(probs=initial_probs),\n", " obs_dist=distrax.as_distribution(\n", " tfp.substrates.jax.distributions.MultivariateNormalFullCovariance(loc=mu_collection,\n", " covariance_matrix=cov_collection)))\n", "\n", "print(hmm)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(50,)\n", "(50, 2)\n" ] } ], "source": [ "\n", "n_samples, seed = 50, 10\n", "samples_state, samples_obs = hmm.sample(seed=PRNGKey(seed), seq_len=n_samples)\n", "\n", "print(samples_state.shape)\n", "print(samples_obs.shape)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "\n", "# Let's plot the observed data in 2d\n", "xmin, xmax = 0, 1\n", "ymin, ymax = 0, 1.2\n", "colors = [\"tab:green\", \"tab:blue\", \"tab:red\"]\n", "\n", "def plot_2dhmm(hmm, samples_obs, samples_state, colors, ax, xmin, xmax, ymin, ymax, step=1e-2):\n", " obs_dist = hmm.obs_dist\n", " color_sample = [colors[i] for i in samples_state]\n", "\n", " xs = jnp.arange(xmin, xmax, step)\n", " ys = jnp.arange(ymin, ymax, step)\n", "\n", " v_prob = vmap(lambda x, y: obs_dist.prob(jnp.array([x, y])), in_axes=(None, 0))\n", " z = vmap(v_prob, in_axes=(0, None))(xs, ys)\n", "\n", " grid = np.mgrid[xmin:xmax:step, ymin:ymax:step]\n", "\n", " for k, color in enumerate(colors):\n", " ax.contour(*grid, z[:, :, k], levels=[1], colors=color, linewidths=3)\n", " ax.text(*(obs_dist.mean()[k] + 0.13), f\"$k$={k + 1}\", fontsize=13, horizontalalignment=\"right\")\n", "\n", " ax.plot(*samples_obs.T, c=\"black\", alpha=0.3, zorder=1)\n", " ax.scatter(*samples_obs.T, c=color_sample, s=30, zorder=2, alpha=0.8)\n", "\n", " return ax, color_sample\n", "\n", "\n", "fig, ax = plt.subplots()\n", "_, color_sample = plot_2dhmm(hmm, samples_obs, samples_state, colors, ax, xmin, xmax, ymin, ymax)\n" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Let's plot the hidden state sequence\n", "\n", "fig, ax = plt.subplots()\n", "ax.step(range(n_samples), samples_state, where=\"post\", c=\"black\", linewidth=1, alpha=0.3)\n", "ax.scatter(range(n_samples), samples_state, c=color_sample, zorder=3)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "(sec:lds-intro)=\n", "# Linear Gaussian SSMs\n", "\n", "\n", "Consider the state space model in \n", "{eq}`eq:SSM-ar`\n", "where we assume the observations are conditionally iid given the\n", "hidden states and inputs (i.e. there are no auto-regressive dependencies\n", "between the observables).\n", "We can rewrite this model as \n", "a stochastic nonlinear dynamical system (NLDS)\n", "by defining the distribution of the next hidden state \n", "as a deterministic function of the past state\n", "plus random process noise $\\vepsilon_t$ \n", "\\begin{align}\n", "\\hmmhid_t &= \\ssmDynFn(\\hmmhid_{t-1}, \\inputs_t, \\vepsilon_t) \n", "\\end{align}\n", "where $\\vepsilon_t$ is drawn from the distribution such\n", "that the induced distribution\n", "on $\\hmmhid_t$ matches $p(\\hmmhid_t|\\hmmhid_{t-1}, \\inputs_t)$.\n", "Similarly we can rewrite the observation distributions\n", "as a deterministic function of the hidden state\n", "plus observation noise $\\veta_t$:\n", "\\begin{align}\n", "\\hmmobs_t &= \\ssmObsFn(\\hmmhid_{t}, \\inputs_t, \\veta_t)\n", "\\end{align}\n", "\n", "\n", "If we assume additive Gaussian noise,\n", "the model becomes\n", "\\begin{align}\n", "\\hmmhid_t &= \\ssmDynFn(\\hmmhid_{t-1}, \\inputs_t) + \\vepsilon_t \\\\\n", "\\hmmobs_t &= \\ssmObsFn(\\hmmhid_{t}, \\inputs_t) + \\veta_t\n", "\\end{align}\n", "where $\\vepsilon_t \\sim \\gauss(\\vzero,\\vQ_t)$\n", "and $\\veta_t \\sim \\gauss(\\vzero,\\vR_t)$.\n", "We will call these Gaussian SSMs.\n", "\n", "If we additionally assume\n", "the transition function $\\ssmDynFn$\n", "and the observation function $\\ssmObsFn$ are both linear,\n", "then we can rewrite the model as follows:\n", "\\begin{align}\n", "p(\\hmmhid_t|\\hmmhid_{t-1},\\inputs_t) &= \\gauss(\\hmmhid_t|\\ldsDyn_t \\hmmhid_{t-1}\n", "+ \\ldsDynIn_t \\inputs_t, \\vQ_t)\n", "\\\\\n", "p(\\hmmobs_t|\\hmmhid_t,\\inputs_t) &= \\gauss(\\hmmobs_t|\\ldsObs_t \\hmmhid_{t}\n", "+ \\ldsObsIn_t \\inputs_t, \\vR_t)\n", "\\end{align}\n", "This is called a \n", "linear-Gaussian state space model\n", "(LG-SSM),\n", "or a\n", "linear dynamical system (LDS).\n", "We usually assume the parameters are independent of time, in which case\n", "the model is said to be time-invariant or homogeneous.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "(sec:tracking-lds)=\n", "(sec:kalman-tracking)=\n", "## Example: tracking a 2d point\n", "\n", "\n", "\n", "% Sarkkar p43\n", "Consider an object moving in $\\real^2$.\n", "Let the state be\n", "the position and velocity of the object,\n", "$$\\vz_t =\\begin{pmatrix} u_t & \\dot{u}_t & v_t & \\dot{v}_t \\end{pmatrix}$$.\n", "(We use $u$ and $v$ for the two coordinates,\n", "to avoid confusion with the state and observation variables.)\n", "If we use Euler discretization,\n", "the dynamics become\n", "\\begin{align}\n", "\\underbrace{\\begin{pmatrix} u_t\\\\ \\dot{u}_t \\\\ v_t \\\\ \\dot{v}_t \\end{pmatrix}}_{\\vz_t}\n", " = \n", "\\underbrace{\n", "\\begin{pmatrix}\n", "1 & 0 & \\Delta & 0 \\\\\n", "0 & 1 & 0 & \\Delta\\\\\n", "0 & 0 & 1 & 0 \\\\\n", "0 & 0 & 0 & 1\n", "\\end{pmatrix}\n", "}_{\\ldsDyn}\n", "\\\n", "\\underbrace{\\begin{pmatrix} u_{t-1} \\\\ \\dot{u}_{t-1} \\\\ v_{t-1} \\\\ \\dot{v}_{t-1} \\end{pmatrix}}_{\\vz_{t-1}}\n", "+ \\vepsilon_t\n", "\\end{align}\n", "where $\\vepsilon_t \\sim \\gauss(\\vzero,\\vQ)$ is\n", "the process noise.\n", "\n", "Let us assume\n", "that the process noise is \n", "a white noise process added to the velocity components\n", "of the state, but not to the location.\n", "(This is known as a random accelerations model.)\n", "We can approximate the resulting process in discrete time by assuming\n", "$\\vQ = \\diag(0, q, 0, q)$.\n", "(See {cite}`Sarkka13` p60 for a more accurate way\n", "to convert the continuous time process to discrete time.)\n", "\n", "\n", "Now suppose that at each discrete time point we\n", "observe the location,\n", "corrupted by Gaussian noise.\n", "Thus the observation model becomes\n", "\\begin{align}\n", "\\underbrace{\\begin{pmatrix} y_{1,t} \\\\ y_{2,t} \\end{pmatrix}}_{\\vy_t}\n", " &=\n", " \\underbrace{\n", " \\begin{pmatrix}\n", "1 & 0 & 0 & 0 \\\\\n", "0 & 0 & 1 & 0\n", " \\end{pmatrix}\n", " }_{\\ldsObs}\n", " \\\n", "\\underbrace{\\begin{pmatrix} u_t\\\\ \\dot{u}_t \\\\ v_t \\\\ \\dot{v}_t \\end{pmatrix}}_{\\vz_t} \n", " + \\veta_t\n", "\\end{align}\n", "where $\\veta_t \\sim \\gauss(\\vzero,\\vR)$ is the \\keywordDef{observation noise}.\n", "We see that the observation matrix $\\ldsObs$ simply ``extracts'' the\n", "relevant parts of the state vector.\n", "\n", "Suppose we sample a trajectory and corresponding set\n", "of noisy observations from this model,\n", "$(\\vz_{1:T}, \\vy_{1:T}) \\sim p(\\vz,\\vy|\\vtheta)$.\n", "(We use diagonal observation noise,\n", "$\\vR = \\diag(\\sigma_1^2, \\sigma_2^2)$.)\n", "The results are shown below. \n" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "LDS(A=DeviceArray([[1., 0., 1., 0.],\n", " [0., 1., 0., 1.],\n", " [0., 0., 1., 0.],\n", " [0., 0., 0., 1.]], dtype=float32), C=DeviceArray([[1, 0, 0, 0],\n", " [0, 1, 0, 0]], dtype=int32), Q=DeviceArray([[0.001, 0. , 0. , 0. ],\n", " [0. , 0.001, 0. , 0. ],\n", " [0. , 0. , 0.001, 0. ],\n", " [0. , 0. , 0. , 0.001]], dtype=float32), R=DeviceArray([[1., 0.],\n", " [0., 1.]], dtype=float32), mu=DeviceArray([ 8., 10., 1., 0.], dtype=float32), Sigma=DeviceArray([[1., 0., 0., 0.],\n", " [0., 1., 0., 0.],\n", " [0., 0., 1., 0.],\n", " [0., 0., 0., 1.]], dtype=float32), state_offset=None, obs_offset=None, nstates=4, nobs=2)\n" ] } ], "source": [ "key = jax.random.PRNGKey(314)\n", "timesteps = 15\n", "delta = 1.0\n", "A = jnp.array([\n", " [1, 0, delta, 0],\n", " [0, 1, 0, delta],\n", " [0, 0, 1, 0],\n", " [0, 0, 0, 1]\n", "])\n", "\n", "C = jnp.array([\n", " [1, 0, 0, 0],\n", " [0, 1, 0, 0]\n", "])\n", "\n", "state_size, _ = A.shape\n", "observation_size, _ = C.shape\n", "\n", "Q = jnp.eye(state_size) * 0.001\n", "R = jnp.eye(observation_size) * 1.0\n", "# Prior parameter distribution\n", "mu0 = jnp.array([8, 10, 1, 0]).astype(float)\n", "Sigma0 = jnp.eye(state_size) * 1.0\n", "\n", "from jsl.lds.kalman_filter import LDS, smooth, filter\n", "\n", "lds = LDS(A, C, Q, R, mu0, Sigma0)\n", "print(lds)\n", "\n" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "from jsl.demos.plot_utils import plot_ellipse\n", "\n", "def plot_tracking_values(observed, filtered, cov_hist, signal_label, ax):\n", " timesteps, _ = observed.shape\n", " ax.plot(observed[:, 0], observed[:, 1], marker=\"o\", linewidth=0,\n", " markerfacecolor=\"none\", markeredgewidth=2, markersize=8, label=\"observed\", c=\"tab:green\")\n", " ax.plot(*filtered[:, :2].T, label=signal_label, c=\"tab:red\", marker=\"x\", linewidth=2)\n", " for t in range(0, timesteps, 1):\n", " covn = cov_hist[t][:2, :2]\n", " plot_ellipse(covn, filtered[t, :2], ax, n_std=2.0, plot_center=False)\n", " ax.axis(\"equal\")\n", " ax.legend()" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(7.24486608505249, 23.857812213897706, 8.0420747756958, 11.636079216003418)" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "\n", "z_hist, x_hist = lds.sample(key, timesteps)\n", "\n", "fig_truth, axs = plt.subplots()\n", "axs.plot(x_hist[:, 0], x_hist[:, 1],\n", " marker=\"o\", linewidth=0, markerfacecolor=\"none\",\n", " markeredgewidth=2, markersize=8,\n", " label=\"observed\", c=\"tab:green\")\n", "\n", "axs.plot(z_hist[:, 0], z_hist[:, 1],\n", " linewidth=2, label=\"truth\",\n", " marker=\"s\", markersize=8)\n", "axs.legend()\n", "axs.axis(\"equal\")\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The main task is to infer the hidden states given the noisy\n", "observations, i.e., $p(\\vz|\\vy,\\vtheta)$. We discuss the topic of inference in {ref}`sec:inference`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "(sec:nlds-intro)=\n", "# Nonlinear Gaussian SSMs\n", "\n", "In this section, we consider SSMs in which the dynamics and/or observation models are nonlinear,\n", "but the process noise and observation noise are Gaussian.\n", "That is, \n", "\\begin{align}\n", "\\hmmhid_t &= \\ssmDynFn(\\hmmhid_{t-1}, \\inputs_t) + \\vepsilon_t \\\\\n", "\\hmmobs_t &= \\ssmObsFn(\\hmmhid_{t}, \\inputs_t) + \\veta_t\n", "\\end{align}\n", "where $\\vepsilon_t \\sim \\gauss(\\vzero,\\vQ_t)$\n", "and $\\veta_t \\sim \\gauss(\\vzero,\\vR_t)$.\n", "This is a very widely used model class. We give some examples below." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "(sec:pendulum)=\n", "## Example: tracking a 1d pendulum\n", "\n", "```{figure} /figures/pendulum.png\n", ":scale: 100%\n", ":name: fig:pendulum\n", "\n", "Illustration of a pendulum swinging.\n", "$g$ is the force of gravity,\n", "$w(t)$ is a random external force,\n", "and $\\alpha$ is the angle wrt the vertical.\n", "Based on {cite}`Sarkka13` fig 3.10.\n", "\n", "```\n", "\n", "\n", "% Sarka p45, p74\n", "Consider a simple pendulum of unit mass and length swinging from\n", "a fixed attachment, as in {ref}`fig:pendulum`.\n", "Such an object is in principle entirely deterministic in its behavior.\n", "However, in the real world, there are often unknown forces at work\n", "(e.g., air turbulence, friction).\n", "We will model these by a continuous time random Gaussian noise process $w(t)$.\n", "This gives rise to the following differential equation:\n", "\\begin{align}\n", "\\frac{d^2 \\alpha}{d t^2}\n", "= -g \\sin(\\alpha) + w(t)\n", "\\end{align}\n", "We can write this as a nonlinear SSM by defining the state to be\n", "$z_1(t) = \\alpha(t)$ and $z_2(t) = d\\alpha(t)/dt$.\n", "Thus\n", "\\begin{align}\n", "\\frac{d \\vz}{dt}\n", "= \\begin{pmatrix} z_2 \\\\ -g \\sin(z_1) \\end{pmatrix}\n", "+ \\begin{pmatrix} 0 \\\\ 1 \\end{pmatrix} w(t)\n", "\\end{align}\n", "If we discretize this step size $\\Delta$,\n", "we get the following\n", "formulation {cite}`Sarkka13` p74:\n", "\\begin{align}\n", "\\underbrace{\n", " \\begin{pmatrix} z_{1,t} \\\\ z_{2,t} \\end{pmatrix}\n", " }_{\\hmmhid_t}\n", "=\n", "\\underbrace{\n", " \\begin{pmatrix} z_{1,t-1} + z_{2,t-1} \\Delta \\\\\n", " z_{2,t-1} -g \\sin(z_{1,t-1}) \\Delta \\end{pmatrix}\n", " }_{\\vf(\\hmmhid_{t-1})}\n", "+\\vq_{t-1}\n", "\\end{align}\n", "where $\\vq_{t-1} \\sim \\gauss(\\vzero,\\vQ)$ with\n", "\\begin{align}\n", "\\vQ = q^c \\begin{pmatrix}\n", " \\frac{\\Delta^3}{3} & \\frac{\\Delta^2}{2} \\\\\n", " \\frac{\\Delta^2}{2} & \\Delta\n", " \\end{pmatrix}\n", " \\end{align}\n", "where $q^c$ is the spectral density (continuous time variance)\n", "of the continuous-time noise process.\n", "\n", "\n", "If we observe the angular position, we\n", "get the linear observation model\n", "\\begin{align}\n", "y_t = \\alpha_t + r_t = h(\\hmmhid_t) + r_t\n", "\\end{align}\n", "where $h(\\hmmhid_t) = z_{1,t}$\n", "and $r_t$ is the observation noise.\n", "If we only observe the horizontal position,\n", "we get the nonlinear observation model\n", "\\begin{align}\n", "y_t = \\sin(\\alpha_t) + r_t = h(\\hmmhid_t) + r_t\n", "\\end{align}\n", "where $h(\\hmmhid_t) = \\sin(z_{1,t})$.\n", "\n", "\n", "\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "(sec:inference)=\n", "# Inferential goals\n", "\n", "```{figure} /figures/inference-problems-tikz.png\n", ":scale: 100%\n", ":name: fig:dbn-inference\n", "\n", "Illustration of the different kinds of inference in an SSM.\n", " The main kinds of inference for state-space models.\n", " The shaded region is the interval for which we have data.\n", " The arrow represents the time step at which we want to perform inference.\n", " $t$ is the current time, $T$ is the sequence length,\n", "$\\ell$ is the lag and $h$ is the prediction horizon.\n", "```\n", "\n", "\n", "\n", "Given the sequence of observations, and a known model,\n", "one of the main tasks with SSMs\n", "to perform posterior inference,\n", "about the hidden states; this is also called\n", "state estimation.\n", "At each time step $t$,\n", "there are multiple forms of posterior we may be interested in computing,\n", "including the following:\n", "- the filtering distribution\n", "$p(\\hmmhid_t|\\hmmobs_{1:t})$\n", "- the smoothing distribution\n", "$p(\\hmmhid_t|\\hmmobs_{1:T})$ (note that this conditions on future data $T>t$)\n", "- the fixed-lag smoothing distribution\n", "$p(\\hmmhid_{t-\\ell}|\\hmmobs_{1:t})$ (note that this\n", "infers $\\ell$ steps in the past given data up to the present).\n", "\n", "We may also want to compute the\n", "predictive distribution $h$ steps into the future:\n", "\\begin{align}\n", "p(\\hmmobs_{t+h}|\\hmmobs_{1:t})\n", "&= \\sum_{\\hmmhid_{t+h}} p(\\hmmobs_{t+h}|\\hmmhid_{t+h}) p(\\hmmhid_{t+h}|\\hmmobs_{1:t})\n", "\\end{align}\n", "where the hidden state predictive distribution is\n", "\\begin{align}\n", "p(\\hmmhid_{t+h}|\\hmmobs_{1:t})\n", "&= \\sum_{\\hmmhid_{t:t+h-1}}\n", " p(\\hmmhid_t|\\hmmobs_{1:t}) \n", " p(\\hmmhid_{t+1}|\\hmmhid_{t})\n", " p(\\hmmhid_{t+2}|\\hmmhid_{t+1})\n", "\\cdots\n", " p(\\hmmhid_{t+h}|\\hmmhid_{t+h-1})\n", "\\end{align}\n", "See {ref}`fig:dbn-inference` for a summary of these distributions.\n", "\n", "In addition to comuting posterior marginals,\n", "we may want to compute the most probable hidden sequence,\n", "i.e., the joint MAP estimate\n", "```{math}\n", "\\arg \\max_{\\hmmhid_{1:T}} p(\\hmmhid_{1:T}|\\hmmobs_{1:T})\n", "```\n", "or sample sequences from the posterior\n", "```{math}\n", "\\hmmhid_{1:T} \\sim p(\\hmmhid_{1:T}|\\hmmobs_{1:T})\n", "```\n", "\n", "Algorithms for all these task are discussed in the following chapters,\n", "since the details depend on the form of the SSM.\n", "\n", "\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Example: inference in the casino HMM\n", "\n", "We now illustrate filtering, smoothing and MAP decoding applied\n", "to the casino HMM from {ref}`sec:casino`. \n" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py:5256: UserWarning: Explicitly requested dtype requested in astype is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.\n", " lax._check_user_dtype_supported(dtype, \"astype\")\n" ] } ], "source": [ "# Call inference engine\n", "\n", "filtered_dist, _, smoothed_dist, loglik = hmm.forward_backward(x_hist)\n", "map_path = hmm.viterbi(x_hist)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# Find the span of timesteps that the simulated systems turns to be in state 1\n", "def find_dishonest_intervals(z_hist):\n", " spans = []\n", " x_init = 0\n", " for t, _ in enumerate(z_hist[:-1]):\n", " if z_hist[t + 1] == 0 and z_hist[t] == 1:\n", " x_end = t\n", " spans.append((x_init, x_end))\n", " elif z_hist[t + 1] == 1 and z_hist[t] == 0:\n", " x_init = t + 1\n", " return spans" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "# Plot posterior\n", "def plot_inference(inference_values, z_hist, ax, state=1, map_estimate=False):\n", " n_samples = len(inference_values)\n", " xspan = np.arange(1, n_samples + 1)\n", " spans = find_dishonest_intervals(z_hist)\n", " if map_estimate:\n", " ax.step(xspan, inference_values, where=\"post\")\n", " else:\n", " ax.plot(xspan, inference_values[:, state])\n", "\n", " for span in spans:\n", " ax.axvspan(*span, alpha=0.5, facecolor=\"tab:gray\", edgecolor=\"none\")\n", " ax.set_xlim(1, n_samples)\n", " # ax.set_ylim(0, 1)\n", " ax.set_ylim(-0.1, 1.1)\n", " ax.set_xlabel(\"Observation number\")" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Text(0.5, 1.0, 'Filtered')" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ " # Filtering\n", "fig, ax = plt.subplots()\n", "plot_inference(filtered_dist, z_hist, ax)\n", "ax.set_ylabel(\"p(loaded)\")\n", "ax.set_title(\"Filtered\")\n", " \n", "\n", " " ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Text(0.5, 1.0, 'Smoothed')" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Smoothing\n", "fig, ax = plt.subplots()\n", "plot_inference(smoothed_dist, z_hist, ax)\n", "ax.set_ylabel(\"p(loaded)\")\n", "ax.set_title(\"Smoothed\")\n", "\n", "\n", " " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# MAP estimation\n", "fig, ax = plt.subplots()\n", "plot_inference(map_path, z_hist, ax, map_estimate=True)\n", "ax.set_ylabel(\"MAP state\")\n", "ax.set_title(\"Viterbi\")\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# TODO: posterior samples\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Example: inference in the tracking SSM\n", "\n", "We now illustrate filtering, smoothing and MAP decoding applied\n", "to the 2d tracking HMM from {ref}`sec:tracking-lds`. " ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.2" } }, "nbformat": 4, "nbformat_minor": 4 }