瀏覽代碼

add macros and more content

Kevin P Murphy 3 年之前
父節點
當前提交
d323ba98a2

+ 1 - 1
_toc.yml

@@ -4,7 +4,6 @@
 format: jb-book
 root: root
 chapters:
-- file: chapters/scratch
 
 - file: chapters/ssm/ssm_index
   sections:
@@ -13,6 +12,7 @@ chapters:
   - file: chapters/ssm/lds
   - file: chapters/ssm/nlds
   - file: chapters/ssm/inference
+  - file: chapters/ssm/learning
 
 - file: chapters/hmm/hmm_index
   sections:

+ 0 - 59
chapters/blank.ipynb

@@ -1,59 +0,0 @@
-{
- "cells": [
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "(chap:my-chap)=\n",
-    "# Chapter title\n",
-    "\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 1,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "42\n"
-     ]
-    }
-   ],
-   "source": [
-    "print(42)"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": []
-  }
- ],
- "metadata": {
-  "interpreter": {
-   "hash": "6407c60499271029b671b4ff687c4ed4626355c45fd34c44476827f4be42c4d7"
-  },
-  "kernelspec": {
-   "display_name": "Python 3.9.2 ('spyder-dev')",
-   "language": "python",
-   "name": "python3"
-  },
-  "language_info": {
-   "codemirror_mode": {
-    "name": "ipython",
-    "version": 3
-   },
-   "file_extension": ".py",
-   "mimetype": "text/x-python",
-   "name": "python",
-   "nbconvert_exporter": "python",
-   "pygments_lexer": "ipython3",
-   "version": "3.9.2"
-  }
- },
- "nbformat": 4,
- "nbformat_minor": 4
-}

+ 0 - 612
chapters/hmm/hmm.ipynb

@@ -1,612 +0,0 @@
-{
- "cells": [
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "(sec:hmm-ex)=\n",
-    "# Hidden Markov Models\n",
-    "\n",
-    "In this section, we introduce Hidden Markov Models (HMMs)."
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "## Boilerplate"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 1,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "Collecting jax[cpu]\n",
-      "  Downloading jax-0.3.5.tar.gz (946 kB)\n",
-      "\u001b[K     |████████████████████████████████| 946 kB 2.7 MB/s eta 0:00:01\n",
-      "\u001b[?25hCollecting absl-py\n",
-      "  Downloading absl_py-1.0.0-py3-none-any.whl (126 kB)\n",
-      "\u001b[K     |████████████████████████████████| 126 kB 47.7 MB/s eta 0:00:01\n",
-      "\u001b[?25hCollecting numpy>=1.19\n",
-      "  Downloading numpy-1.22.3-cp38-cp38-macosx_10_14_x86_64.whl (17.6 MB)\n",
-      "\u001b[K     |████████████████████████████████| 17.6 MB 47.5 MB/s eta 0:00:01\n",
-      "\u001b[?25hCollecting opt_einsum\n",
-      "  Using cached opt_einsum-3.3.0-py3-none-any.whl (65 kB)\n",
-      "Collecting scipy>=1.2.1\n",
-      "  Downloading scipy-1.8.0-cp38-cp38-macosx_12_0_universal2.macosx_10_9_x86_64.whl (55.3 MB)\n",
-      "\u001b[K     |████████████████████████████████| 55.3 MB 73.1 MB/s eta 0:00:01\n",
-      "\u001b[?25hCollecting typing_extensions\n",
-      "  Using cached typing_extensions-4.1.1-py3-none-any.whl (26 kB)\n",
-      "Collecting jaxlib==0.3.5\n",
-      "  Downloading jaxlib-0.3.5-cp38-none-macosx_10_9_x86_64.whl (70.5 MB)\n",
-      "\u001b[K     |████████████████████████████████| 70.5 MB 723 kB/s  eta 0:00:01\n",
-      "\u001b[?25hCollecting flatbuffers<3.0,>=1.12\n",
-      "  Using cached flatbuffers-2.0-py2.py3-none-any.whl (26 kB)\n",
-      "Requirement already satisfied: six in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from absl-py->jax[cpu]) (1.16.0)\n",
-      "Building wheels for collected packages: jax\n",
-      "  Building wheel for jax (setup.py) ... \u001b[?25ldone\n",
-      "\u001b[?25h  Created wheel for jax: filename=jax-0.3.5-py3-none-any.whl size=1095861 sha256=6886baa70817bbac3b5797b3720dcb81e49097f61cee7d2e1255823ea32ccad8\n",
-      "  Stored in directory: /Users/kpmurphy/Library/Caches/pip/wheels/05/30/aa/908988293721511b4b29e0aadf9b5d133d0f14f6c0a188e764\n",
-      "Successfully built jax\n",
-      "Installing collected packages: numpy, typing-extensions, scipy, opt-einsum, flatbuffers, absl-py, jaxlib, jax\n",
-      "Successfully installed absl-py-1.0.0 flatbuffers-2.0 jax-0.3.5 jaxlib-0.3.5 numpy-1.22.3 opt-einsum-3.3.0 scipy-1.8.0 typing-extensions-4.1.1\n",
-      "Note: you may need to restart the kernel to use updated packages.\n",
-      "Collecting git+https://github.com/probml/jsl\n",
-      "  Cloning https://github.com/probml/jsl to /private/var/folders/mn/vt7cgfsx6zs9vblhvbbk7pf8003xtr/T/pip-req-build-i8seqdiw\n",
-      "  Running command git clone -q https://github.com/probml/jsl /private/var/folders/mn/vt7cgfsx6zs9vblhvbbk7pf8003xtr/T/pip-req-build-i8seqdiw\n",
-      "Collecting chex\n",
-      "  Downloading chex-0.1.2-py3-none-any.whl (72 kB)\n",
-      "\u001b[K     |████████████████████████████████| 72 kB 1.3 MB/s eta 0:00:011\n",
-      "\u001b[?25hCollecting dataclasses\n",
-      "  Using cached dataclasses-0.6-py3-none-any.whl (14 kB)\n",
-      "Requirement already satisfied: jaxlib in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from jsl==0.0.0) (0.3.5)\n",
-      "Requirement already satisfied: jax in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from jsl==0.0.0) (0.3.5)\n",
-      "Collecting matplotlib\n",
-      "  Downloading matplotlib-3.5.1-cp38-cp38-macosx_10_9_x86_64.whl (7.3 MB)\n",
-      "\u001b[K     |████████████████████████████████| 7.3 MB 3.9 MB/s eta 0:00:01\n",
-      "\u001b[?25hCollecting tensorflow_probability\n",
-      "  Using cached tensorflow_probability-0.16.0-py2.py3-none-any.whl (6.3 MB)\n",
-      "Collecting dm-tree>=0.1.5\n",
-      "  Using cached dm_tree-0.1.6-cp38-cp38-macosx_10_14_x86_64.whl (95 kB)\n",
-      "Requirement already satisfied: absl-py>=0.9.0 in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from chex->jsl==0.0.0) (1.0.0)\n",
-      "Requirement already satisfied: numpy>=1.18.0 in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from chex->jsl==0.0.0) (1.22.3)\n",
-      "Collecting toolz>=0.9.0\n",
-      "  Downloading toolz-0.11.2-py3-none-any.whl (55 kB)\n",
-      "\u001b[K     |████████████████████████████████| 55 kB 11.3 MB/s eta 0:00:01\n",
-      "\u001b[?25hRequirement already satisfied: six in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from absl-py>=0.9.0->chex->jsl==0.0.0) (1.16.0)\n",
-      "Requirement already satisfied: typing-extensions in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from jax->jsl==0.0.0) (4.1.1)\n",
-      "Requirement already satisfied: scipy>=1.2.1 in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from jax->jsl==0.0.0) (1.8.0)\n",
-      "Requirement already satisfied: opt-einsum in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from jax->jsl==0.0.0) (3.3.0)\n",
-      "Requirement already satisfied: flatbuffers<3.0,>=1.12 in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from jaxlib->jsl==0.0.0) (2.0)\n",
-      "Collecting cycler>=0.10\n",
-      "  Downloading cycler-0.11.0-py3-none-any.whl (6.4 kB)\n",
-      "Collecting kiwisolver>=1.0.1\n",
-      "  Downloading kiwisolver-1.4.2-cp38-cp38-macosx_10_9_x86_64.whl (65 kB)\n",
-      "\u001b[K     |████████████████████████████████| 65 kB 8.5 MB/s  eta 0:00:01\n",
-      "\u001b[?25hCollecting fonttools>=4.22.0\n",
-      "  Downloading fonttools-4.32.0-py3-none-any.whl (900 kB)\n",
-      "\u001b[K     |████████████████████████████████| 900 kB 35.8 MB/s eta 0:00:01\n",
-      "\u001b[?25hRequirement already satisfied: pyparsing>=2.2.1 in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from matplotlib->jsl==0.0.0) (3.0.7)\n",
-      "Requirement already satisfied: packaging>=20.0 in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from matplotlib->jsl==0.0.0) (21.3)\n",
-      "Requirement already satisfied: python-dateutil>=2.7 in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from matplotlib->jsl==0.0.0) (2.8.2)\n",
-      "Collecting pillow>=6.2.0\n",
-      "  Downloading Pillow-9.1.0-cp38-cp38-macosx_10_9_x86_64.whl (3.1 MB)\n",
-      "\u001b[K     |████████████████████████████████| 3.1 MB 76.6 MB/s eta 0:00:01\n",
-      "\u001b[?25hCollecting gast>=0.3.2\n",
-      "  Downloading gast-0.5.3-py3-none-any.whl (19 kB)\n",
-      "Requirement already satisfied: decorator in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from tensorflow_probability->jsl==0.0.0) (5.1.1)\n",
-      "Collecting cloudpickle>=1.3\n",
-      "  Downloading cloudpickle-2.0.0-py3-none-any.whl (25 kB)\n",
-      "Building wheels for collected packages: jsl\n",
-      "  Building wheel for jsl (setup.py) ... \u001b[?25ldone\n",
-      "\u001b[?25h  Created wheel for jsl: filename=jsl-0.0.0-py3-none-any.whl size=77852 sha256=e7365293dc97e2b3e72bf42cc19db7d7e355abec312fc4d87961fa2044fa06f0\n",
-      "  Stored in directory: /private/var/folders/mn/vt7cgfsx6zs9vblhvbbk7pf8003xtr/T/pip-ephem-wheel-cache-63vxzlng/wheels/ed/8b/bf/0105dc839fecf1fc8db14f7267a6ce5ee876324b58565b359f\n",
-      "Successfully built jsl\n",
-      "Installing collected packages: toolz, pillow, kiwisolver, gast, fonttools, dm-tree, cycler, cloudpickle, tensorflow-probability, matplotlib, dataclasses, chex, jsl\n",
-      "Successfully installed chex-0.1.2 cloudpickle-2.0.0 cycler-0.11.0 dataclasses-0.6 dm-tree-0.1.6 fonttools-4.32.0 gast-0.5.3 jsl-0.0.0 kiwisolver-1.4.2 matplotlib-3.5.1 pillow-9.1.0 tensorflow-probability-0.16.0 toolz-0.11.2\n",
-      "Note: you may need to restart the kernel to use updated packages.\n",
-      "Collecting rich\n",
-      "  Downloading rich-12.2.0-py3-none-any.whl (229 kB)\n",
-      "\u001b[K     |████████████████████████████████| 229 kB 2.9 MB/s eta 0:00:01\n",
-      "\u001b[?25hRequirement already satisfied: typing-extensions<5.0,>=4.0.0 in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from rich) (4.1.1)\n",
-      "Requirement already satisfied: pygments<3.0.0,>=2.6.0 in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from rich) (2.11.2)\n",
-      "Collecting commonmark<0.10.0,>=0.9.0\n",
-      "  Using cached commonmark-0.9.1-py2.py3-none-any.whl (51 kB)\n",
-      "Installing collected packages: commonmark, rich\n",
-      "Successfully installed commonmark-0.9.1 rich-12.2.0\n",
-      "Note: you may need to restart the kernel to use updated packages.\n"
-     ]
-    }
-   ],
-   "source": [
-    "# Install necessary libraries\n",
-    "\n",
-    "try:\n",
-    "    import jax\n",
-    "except:\n",
-    "    # For cuda version, see https://github.com/google/jax#installation\n",
-    "    %pip install --upgrade \"jax[cpu]\" \n",
-    "    import jax\n",
-    "\n",
-    "try:\n",
-    "    import jsl\n",
-    "except:\n",
-    "    %pip install git+https://github.com/probml/jsl\n",
-    "    import jsl\n",
-    "\n",
-    "try:\n",
-    "    import rich\n",
-    "except:\n",
-    "    %pip install rich\n",
-    "    import rich\n",
-    "\n",
-    "\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 2,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "# Import standard libraries\n",
-    "\n",
-    "import abc\n",
-    "from dataclasses import dataclass\n",
-    "import functools\n",
-    "import itertools\n",
-    "\n",
-    "from typing import Any, Callable, NamedTuple, Optional, Union, Tuple\n",
-    "\n",
-    "import matplotlib.pyplot as plt\n",
-    "import numpy as np\n",
-    "\n",
-    "\n",
-    "import jax\n",
-    "import jax.numpy as jnp\n",
-    "from jax import lax, vmap, jit, grad\n",
-    "from jax.scipy.special import logit\n",
-    "from jax.nn import softmax\n",
-    "from functools import partial\n",
-    "from jax.random import PRNGKey, split\n",
-    "\n",
-    "import inspect\n",
-    "import inspect as py_inspect\n",
-    "from rich import inspect as r_inspect\n",
-    "from rich import print as r_print\n",
-    "\n",
-    "def print_source(fname):\n",
-    "    r_print(py_inspect.getsource(fname))"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "## Utility code"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 3,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "\n",
-    "\n",
-    "def normalize(u, axis=0, eps=1e-15):\n",
-    "    '''\n",
-    "    Normalizes the values within the axis in a way that they sum up to 1.\n",
-    "    Parameters\n",
-    "    ----------\n",
-    "    u : array\n",
-    "    axis : int\n",
-    "    eps : float\n",
-    "        Threshold for the alpha values\n",
-    "    Returns\n",
-    "    -------\n",
-    "    * array\n",
-    "        Normalized version of the given matrix\n",
-    "    * array(seq_len, n_hidden) :\n",
-    "        The values of the normalizer\n",
-    "    '''\n",
-    "    u = jnp.where(u == 0, 0, jnp.where(u < eps, eps, u))\n",
-    "    c = u.sum(axis=axis)\n",
-    "    c = jnp.where(c == 0, 1, c)\n",
-    "    return u / c, c"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "(sec:casino-ex)=\n",
-    "## Example: Casino HMM\n",
-    "\n",
-    "We first create the \"Ocassionally dishonest casino\" model from {cite}`Durbin98`.\n",
-    "\n",
-    "```{figure} /figures/casino.png\n",
-    ":scale: 50%\n",
-    ":name: casino-fig\n",
-    "\n",
-    "Illustration of the casino HMM.\n",
-    "```\n",
-    "\n",
-    "There are 2 hidden states, each of which emit 6 possible observations."
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 5,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
-     ]
-    }
-   ],
-   "source": [
-    "# state transition matrix\n",
-    "A = np.array([\n",
-    "    [0.95, 0.05],\n",
-    "    [0.10, 0.90]\n",
-    "])\n",
-    "\n",
-    "# observation matrix\n",
-    "B = np.array([\n",
-    "    [1/6, 1/6, 1/6, 1/6, 1/6, 1/6], # fair die\n",
-    "    [1/10, 1/10, 1/10, 1/10, 1/10, 5/10] # loaded die\n",
-    "])\n",
-    "\n",
-    "pi, _ = normalize(np.array([1, 1]))\n",
-    "pi = np.array(pi)\n",
-    "\n",
-    "\n",
-    "(nstates, nobs) = np.shape(B)\n"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "Let's make a little data structure to store all the parameters.\n",
-    "We use NamedTuple rather than dataclass, since we assume these are immutable.\n",
-    "(Also, standard python dataclass does not work well with JAX, which requires parameters to be\n",
-    "pytrees, as discussed in https://github.com/google/jax/issues/2371)."
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 74,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "HMM(trans_mat=array([[0.95, 0.05],\n",
-      "       [0.1 , 0.9 ]]), obs_mat=array([[0.16666667, 0.16666667, 0.16666667, 0.16666667, 0.16666667,\n",
-      "        0.16666667],\n",
-      "       [0.1       , 0.1       , 0.1       , 0.1       , 0.1       ,\n",
-      "        0.5       ]]), init_dist=array([0.5, 0.5], dtype=float32))\n",
-      "<class 'numpy.ndarray'>\n",
-      "HMM(trans_mat=DeviceArray([[0.95, 0.05],\n",
-      "             [0.1 , 0.9 ]], dtype=float32), obs_mat=DeviceArray([[0.16666667, 0.16666667, 0.16666667, 0.16666667, 0.16666667,\n",
-      "              0.16666667],\n",
-      "             [0.1       , 0.1       , 0.1       , 0.1       , 0.1       ,\n",
-      "              0.5       ]], dtype=float32), init_dist=DeviceArray([0.5, 0.5], dtype=float32))\n",
-      "<class 'jaxlib.xla_extension.DeviceArray'>\n"
-     ]
-    }
-   ],
-   "source": [
-    "Array = Union[np.array, jnp.array]\n",
-    "\n",
-    "class HMM(NamedTuple):\n",
-    "    trans_mat: Array  # A : (n_states, n_states)\n",
-    "    obs_mat: Array  # B : (n_states, n_obs)\n",
-    "    init_dist: Array  # pi : (n_states)\n",
-    "\n",
-    "params_np = HMM(A, B, pi)\n",
-    "print(params_np)\n",
-    "print(type(params_np.trans_mat))\n",
-    "\n",
-    "\n",
-    "params = jax.tree_map(lambda x: jnp.array(x), params_np)\n",
-    "print(params)\n",
-    "print(type(params.trans_mat))"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "## Sampling from the joint\n",
-    "\n",
-    "Let's write code to sample from this model. \n"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "### Numpy version\n",
-    "\n",
-    "First we code it in numpy using a for loop."
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 30,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "def hmm_sample_np(params, seq_len, random_state=0):\n",
-    "    np.random.seed(random_state)\n",
-    "    trans_mat, obs_mat, init_dist = params.trans_mat, params.obs_mat, params.init_dist\n",
-    "    n_states, n_obs = obs_mat.shape\n",
-    "    state_seq = np.zeros(seq_len, dtype=int)\n",
-    "    obs_seq = np.zeros(seq_len, dtype=int)\n",
-    "    for t in range(seq_len):\n",
-    "        if t==0:\n",
-    "            zt = np.random.choice(n_states, p=init_dist)\n",
-    "        else:\n",
-    "            zt = np.random.choice(n_states, p=trans_mat[zt])\n",
-    "        yt = np.random.choice(n_obs, p=obs_mat[zt])\n",
-    "        state_seq[t] = zt\n",
-    "        obs_seq[t] = yt\n",
-    "\n",
-    "    return state_seq, obs_seq"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 75,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0\n",
-      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
-      " 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
-      "[4 1 0 2 3 4 5 4 3 1 5 4 5 0 5 2 5 3 5 4 5 5 4 2 1 4 1 0 0 4 2 2 3 3 3 0 4\n",
-      " 0 2 4 3 2 5 5 3 5 3 1 3 3 3 2 3 5 5 0 4 4 5 0 0 1 3 5 1 5 0 1 2 4 0 0 0 4\n",
-      " 0 5 1 4 3 5 4 5 0 2 3 5 2 4 1 2 1 0 4 3 5 0 4 5 1 5]\n"
-     ]
-    }
-   ],
-   "source": [
-    "seq_len = 100\n",
-    "state_seq, obs_seq = hmm_sample_np(params_np, seq_len, random_state=1)\n",
-    "print(state_seq)\n",
-    "print(obs_seq)"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "### JAX version\n",
-    "\n",
-    "Now let's write a JAX version using jax.lax.scan (for the inter-dependent states) and vmap (for the observations).\n",
-    "This is harder to read than the numpy version, but faster."
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 91,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "#@partial(jit, static_argnums=(1,))\n",
-    "def markov_chain_sample(rng_key, init_dist, trans_mat, seq_len):\n",
-    "    n_states = len(init_dist)\n",
-    "\n",
-    "    def draw_state(prev_state, key):\n",
-    "        state = jax.random.choice(key, n_states, p=trans_mat[prev_state])\n",
-    "        return state, state\n",
-    "\n",
-    "    rng_key, rng_state = jax.random.split(rng_key, 2)\n",
-    "    keys = jax.random.split(rng_state, seq_len - 1)\n",
-    "    initial_state = jax.random.choice(rng_key, n_states, p=init_dist)\n",
-    "    final_state, states = jax.lax.scan(draw_state, initial_state, keys)\n",
-    "    state_seq = jnp.append(jnp.array([initial_state]), states)\n",
-    "\n",
-    "    return state_seq"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 90,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "#@partial(jit, static_argnums=(1,))\n",
-    "def hmm_sample(rng_key, params, seq_len):\n",
-    "\n",
-    "    trans_mat, obs_mat, init_dist = params.trans_mat, params.obs_mat, params.init_dist\n",
-    "    n_states, n_obs = obs_mat.shape\n",
-    "    rng_key, rng_obs = jax.random.split(rng_key, 2)\n",
-    "    state_seq = markov_chain_sample(rng_key, init_dist, trans_mat, seq_len)\n",
-    "\n",
-    "    def draw_obs(z, key):\n",
-    "        obs = jax.random.choice(key, n_obs, p=obs_mat[z])\n",
-    "        return obs\n",
-    "\n",
-    "    keys = jax.random.split(rng_obs, seq_len)\n",
-    "    obs_seq = jax.vmap(draw_obs, in_axes=(0, 0))(state_seq, keys)\n",
-    "    \n",
-    "    return state_seq, obs_seq"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 70,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "#@partial(jit, static_argnums=(1,))\n",
-    "def hmm_sample2(rng_key, params, seq_len):\n",
-    "\n",
-    "    trans_mat, obs_mat, init_dist = params.trans_mat, params.obs_mat, params.init_dist\n",
-    "    n_states, n_obs = obs_mat.shape\n",
-    "\n",
-    "    def draw_state(prev_state, key):\n",
-    "        state = jax.random.choice(key, n_states, p=trans_mat[prev_state])\n",
-    "        return state, state\n",
-    "\n",
-    "    rng_key, rng_state, rng_obs = jax.random.split(rng_key, 3)\n",
-    "    keys = jax.random.split(rng_state, seq_len - 1)\n",
-    "    initial_state = jax.random.choice(rng_key, n_states, p=init_dist)\n",
-    "    final_state, states = jax.lax.scan(draw_state, initial_state, keys)\n",
-    "    state_seq = jnp.append(jnp.array([initial_state]), states)\n",
-    "\n",
-    "    def draw_obs(z, key):\n",
-    "        obs = jax.random.choice(key, n_obs, p=obs_mat[z])\n",
-    "        return obs\n",
-    "\n",
-    "    keys = jax.random.split(rng_obs, seq_len)\n",
-    "    obs_seq = jax.vmap(draw_obs, in_axes=(0, 0))(state_seq, keys)\n",
-    "\n",
-    "    return state_seq, obs_seq"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 93,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "[1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
-      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1\n",
-      " 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
-      "[5 5 2 2 0 0 0 1 3 3 2 2 5 1 5 1 0 2 2 4 2 5 1 5 5 0 0 4 2 4 3 2 3 4 1 0 5\n",
-      " 2 2 2 1 4 3 2 2 2 4 1 0 3 5 2 5 1 4 2 5 2 5 0 5 4 4 4 2 2 0 4 5 2 2 0 1 5\n",
-      " 1 3 4 5 1 5 0 5 1 5 1 2 4 5 3 4 5 4 0 4 0 2 4 5 3 3]\n"
-     ]
-    }
-   ],
-   "source": [
-    "\n",
-    "key = PRNGKey(2)\n",
-    "seq_len = 100\n",
-    "\n",
-    "state_seq, obs_seq = hmm_sample(key, params, seq_len)\n",
-    "print(state_seq)\n",
-    "print(obs_seq)"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "### Check correctness by computing empirical pairwise statistics\n",
-    "\n",
-    "We will compute the number of i->j transitions, and check that it is close to the true \n",
-    "A[i,j] transition probabilites."
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 107,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "[0 0 1 1 1 1 0 0 1 1 1 0 1 0 0 1 1 1 1 0 1 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 1\n",
-      " 1 0 0 0 0 1 0 1 0 0 0 0 1 0 0 1 1 0 1 1 0 1 1 0 1 1 1 0 0 1 1 0 1 0 0 1 0\n",
-      " 1 0 0 0 0 0 0 0 0 0 0 1 1 1 1 0 0 0 1 1 1 0 1 0 0 0 0 1 0 0 0 0 1 1 0 0 0\n",
-      " 0 0 1 1 1 1 1 1 0 0 0 1 1 0 0 0 0 0 1 0 0 0 1 0 1 1 0 1 1 0 0 0 0 0 0 1 0\n",
-      " 0 0 0 1 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 1 1 1 0 1 1 0 0 0 0 0 1 0 0 0 0\n",
-      " 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 1 0 0 0 1 1 1 0 0 0 1 1 0 0 0\n",
-      " 0 0 0 1 1 1 0 0 0 0 1 0 0 1 1 1 0 1 1 1 1 1 0 1 1 0 0 0 1 1 0 1 0 0 1 0 0\n",
-      " 0 0 0 1 0 0 0 1 0 1 0 0 0 0 1 0 0 1 0 0 0 1 1 0 0 0 0 0 0 0 1 0 0 1 1 1 1\n",
-      " 1 1 0 0 0 0 0 1 1 0 0 0 0 0 0 1 0 0 1 0 0 0 0 0 0 1 0 0 0 0 1 0 1 0 0 0 1\n",
-      " 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 1 0 0 1 1 0 1 0 0 0\n",
-      " 0 0 0 0 0 0 0 1 0 0 1 1 1 1 0 0 1 1 0 0 0 0 1 1 0 1 1 0 0 0 0 0 0 0 0 1 0\n",
-      " 1 0 1 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0\n",
-      " 0 0 0 0 1 0 0 1 1 0 1 1 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 0 1 1 0 0 1 1 0 0 1\n",
-      " 1 0 0 0 0 0 0 0 1 0 0 1 1 0 0 0 0 1 1]\n",
-      "[[244.  93.]\n",
-      " [ 92.  70.]]\n",
-      "[[0.7240356  0.27596438]\n",
-      " [0.56790125 0.43209878]]\n"
-     ]
-    }
-   ],
-   "source": [
-    "import collections\n",
-    "def compute_counts(state_seq, nstates):\n",
-    "    wseq = np.array(state_seq)\n",
-    "    word_pairs = [pair for pair in zip(wseq[:-1], wseq[1:])]\n",
-    "    counter_pairs = collections.Counter(word_pairs)\n",
-    "    counts = np.zeros((nstates, nstates))\n",
-    "    for (k,v) in counter_pairs.items():\n",
-    "        counts[k[0], k[1]] = v\n",
-    "    return counts\n",
-    "\n",
-    "def normalize_counts(counts):\n",
-    "    ncounts = vmap(lambda v: normalize(v)[0], in_axes=0)(counts)\n",
-    "    return ncounts\n",
-    "\n",
-    "init_dist = jnp.array([1.0, 0.0])\n",
-    "trans_mat = jnp.array([[0.7, 0.3], [0.5, 0.5]])\n",
-    "rng_key = jax.random.PRNGKey(0)\n",
-    "seq_len = 500\n",
-    "state_seq = markov_chain_sample(rng_key, init_dist, trans_mat, seq_len)\n",
-    "print(state_seq)\n",
-    "\n",
-    "counts = compute_counts(state_seq, nstates=2)\n",
-    "print(counts)\n",
-    "\n",
-    "trans_mat_empirical = normalize_counts(counts)\n",
-    "print(trans_mat_empirical)\n",
-    "\n",
-    "assert jnp.allclose(trans_mat, trans_mat_empirical, atol=1e-1)\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": []
-  }
- ],
- "metadata": {
-  "kernelspec": {
-   "display_name": "Python 3",
-   "language": "python",
-   "name": "python3"
-  },
-  "language_info": {
-   "codemirror_mode": {
-    "name": "ipython",
-    "version": 3
-   },
-   "file_extension": ".py",
-   "mimetype": "text/x-python",
-   "name": "python",
-   "nbconvert_exporter": "python",
-   "pygments_lexer": "ipython3",
-   "version": "3.8.8"
-  }
- },
- "nbformat": 4,
- "nbformat_minor": 4
-}

+ 253 - 405
chapters/hmm/hmm_filter.ipynb

@@ -4,171 +4,8 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "```{math}\n",
-    "\n",
-    "\\newcommand{\\defeq}{\\triangleq}\n",
-    "\\newcommand{\\trans}{{\\mkern-1.5mu\\mathsf{T}}}\n",
-    "\\newcommand{\\transpose}[1]{{#1}^{\\trans}}\n",
-    "\n",
-    "\\newcommand{\\inv}[1]{{#1}^{-1}}\n",
-    "\\DeclareMathOperator{\\dotstar}{\\odot}\n",
-    "\n",
-    "\n",
-    "\\newcommand\\floor[1]{\\lfloor#1\\rfloor}\n",
-    "\n",
-    "\\newcommand{\\real}{\\mathbb{R}}\n",
-    "\n",
-    "% Numbers\n",
-    "\\newcommand{\\vzero}{\\boldsymbol{0}}\n",
-    "\\newcommand{\\vone}{\\boldsymbol{1}}\n",
-    "\n",
-    "% Greek https://www.latex-tutorial.com/symbols/greek-alphabet/\n",
-    "\\newcommand{\\valpha}{\\boldsymbol{\\alpha}}\n",
-    "\\newcommand{\\vbeta}{\\boldsymbol{\\beta}}\n",
-    "\\newcommand{\\vchi}{\\boldsymbol{\\chi}}\n",
-    "\\newcommand{\\vdelta}{\\boldsymbol{\\delta}}\n",
-    "\\newcommand{\\vDelta}{\\boldsymbol{\\Delta}}\n",
-    "\\newcommand{\\vepsilon}{\\boldsymbol{\\epsilon}}\n",
-    "\\newcommand{\\vzeta}{\\boldsymbol{\\zeta}}\n",
-    "\\newcommand{\\vXi}{\\boldsymbol{\\Xi}}\n",
-    "\\newcommand{\\vell}{\\boldsymbol{\\ell}}\n",
-    "\\newcommand{\\veta}{\\boldsymbol{\\eta}}\n",
-    "%\\newcommand{\\vEta}{\\boldsymbol{\\Eta}}\n",
-    "\\newcommand{\\vgamma}{\\boldsymbol{\\gamma}}\n",
-    "\\newcommand{\\vGamma}{\\boldsymbol{\\Gamma}}\n",
-    "\\newcommand{\\vmu}{\\boldsymbol{\\mu}}\n",
-    "\\newcommand{\\vmut}{\\boldsymbol{\\tilde{\\mu}}}\n",
-    "\\newcommand{\\vnu}{\\boldsymbol{\\nu}}\n",
-    "\\newcommand{\\vkappa}{\\boldsymbol{\\kappa}}\n",
-    "\\newcommand{\\vlambda}{\\boldsymbol{\\lambda}}\n",
-    "\\newcommand{\\vLambda}{\\boldsymbol{\\Lambda}}\n",
-    "\\newcommand{\\vLambdaBar}{\\overline{\\vLambda}}\n",
-    "%\\newcommand{\\vnu}{\\boldsymbol{\\nu}}\n",
-    "\\newcommand{\\vomega}{\\boldsymbol{\\omega}}\n",
-    "\\newcommand{\\vOmega}{\\boldsymbol{\\Omega}}\n",
-    "\\newcommand{\\vphi}{\\boldsymbol{\\phi}}\n",
-    "\\newcommand{\\vvarphi}{\\boldsymbol{\\varphi}}\n",
-    "\\newcommand{\\vPhi}{\\boldsymbol{\\Phi}}\n",
-    "\\newcommand{\\vpi}{\\boldsymbol{\\pi}}\n",
-    "\\newcommand{\\vPi}{\\boldsymbol{\\Pi}}\n",
-    "\\newcommand{\\vpsi}{\\boldsymbol{\\psi}}\n",
-    "\\newcommand{\\vPsi}{\\boldsymbol{\\Psi}}\n",
-    "\\newcommand{\\vrho}{\\boldsymbol{\\rho}}\n",
-    "\\newcommand{\\vtheta}{\\boldsymbol{\\theta}}\n",
-    "\\newcommand{\\vthetat}{\\boldsymbol{\\tilde{\\theta}}}\n",
-    "\\newcommand{\\vTheta}{\\boldsymbol{\\Theta}}\n",
-    "\\newcommand{\\vsigma}{\\boldsymbol{\\sigma}}\n",
-    "\\newcommand{\\vSigma}{\\boldsymbol{\\Sigma}}\n",
-    "\\newcommand{\\vSigmat}{\\boldsymbol{\\tilde{\\Sigma}}}\n",
-    "\\newcommand{\\vsigmoid}{\\vsigma}\n",
-    "\\newcommand{\\vtau}{\\boldsymbol{\\tau}}\n",
-    "\\newcommand{\\vxi}{\\boldsymbol{\\xi}}\n",
-    "\n",
-    "\n",
-    "% Lower Roman (Vectors)\n",
-    "\\newcommand{\\va}{\\mathbf{a}}\n",
-    "\\newcommand{\\vb}{\\mathbf{b}}\n",
-    "\\newcommand{\\vBt}{\\mathbf{\\tilde{B}}}\n",
-    "\\newcommand{\\vc}{\\mathbf{c}}\n",
-    "\\newcommand{\\vct}{\\mathbf{\\tilde{c}}}\n",
-    "\\newcommand{\\vd}{\\mathbf{d}}\n",
-    "\\newcommand{\\ve}{\\mathbf{e}}\n",
-    "\\newcommand{\\vf}{\\mathbf{f}}\n",
-    "\\newcommand{\\vg}{\\mathbf{g}}\n",
-    "\\newcommand{\\vh}{\\mathbf{h}}\n",
-    "%\\newcommand{\\myvh}{\\mathbf{h}}\n",
-    "\\newcommand{\\vi}{\\mathbf{i}}\n",
-    "\\newcommand{\\vj}{\\mathbf{j}}\n",
-    "\\newcommand{\\vk}{\\mathbf{k}}\n",
-    "\\newcommand{\\vl}{\\mathbf{l}}\n",
-    "\\newcommand{\\vm}{\\mathbf{m}}\n",
-    "\\newcommand{\\vn}{\\mathbf{n}}\n",
-    "\\newcommand{\\vo}{\\mathbf{o}}\n",
-    "\\newcommand{\\vp}{\\mathbf{p}}\n",
-    "\\newcommand{\\vq}{\\mathbf{q}}\n",
-    "\\newcommand{\\vr}{\\mathbf{r}}\n",
-    "\\newcommand{\\vs}{\\mathbf{s}}\n",
-    "\\newcommand{\\vt}{\\mathbf{t}}\n",
-    "\\newcommand{\\vu}{\\mathbf{u}}\n",
-    "\\newcommand{\\vv}{\\mathbf{v}}\n",
-    "\\newcommand{\\vw}{\\mathbf{w}}\n",
-    "\\newcommand{\\vws}{\\vw_s}\n",
-    "\\newcommand{\\vwt}{\\mathbf{\\tilde{w}}}\n",
-    "\\newcommand{\\vWt}{\\mathbf{\\tilde{W}}}\n",
-    "\\newcommand{\\vwh}{\\hat{\\vw}}\n",
-    "\\newcommand{\\vx}{\\mathbf{x}}\n",
-    "%\\newcommand{\\vx}{\\mathbf{x}}\n",
-    "\\newcommand{\\vxt}{\\mathbf{\\tilde{x}}}\n",
-    "\\newcommand{\\vy}{\\mathbf{y}}\n",
-    "\\newcommand{\\vyt}{\\mathbf{\\tilde{y}}}\n",
-    "\\newcommand{\\vz}{\\mathbf{z}}\n",
-    "%\\newcommand{\\vzt}{\\mathbf{\\tilde{z}}}\n",
-    "\n",
-    "\n",
-    "% Upper Roman (Matrices)\n",
-    "\\newcommand{\\vA}{\\mathbf{A}}\n",
-    "\\newcommand{\\vB}{\\mathbf{B}}\n",
-    "\\newcommand{\\vC}{\\mathbf{C}}\n",
-    "\\newcommand{\\vD}{\\mathbf{D}}\n",
-    "\\newcommand{\\vE}{\\mathbf{E}}\n",
-    "\\newcommand{\\vF}{\\mathbf{F}}\n",
-    "\\newcommand{\\vG}{\\mathbf{G}}\n",
-    "\\newcommand{\\vH}{\\mathbf{H}}\n",
-    "\\newcommand{\\vI}{\\mathbf{I}}\n",
-    "\\newcommand{\\vJ}{\\mathbf{J}}\n",
-    "\\newcommand{\\vK}{\\mathbf{K}}\n",
-    "\\newcommand{\\vL}{\\mathbf{L}}\n",
-    "\\newcommand{\\vM}{\\mathbf{M}}\n",
-    "\\newcommand{\\vMt}{\\mathbf{\\tilde{M}}}\n",
-    "\\newcommand{\\vN}{\\mathbf{N}}\n",
-    "\\newcommand{\\vO}{\\mathbf{O}}\n",
-    "\\newcommand{\\vP}{\\mathbf{P}}\n",
-    "\\newcommand{\\vQ}{\\mathbf{Q}}\n",
-    "\\newcommand{\\vR}{\\mathbf{R}}\n",
-    "\\newcommand{\\vS}{\\mathbf{S}}\n",
-    "\\newcommand{\\vT}{\\mathbf{T}}\n",
-    "\\newcommand{\\vU}{\\mathbf{U}}\n",
-    "\\newcommand{\\vV}{\\mathbf{V}}\n",
-    "\\newcommand{\\vW}{\\mathbf{W}}\n",
-    "\\newcommand{\\vX}{\\mathbf{X}}\n",
-    "%\\newcommand{\\vXs}{\\vX_{\\vs}}\n",
-    "\\newcommand{\\vXs}{\\vX_{s}}\n",
-    "\\newcommand{\\vXt}{\\mathbf{\\tilde{X}}}\n",
-    "\\newcommand{\\vY}{\\mathbf{Y}}\n",
-    "\\newcommand{\\vZ}{\\mathbf{Z}}\n",
-    "\\newcommand{\\vZt}{\\mathbf{\\tilde{Z}}}\n",
-    "\\newcommand{\\vzt}{\\mathbf{\\tilde{z}}}\n",
-    "\n",
-    "\n",
-    "%%%%\n",
-    "\\newcommand{\\hidden}{\\vz}\n",
-    "\\newcommand{\\hid}{\\hidden}\n",
-    "\\newcommand{\\observed}{\\vy}\n",
-    "\\newcommand{\\obs}{\\observed}\n",
-    "\\newcommand{\\inputs}{\\vu}\n",
-    "\\newcommand{\\input}{\\inputs}\n",
-    "\n",
-    "\\newcommand{\\hmmTrans}{\\vA}\n",
-    "\\newcommand{\\hmmObs}{\\vB}\n",
-    "\\newcommand{\\hmmInit}{\\vpi}\n",
-    "\n",
-    "\n",
-    "\\newcommand{\\ldsDyn}{\\vA}\n",
-    "\\newcommand{\\ldsObs}{\\vC}\n",
-    "\\newcommand{\\ldsDynIn}{\\vB}\n",
-    "\\newcommand{\\ldsObsIn}{\\vD}\n",
-    "\\newcommand{\\ldsDynNoise}{\\vQ}\n",
-    "\\newcommand{\\ldsObsNoise}{\\vR}\n",
-    "\n",
-    "\\newcommand{\\ssmDynFn}{f}\n",
-    "\\newcommand{\\ssmObsFn}{h}\n",
-    "\n",
-    "\n",
-    "%%%\n",
-    "\\newcommand{\\gauss}{\\mathcal{N}}\n",
-    "\n",
-    "\\newcommand{\\diag}{\\mathrm{diag}}\n",
-    "```\n"
+    "(sec:forwards)=\n",
+    "# HMM filtering (forwards algorithm)"
    ]
   },
   {
@@ -177,141 +14,85 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "# meta-data does not work yet in VScode\n",
-    "# https://github.com/microsoft/vscode-jupyter/issues/1121\n",
-    "\n",
-    "{\n",
-    "    \"tags\": [\n",
-    "        \"hide-cell\"\n",
-    "    ]\n",
-    "}\n",
-    "\n",
-    "\n",
-    "### Install necessary libraries\n",
-    "\n",
-    "try:\n",
-    "    import jax\n",
-    "except:\n",
-    "    # For cuda version, see https://github.com/google/jax#installation\n",
-    "    %pip install --upgrade \"jax[cpu]\" \n",
-    "    import jax\n",
-    "\n",
-    "try:\n",
-    "    import distrax\n",
-    "except:\n",
-    "    %pip install --upgrade  distrax\n",
-    "    import distrax\n",
-    "\n",
-    "try:\n",
-    "    import jsl\n",
-    "except:\n",
-    "    %pip install git+https://github.com/probml/jsl\n",
-    "    import jsl\n",
-    "\n",
-    "#try:\n",
-    "#    import ssm_jax\n",
-    "##except:\n",
-    "#    %pip install git+https://github.com/probml/ssm-jax\n",
-    "#    import ssm_jax\n",
-    "\n",
-    "try:\n",
-    "    import rich\n",
-    "except:\n",
-    "    %pip install rich\n",
-    "    import rich\n",
-    "\n",
-    "\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 4,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "{\n",
-    "    \"tags\": [\n",
-    "        \"hide-cell\"\n",
-    "    ]\n",
-    "}\n",
-    "\n",
-    "\n",
     "### Import standard libraries\n",
     "\n",
     "import abc\n",
     "from dataclasses import dataclass\n",
     "import functools\n",
+    "from functools import partial\n",
     "import itertools\n",
-    "\n",
-    "from typing import Any, Callable, NamedTuple, Optional, Union, Tuple\n",
-    "\n",
     "import matplotlib.pyplot as plt\n",
     "import numpy as np\n",
-    "\n",
+    "from typing import Any, Callable, NamedTuple, Optional, Union, Tuple\n",
     "\n",
     "import jax\n",
     "import jax.numpy as jnp\n",
     "from jax import lax, vmap, jit, grad\n",
-    "from jax.scipy.special import logit\n",
-    "from jax.nn import softmax\n",
-    "from functools import partial\n",
-    "from jax.random import PRNGKey, split\n",
+    "#from jax.scipy.special import logit\n",
+    "#from jax.nn import softmax\n",
+    "import jax.random as jr\n",
     "\n",
-    "import inspect\n",
-    "import inspect as py_inspect\n",
-    "import rich\n",
-    "from rich import inspect as r_inspect\n",
-    "from rich import print as r_print\n",
     "\n",
-    "def print_source(fname):\n",
-    "    r_print(py_inspect.getsource(fname))"
+    "\n",
+    "import distrax\n",
+    "import optax\n",
+    "\n",
+    "import jsl\n",
+    "import ssm_jax"
    ]
   },
   {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "(sec:forwards)=\n",
-    "# HMM filtering (forwards algorithm)\n",
+    "\n",
+    "## Introduction\n",
     "\n",
     "\n",
-    "The  **Bayes filter** is an algorithm for recursively computing\n",
+    "The  $\\keyword{Bayes filter}$ is an algorithm for recursively computing\n",
     "the belief state\n",
     "$p(\\hidden_t|\\obs_{1:t})$ given\n",
     "the prior belief from the previous step,\n",
     "$p(\\hidden_{t-1}|\\obs_{1:t-1})$,\n",
     "the new observation $\\obs_t$,\n",
     "and the model.\n",
-    "This can be done using **sequential Bayesian updating**.\n",
+    "This can be done using $\\keyword{sequential Bayesian updating}$.\n",
     "For a dynamical model, this reduces to the\n",
-    "**predict-update** cycle described below.\n",
+    "$\\keyword{predict-update}$ cycle described below.\n",
     "\n",
-    "\n",
-    "The **prediction step** is just the **Chapman-Kolmogorov equation**:\n",
-    "```{math}\n",
+    "The $\\keyword{prediction step}$ is just the $\\keyword{Chapman-Kolmogorov equation}$:\n",
+    "\\begin{align}\n",
     "p(\\hidden_t|\\obs_{1:t-1})\n",
     "= \\int p(\\hidden_t|\\hidden_{t-1}) p(\\hidden_{t-1}|\\obs_{1:t-1}) d\\hidden_{t-1}\n",
-    "```\n",
+    "\\end{align}\n",
     "The prediction step computes\n",
-    "the one-step-ahead predictive distribution\n",
-    "for the latent state, which updates\n",
-    "the posterior from the previous time step into the prior\n",
+    "the $\\keyword{one-step-ahead predictive distribution}$\n",
+    "for the latent state, which converts\n",
+    "the posterior from the previous time step to become the prior\n",
     "for the current step.\n",
     "\n",
     "\n",
-    "The **update step**\n",
+    "The $\\keyword{update step}$\n",
     "is just Bayes rule:\n",
-    "```{math}\n",
+    "\\begin{align}\n",
     "p(\\hidden_t|\\obs_{1:t}) = \\frac{1}{Z_t}\n",
     "p(\\obs_t|\\hidden_t) p(\\hidden_t|\\obs_{1:t-1})\n",
-    "```\n",
+    "\\end{align}\n",
     "where the normalization constant is\n",
-    "```{math}\n",
+    "\\begin{align}\n",
     "Z_t = \\int p(\\obs_t|\\hidden_t) p(\\hidden_t|\\obs_{1:t-1}) d\\hidden_{t}\n",
     "= p(\\obs_t|\\obs_{1:t-1})\n",
-    "```\n",
+    "\\end{align}\n",
     "\n",
+    "Note that we can derive the log marginal likelihood from these normalization constants\n",
+    "as follows:\n",
+    "```{math}\n",
+    ":label: eqn:logZ\n",
     "\n",
+    "\\log p(\\obs_{1:T})\n",
+    "= \\sum_{t=1}^{T} \\log p(\\obs_t|\\obs_{1:t-1})\n",
+    "= \\sum_{t=1}^{T} \\log Z_t\n",
+    "```\n",
     "\n"
    ]
   },
@@ -323,10 +104,11 @@
     "When the latent states $\\hidden_t$ are discrete, as in HMM,\n",
     "the above integrals become sums.\n",
     "In particular, suppose we define\n",
-    "the belief state as $\\alpha_t(j) \\defeq p(\\hidden_t=j|\\obs_{1:t})$,\n",
-    "the local evidence as $\\lambda_t(j) \\defeq p(\\obs_t|\\hidden_t=j)$,\n",
-    "and the transition matrix\n",
-    "$A(i,j)  = p(\\hidden_t=j|\\hidden_{t-1}=i)$.\n",
+    "the $\\keyword{belief state}$ as $\\alpha_t(j) \\defeq p(\\hidden_t=j|\\obs_{1:t})$,\n",
+    "the  $\\keyword{local evidence}$ (or $\\keyword{local likelihood}$)\n",
+    "as $\\lambda_t(j) \\defeq p(\\obs_t|\\hidden_t=j)$,\n",
+    "and the transition matrix as\n",
+    "$\\hmmTrans(i,j)  = p(\\hidden_t=j|\\hidden_{t-1}=i)$.\n",
     "Then the predict step becomes\n",
     "```{math}\n",
     ":label: eqn:predictiveHMM\n",
@@ -338,7 +120,7 @@
     ":label: eqn:fwdsEqn\n",
     "\\alpha_t(j)\n",
     "= \\frac{1}{Z_t} \\lambda_t(j) \\alpha_{t|t-1}(j)\n",
-    "= \\frac{1}{Z_t} \\lambda_t(j) \\left[\\sum_i \\alpha_{t-1}(i) A(i,j)  \\right]\n",
+    "= \\frac{1}{Z_t} \\lambda_t(j) \\left[\\sum_i \\alpha_{t-1}(i) \\hmmTrans(i,j)  \\right]\n",
     "```\n",
     "where\n",
     "the  normalization constant for each time step is given by\n",
@@ -350,20 +132,33 @@
     "&=  \\sum_{j=1}^K \\lambda_t(j) \\alpha_{t|t-1}(j)\n",
     "\\end{align}\n",
     "```\n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
     "\n",
     "Since all the quantities are finite length vectors and matrices,\n",
-    "we can write the update equation\n",
-    "in matrix-vector notation as follows:\n",
+    "we can implement the whole procedure using matrix vector multoplication:\n",
     "```{math}\n",
+    ":label: eqn:fwdsAlgoMatrixForm\n",
     "\\valpha_t =\\text{normalize}\\left(\n",
-    "\\vlambda_t \\dotstar  (\\vA^{\\trans} \\valpha_{t-1}) \\right)\n",
-    "\\label{eqn:fwdsAlgoMatrixForm}\n",
+    "\\vlambda_t \\dotstar  (\\hmmTrans^{\\trans} \\valpha_{t-1}) \\right)\n",
     "```\n",
     "where $\\dotstar$ represents\n",
     "elementwise vector multiplication,\n",
-    "and the $\\text{normalize}$ function just ensures its argument sums to one.\n",
+    "and the $\\text{normalize}$ function just ensures its argument sums to one.\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Example\n",
     "\n",
-    "In {ref}(sec:casino-inference)\n",
+    "In {ref}`sec:casino-inference`\n",
     "we illustrate\n",
     "filtering for the casino HMM,\n",
     "applied to a random sequence $\\obs_{1:T}$ of length $T=300$.\n",
@@ -371,156 +166,209 @@
     "based on the evidence seen so far.\n",
     "The gray bars indicate time intervals during which the generative\n",
     "process actually switched to the loaded dice.\n",
-    "We see that the probability generally increases in the right places.\n"
+    "We see that the probability generally increases in the right places."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Normalization constants\n",
+    "\n",
+    "In most publications on HMMs,\n",
+    "such as {cite}`Rabiner89`,\n",
+    "the forwards message is defined\n",
+    "as the following unnormalized joint probability:\n",
+    "```{math}\n",
+    "\\alpha'_t(j) = p(\\hidden_t=j,\\obs_{1:t}) \n",
+    "= \\lambda_t(j) \\left[\\sum_i \\alpha'_{t-1}(i) A(i,j)  \\right]\n",
+    "```\n",
+    "In this book we define the forwards message   as the normalized\n",
+    "conditional probability\n",
+    "```{math}\n",
+    "\\alpha_t(j) = p(\\hidden_t=j|\\obs_{1:t}) \n",
+    "= \\frac{1}{Z_t} \\lambda_t(j) \\left[\\sum_i \\alpha_{t-1}(i) A(i,j)  \\right]\n",
+    "```\n",
+    "where $Z_t = p(\\obs_t|\\obs_{1:t-1})$.\n",
+    "\n",
+    "The \"traditional\" unnormalized form has several problems.\n",
+    "First, it rapidly suffers from numerical underflow,\n",
+    "since the probability of\n",
+    "the joint event that $(\\hidden_t=j,\\obs_{1:t})$\n",
+    "is vanishingly small. \n",
+    "To see why, suppose the observations are independent of the states.\n",
+    "In this case, the unnormalized joint has the form\n",
+    "\\begin{align}\n",
+    "p(\\hidden_t=j,\\obs_{1:t}) = p(\\hidden_t=j)\\prod_{i=1}^t p(\\obs_i)\n",
+    "\\end{align}\n",
+    "which becomes exponentially small with $t$, because we multiply\n",
+    "many probabilities which are less than one.\n",
+    "Second, the unnormalized probability is less interpretable,\n",
+    "since it is a joint distribution over states and observations,\n",
+    "rather than a conditional probability of states given observations.\n",
+    "Third, the unnormalized joint form is harder to approximate\n",
+    "than the normalized form.\n",
+    "Of course,\n",
+    "the two definitions only differ by a\n",
+    "multiplicative constant\n",
+    "{cite}`Devijver85`,\n",
+    "so the algorithmic difference is just\n",
+    "one line of code (namely the presence or absence of a call to the `normalize` function).\n",
+    "\n",
+    "\n",
+    "\n",
+    "\n"
    ]
   },
   {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "Here is a JAX implementation of the forwards algorithm."
+    "## Naive implementation\n",
+    "\n",
+    "Below we give a simple numpy implementation of the forwards algorithm.\n",
+    "We assume the HMM uses categorical observations, for simplicity.\n",
+    "\n"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 5,
+   "execution_count": null,
    "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/html": [
-       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">@jit\n",
-       "def <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">hmm_forwards_jax</span><span style=\"font-weight: bold\">(</span>params, obs_seq, <span style=\"color: #808000; text-decoration-color: #808000\">length</span>=<span style=\"color: #800080; text-decoration-color: #800080; font-style: italic\">None</span><span style=\"font-weight: bold\">)</span>:\n",
-       "    <span style=\"color: #008000; text-decoration-color: #008000\">''</span>'\n",
-       "    Calculates a belief state\n",
-       "\n",
-       "    Parameters\n",
-       "    ----------\n",
-       "    params : HMMJax\n",
-       "        Hidden Markov Model\n",
-       "\n",
-       "    obs_seq: <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">array</span><span style=\"font-weight: bold\">(</span>seq_len<span style=\"font-weight: bold\">)</span>\n",
-       "        History of observable events\n",
-       "\n",
-       "    Returns\n",
-       "    -------\n",
-       "    * float\n",
-       "        The loglikelihood giving <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">log</span><span style=\"font-weight: bold\">(</span><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">p</span><span style=\"font-weight: bold\">(</span>x|model<span style=\"font-weight: bold\">))</span>\n",
-       "\n",
-       "    * <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">array</span><span style=\"font-weight: bold\">(</span>seq_len, n_hidden<span style=\"font-weight: bold\">)</span> :\n",
-       "        All alpha values found for each sample\n",
-       "    <span style=\"color: #008000; text-decoration-color: #008000\">''</span>'\n",
-       "    seq_len = <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">len</span><span style=\"font-weight: bold\">(</span>obs_seq<span style=\"font-weight: bold\">)</span>\n",
-       "\n",
-       "    if length is <span style=\"color: #800080; text-decoration-color: #800080; font-style: italic\">None</span>:\n",
-       "        length = seq_len\n",
-       "\n",
-       "    trans_mat, obs_mat, init_dist = params.trans_mat, params.obs_mat, params.init_dist\n",
-       "\n",
-       "    trans_mat = <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">jnp.array</span><span style=\"font-weight: bold\">(</span>trans_mat<span style=\"font-weight: bold\">)</span>\n",
-       "    obs_mat = <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">jnp.array</span><span style=\"font-weight: bold\">(</span>obs_mat<span style=\"font-weight: bold\">)</span>\n",
-       "    init_dist = <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">jnp.array</span><span style=\"font-weight: bold\">(</span>init_dist<span style=\"font-weight: bold\">)</span>\n",
-       "\n",
-       "    n_states, n_obs = obs_mat.shape\n",
-       "\n",
-       "    def <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">scan_fn</span><span style=\"font-weight: bold\">(</span>carry, t<span style=\"font-weight: bold\">)</span>:\n",
-       "        <span style=\"font-weight: bold\">(</span>alpha_prev, log_ll_prev<span style=\"font-weight: bold\">)</span> = carry\n",
-       "        alpha_n = <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">jnp.where</span><span style=\"font-weight: bold\">(</span>t &lt; length,\n",
-       "                            obs_mat<span style=\"font-weight: bold\">[</span>:, obs_seq<span style=\"font-weight: bold\">]</span> * <span style=\"font-weight: bold\">(</span>alpha_prev<span style=\"font-weight: bold\">[</span>:, <span style=\"color: #800080; text-decoration-color: #800080; font-style: italic\">None</span><span style=\"font-weight: bold\">]</span> * \n",
-       "trans_mat<span style=\"font-weight: bold\">)</span><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">.sum</span><span style=\"font-weight: bold\">(</span><span style=\"color: #808000; text-decoration-color: #808000\">axis</span>=<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span><span style=\"font-weight: bold\">)</span>,\n",
-       "                            <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">jnp.zeros_like</span><span style=\"font-weight: bold\">(</span>alpha_prev<span style=\"font-weight: bold\">))</span>\n",
-       "\n",
-       "        alpha_n, cn = <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">normalize</span><span style=\"font-weight: bold\">(</span>alpha_n<span style=\"font-weight: bold\">)</span>\n",
-       "        carry = <span style=\"font-weight: bold\">(</span>alpha_n, <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">jnp.log</span><span style=\"font-weight: bold\">(</span>cn<span style=\"font-weight: bold\">)</span> + log_ll_prev<span style=\"font-weight: bold\">)</span>\n",
-       "\n",
-       "        return carry, alpha_n\n",
-       "\n",
-       "    # initial belief state\n",
-       "    alpha_0, c0 = <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">normalize</span><span style=\"font-weight: bold\">(</span>init_dist * obs_mat<span style=\"font-weight: bold\">[</span>:, obs_seq<span style=\"font-weight: bold\">[</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span><span style=\"font-weight: bold\">]])</span>\n",
-       "\n",
-       "    # setup scan loop\n",
-       "    init_state = <span style=\"font-weight: bold\">(</span>alpha_0, <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">jnp.log</span><span style=\"font-weight: bold\">(</span>c0<span style=\"font-weight: bold\">))</span>\n",
-       "    ts = <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">jnp.arange</span><span style=\"font-weight: bold\">(</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span>, seq_len<span style=\"font-weight: bold\">)</span>\n",
-       "    carry, alpha_hist = <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">lax.scan</span><span style=\"font-weight: bold\">(</span>scan_fn, init_state, ts<span style=\"font-weight: bold\">)</span>\n",
-       "\n",
-       "    # post-process\n",
-       "    alpha_hist = <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">jnp.vstack</span><span style=\"font-weight: bold\">()</span>\n",
-       "    <span style=\"font-weight: bold\">(</span>alpha_final, log_ll<span style=\"font-weight: bold\">)</span> = carry\n",
-       "    return log_ll, alpha_hist\n",
-       "\n",
-       "</pre>\n"
-      ],
-      "text/plain": [
-       "@jit\n",
-       "def \u001b[1;35mhmm_forwards_jax\u001b[0m\u001b[1m(\u001b[0mparams, obs_seq, \u001b[33mlength\u001b[0m=\u001b[3;35mNone\u001b[0m\u001b[1m)\u001b[0m:\n",
-       "    \u001b[32m''\u001b[0m'\n",
-       "    Calculates a belief state\n",
-       "\n",
-       "    Parameters\n",
-       "    ----------\n",
-       "    params : HMMJax\n",
-       "        Hidden Markov Model\n",
-       "\n",
-       "    obs_seq: \u001b[1;35marray\u001b[0m\u001b[1m(\u001b[0mseq_len\u001b[1m)\u001b[0m\n",
-       "        History of observable events\n",
-       "\n",
-       "    Returns\n",
-       "    -------\n",
-       "    * float\n",
-       "        The loglikelihood giving \u001b[1;35mlog\u001b[0m\u001b[1m(\u001b[0m\u001b[1;35mp\u001b[0m\u001b[1m(\u001b[0mx|model\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m\n",
-       "\n",
-       "    * \u001b[1;35marray\u001b[0m\u001b[1m(\u001b[0mseq_len, n_hidden\u001b[1m)\u001b[0m :\n",
-       "        All alpha values found for each sample\n",
-       "    \u001b[32m''\u001b[0m'\n",
-       "    seq_len = \u001b[1;35mlen\u001b[0m\u001b[1m(\u001b[0mobs_seq\u001b[1m)\u001b[0m\n",
-       "\n",
-       "    if length is \u001b[3;35mNone\u001b[0m:\n",
-       "        length = seq_len\n",
-       "\n",
-       "    trans_mat, obs_mat, init_dist = params.trans_mat, params.obs_mat, params.init_dist\n",
-       "\n",
-       "    trans_mat = \u001b[1;35mjnp.array\u001b[0m\u001b[1m(\u001b[0mtrans_mat\u001b[1m)\u001b[0m\n",
-       "    obs_mat = \u001b[1;35mjnp.array\u001b[0m\u001b[1m(\u001b[0mobs_mat\u001b[1m)\u001b[0m\n",
-       "    init_dist = \u001b[1;35mjnp.array\u001b[0m\u001b[1m(\u001b[0minit_dist\u001b[1m)\u001b[0m\n",
-       "\n",
-       "    n_states, n_obs = obs_mat.shape\n",
-       "\n",
-       "    def \u001b[1;35mscan_fn\u001b[0m\u001b[1m(\u001b[0mcarry, t\u001b[1m)\u001b[0m:\n",
-       "        \u001b[1m(\u001b[0malpha_prev, log_ll_prev\u001b[1m)\u001b[0m = carry\n",
-       "        alpha_n = \u001b[1;35mjnp.where\u001b[0m\u001b[1m(\u001b[0mt < length,\n",
-       "                            obs_mat\u001b[1m[\u001b[0m:, obs_seq\u001b[1m]\u001b[0m * \u001b[1m(\u001b[0malpha_prev\u001b[1m[\u001b[0m:, \u001b[3;35mNone\u001b[0m\u001b[1m]\u001b[0m * \n",
-       "trans_mat\u001b[1m)\u001b[0m\u001b[1;35m.sum\u001b[0m\u001b[1m(\u001b[0m\u001b[33maxis\u001b[0m=\u001b[1;36m0\u001b[0m\u001b[1m)\u001b[0m,\n",
-       "                            \u001b[1;35mjnp.zeros_like\u001b[0m\u001b[1m(\u001b[0malpha_prev\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m\n",
-       "\n",
-       "        alpha_n, cn = \u001b[1;35mnormalize\u001b[0m\u001b[1m(\u001b[0malpha_n\u001b[1m)\u001b[0m\n",
-       "        carry = \u001b[1m(\u001b[0malpha_n, \u001b[1;35mjnp.log\u001b[0m\u001b[1m(\u001b[0mcn\u001b[1m)\u001b[0m + log_ll_prev\u001b[1m)\u001b[0m\n",
-       "\n",
-       "        return carry, alpha_n\n",
-       "\n",
-       "    # initial belief state\n",
-       "    alpha_0, c0 = \u001b[1;35mnormalize\u001b[0m\u001b[1m(\u001b[0minit_dist * obs_mat\u001b[1m[\u001b[0m:, obs_seq\u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m\n",
-       "\n",
-       "    # setup scan loop\n",
-       "    init_state = \u001b[1m(\u001b[0malpha_0, \u001b[1;35mjnp.log\u001b[0m\u001b[1m(\u001b[0mc0\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m\n",
-       "    ts = \u001b[1;35mjnp.arange\u001b[0m\u001b[1m(\u001b[0m\u001b[1;36m1\u001b[0m, seq_len\u001b[1m)\u001b[0m\n",
-       "    carry, alpha_hist = \u001b[1;35mlax.scan\u001b[0m\u001b[1m(\u001b[0mscan_fn, init_state, ts\u001b[1m)\u001b[0m\n",
-       "\n",
-       "    # post-process\n",
-       "    alpha_hist = \u001b[1;35mjnp.vstack\u001b[0m\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\n",
-       "    \u001b[1m(\u001b[0malpha_final, log_ll\u001b[1m)\u001b[0m = carry\n",
-       "    return log_ll, alpha_hist\n",
-       "\n"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    }
-   ],
+   "outputs": [],
    "source": [
-    "import jsl.hmm.hmm_lib as hmm_lib\n",
-    "print_source(hmm_lib.hmm_forwards_jax)\n",
-    "#https://github.com/probml/JSL/blob/main/jsl/hmm/hmm_lib.py#L189\n",
-    "\n"
+    "\n",
+    "\n",
+    "def normalize_np(u, axis=0, eps=1e-15):\n",
+    "    u = np.where(u == 0, 0, np.where(u < eps, eps, u))\n",
+    "    c = u.sum(axis=axis)\n",
+    "    c = np.where(c == 0, 1, c)\n",
+    "    return u / c, c\n",
+    "\n",
+    "def hmm_forwards_np(trans_mat, obs_mat, init_dist, obs_seq):\n",
+    "    n_states, n_obs = obs_mat.shape\n",
+    "    seq_len = len(obs_seq)\n",
+    "\n",
+    "    alpha_hist = np.zeros((seq_len, n_states))\n",
+    "    ll_hist = np.zeros(seq_len)  # loglikelihood history\n",
+    "\n",
+    "    alpha_n = init_dist * obs_mat[:, obs_seq[0]]\n",
+    "    alpha_n, cn = normalize_np(alpha_n)\n",
+    "\n",
+    "    alpha_hist[0] = alpha_n\n",
+    "    log_normalizer = np.log(cn)\n",
+    "\n",
+    "    for t in range(1, seq_len):\n",
+    "        alpha_n = obs_mat[:, obs_seq[t]] * (alpha_n[:, None] * trans_mat).sum(axis=0)\n",
+    "        alpha_n, zn = normalize_np(alpha_n)\n",
+    "\n",
+    "        alpha_hist[t] = alpha_n\n",
+    "        log_normalizer = np.log(zn) + log_normalizer\n",
+    "\n",
+    "    return  log_normalizer, alpha_hist"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Numerically stable implementation \n",
+    "\n",
+    "\n",
+    "\n",
+    "In practice it is more numerically stable to compute\n",
+    "the log likelihoods $\\ell_t(j) = \\log p(\\obs_t|\\hidden_t=j)$,\n",
+    "rather than the likelioods $\\lambda_t(j) = p(\\obs_t|\\hidden_t=j)$.\n",
+    "In this case, we can perform the posterior updating in a numerically stable way as follows.\n",
+    "Define $L_t = \\max_j \\ell_t(j)$ and\n",
+    "\\begin{align}\n",
+    "\\tilde{p}(\\hidden_t=j,\\obs_t|\\obs_{1:t-1})\n",
+    "&\\defeq p(\\hidden_t=j|\\obs_{1:t-1}) p(\\obs_t|\\hidden_t=j) e^{-L_t} \\\\\n",
+    " &= p(\\hidden_t=j|\\obs_{1:t-1}) e^{\\ell_t(j) - L_t}\n",
+    "\\end{align}\n",
+    "Then we have\n",
+    "\\begin{align}\n",
+    "p(\\hidden_t=j|\\obs_t,\\obs_{1:t-1})\n",
+    "  &= \\frac{1}{\\tilde{Z}_t} \\tilde{p}(\\hidden_t=j,\\obs_t|\\obs_{1:t-1}) \\\\\n",
+    "\\tilde{Z}_t &= \\sum_j \\tilde{p}(\\hidden_t=j,\\obs_t|\\obs_{1:t-1})\n",
+    "= p(\\obs_t|\\obs_{1:t-1}) e^{-L_t} \\\\\n",
+    "\\log Z_t &= \\log p(\\obs_t|\\obs_{1:t-1}) = \\log \\tilde{Z}_t + L_t\n",
+    "\\end{align}\n",
+    "\n",
+    "Below we show some JAX code that implements this core operation.\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "\n",
+    "def _condition_on(probs, ll):\n",
+    "    ll_max = ll.max()\n",
+    "    new_probs = probs * jnp.exp(ll - ll_max)\n",
+    "    norm = new_probs.sum()\n",
+    "    new_probs /= norm\n",
+    "    log_norm = jnp.log(norm) + ll_max\n",
+    "    return new_probs, log_norm"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "With the above function, we can implement a more numerically stable version of the forwards filter,\n",
+    "that works for any likelihood function, as shown below. It takes in the prior predictive distribution,\n",
+    "$\\alpha_{t|t-1}$,\n",
+    "stored in `predicted_probs`, and conditions them on the log-likelihood for each time step $\\ell_t$ to get the\n",
+    "posterior, $\\alpha_t$, stored in `filtered_probs`,\n",
+    "which is then converted to the prediction for the next state, $\\alpha_{t+1|t}$."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def _predict(probs, A):\n",
+    "    return A.T @ probs\n",
+    "\n",
+    "\n",
+    "def hmm_filter(initial_distribution,\n",
+    "               transition_matrix,\n",
+    "               log_likelihoods):\n",
+    "    def _step(carry, t):\n",
+    "        log_normalizer, predicted_probs = carry\n",
+    "\n",
+    "        # Get parameters for time t\n",
+    "        get = lambda x: x[t] if x.ndim == 3 else x\n",
+    "        A = get(transition_matrix)\n",
+    "        ll = log_likelihoods[t]\n",
+    "\n",
+    "        # Condition on emissions at time t, being careful not to overflow\n",
+    "        filtered_probs, log_norm = _condition_on(predicted_probs, ll)\n",
+    "        # Update the log normalizer\n",
+    "        log_normalizer += log_norm\n",
+    "        # Predict the next state\n",
+    "        predicted_probs = _predict(filtered_probs, A)\n",
+    "\n",
+    "        return (log_normalizer, predicted_probs), (filtered_probs, predicted_probs)\n",
+    "\n",
+    "    num_timesteps = len(log_likelihoods)\n",
+    "    carry = (0.0, initial_distribution)\n",
+    "    (log_normalizer, _), (filtered_probs, predicted_probs) = lax.scan(\n",
+    "        _step, carry, jnp.arange(num_timesteps))\n",
+    "    return log_normalizer, filtered_probs, predicted_probs"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "\n",
+    "TODO: check equivalence of these two implementations!"
    ]
   }
  ],

+ 1 - 1
chapters/hmm/hmm_index.md

@@ -2,7 +2,7 @@
 # Hidden Markov Models 
 
 This chapter discusses Hidden Markov Models (HMMs), which are state space models
-in which the latent state $z_t \in \{1,\ldots,K\}$ is discrete.
+in which the latent state $\hidden_t \in \{1,\ldots,\nstates\}$ is discrete.
 
 
 ```{tableofcontents}

+ 1 - 1
chapters/scratch.md

@@ -102,7 +102,7 @@ I am a useful note!
 
 ## Math
 
-Here is $\N=10$ and blah. $\floor{42.3}= 42$.
+Here is $\N=10$ and blah. $\floor{42.3}= 42$. Let's try again.
 
 We have $E= mc^2$, and also
 

File diff suppressed because it is too large
+ 552 - 0
chapters/scratchpad.ipynb


+ 36 - 239
chapters/ssm/hmm.ipynb

@@ -2,48 +2,27 @@
  "cells": [
   {
    "cell_type": "code",
-   "execution_count": 1,
+   "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
-    "# meta-data does not work yet in VScode\n",
-    "# https://github.com/microsoft/vscode-jupyter/issues/1121\n",
-    "\n",
-    "{\n",
-    "    \"tags\": [\n",
-    "        \"hide-cell\"\n",
-    "    ]\n",
-    "}\n",
-    "\n",
-    "\n",
-    "### Install necessary libraries\n",
-    "\n",
-    "try:\n",
-    "    import jax\n",
-    "except:\n",
-    "    # For cuda version, see https://github.com/google/jax#installation\n",
-    "    %pip install --upgrade \"jax[cpu]\" \n",
-    "    import jax\n",
-    "\n",
-    "try:\n",
-    "    import distrax\n",
-    "except:\n",
-    "    %pip install --upgrade  distrax\n",
-    "    import distrax\n",
-    "\n",
-    "try:\n",
-    "    import jsl\n",
-    "except:\n",
-    "    %pip install git+https://github.com/probml/jsl\n",
-    "    import jsl\n",
-    "\n",
-    "try:\n",
-    "    import rich\n",
-    "except:\n",
-    "    %pip install rich\n",
-    "    import rich\n",
+    "(sec:hmm-intro)=\n",
+    "# Hidden Markov Models\n",
     "\n",
-    "\n"
+    "In this section, we discuss the\n",
+    "hidden Markov model or HMM,\n",
+    "which is a state space model in which the hidden states\n",
+    "are discrete, so $\\hidden_t \\in \\{1,\\ldots, \\nstates\\}$.\n",
+    "The observations may be discrete,\n",
+    "$\\obs_t \\in \\{1,\\ldots, \\nsymbols\\}$,\n",
+    "or continuous,\n",
+    "$\\obs_t \\in \\real^\\nstates$,\n",
+    "or some combination,\n",
+    "as we illustrate below.\n",
+    "More details can be found in e.g., \n",
+    "{cite}`Rabiner89,Fraser08,Cappe05`.\n",
+    "For an interactive introduction,\n",
+    "see https://nipunbatra.github.io/hmm/."
    ]
   },
   {
@@ -64,21 +43,26 @@
     "import abc\n",
     "from dataclasses import dataclass\n",
     "import functools\n",
+    "from functools import partial\n",
     "import itertools\n",
-    "\n",
-    "from typing import Any, Callable, NamedTuple, Optional, Union, Tuple\n",
-    "\n",
     "import matplotlib.pyplot as plt\n",
     "import numpy as np\n",
-    "\n",
+    "from typing import Any, Callable, NamedTuple, Optional, Union, Tuple\n",
     "\n",
     "import jax\n",
     "import jax.numpy as jnp\n",
     "from jax import lax, vmap, jit, grad\n",
-    "from jax.scipy.special import logit\n",
-    "from jax.nn import softmax\n",
-    "from functools import partial\n",
-    "from jax.random import PRNGKey, split\n",
+    "#from jax.scipy.special import logit\n",
+    "#from jax.nn import softmax\n",
+    "import jax.random as jr\n",
+    "\n",
+    "\n",
+    "\n",
+    "import distrax\n",
+    "import optax\n",
+    "\n",
+    "import jsl\n",
+    "import ssm_jax\n",
     "\n",
     "import inspect\n",
     "import inspect as py_inspect\n",
@@ -94,193 +78,6 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "```{math}\n",
-    "\n",
-    "\\newcommand\\floor[1]{\\lfloor#1\\rfloor}\n",
-    "\n",
-    "\\newcommand{\\real}{\\mathbb{R}}\n",
-    "\n",
-    "% Numbers\n",
-    "\\newcommand{\\vzero}{\\boldsymbol{0}}\n",
-    "\\newcommand{\\vone}{\\boldsymbol{1}}\n",
-    "\n",
-    "% Greek https://www.latex-tutorial.com/symbols/greek-alphabet/\n",
-    "\\newcommand{\\valpha}{\\boldsymbol{\\alpha}}\n",
-    "\\newcommand{\\vbeta}{\\boldsymbol{\\beta}}\n",
-    "\\newcommand{\\vchi}{\\boldsymbol{\\chi}}\n",
-    "\\newcommand{\\vdelta}{\\boldsymbol{\\delta}}\n",
-    "\\newcommand{\\vDelta}{\\boldsymbol{\\Delta}}\n",
-    "\\newcommand{\\vepsilon}{\\boldsymbol{\\epsilon}}\n",
-    "\\newcommand{\\vzeta}{\\boldsymbol{\\zeta}}\n",
-    "\\newcommand{\\vXi}{\\boldsymbol{\\Xi}}\n",
-    "\\newcommand{\\vell}{\\boldsymbol{\\ell}}\n",
-    "\\newcommand{\\veta}{\\boldsymbol{\\eta}}\n",
-    "%\\newcommand{\\vEta}{\\boldsymbol{\\Eta}}\n",
-    "\\newcommand{\\vgamma}{\\boldsymbol{\\gamma}}\n",
-    "\\newcommand{\\vGamma}{\\boldsymbol{\\Gamma}}\n",
-    "\\newcommand{\\vmu}{\\boldsymbol{\\mu}}\n",
-    "\\newcommand{\\vmut}{\\boldsymbol{\\tilde{\\mu}}}\n",
-    "\\newcommand{\\vnu}{\\boldsymbol{\\nu}}\n",
-    "\\newcommand{\\vkappa}{\\boldsymbol{\\kappa}}\n",
-    "\\newcommand{\\vlambda}{\\boldsymbol{\\lambda}}\n",
-    "\\newcommand{\\vLambda}{\\boldsymbol{\\Lambda}}\n",
-    "\\newcommand{\\vLambdaBar}{\\overline{\\vLambda}}\n",
-    "%\\newcommand{\\vnu}{\\boldsymbol{\\nu}}\n",
-    "\\newcommand{\\vomega}{\\boldsymbol{\\omega}}\n",
-    "\\newcommand{\\vOmega}{\\boldsymbol{\\Omega}}\n",
-    "\\newcommand{\\vphi}{\\boldsymbol{\\phi}}\n",
-    "\\newcommand{\\vvarphi}{\\boldsymbol{\\varphi}}\n",
-    "\\newcommand{\\vPhi}{\\boldsymbol{\\Phi}}\n",
-    "\\newcommand{\\vpi}{\\boldsymbol{\\pi}}\n",
-    "\\newcommand{\\vPi}{\\boldsymbol{\\Pi}}\n",
-    "\\newcommand{\\vpsi}{\\boldsymbol{\\psi}}\n",
-    "\\newcommand{\\vPsi}{\\boldsymbol{\\Psi}}\n",
-    "\\newcommand{\\vrho}{\\boldsymbol{\\rho}}\n",
-    "\\newcommand{\\vtheta}{\\boldsymbol{\\theta}}\n",
-    "\\newcommand{\\vthetat}{\\boldsymbol{\\tilde{\\theta}}}\n",
-    "\\newcommand{\\vTheta}{\\boldsymbol{\\Theta}}\n",
-    "\\newcommand{\\vsigma}{\\boldsymbol{\\sigma}}\n",
-    "\\newcommand{\\vSigma}{\\boldsymbol{\\Sigma}}\n",
-    "\\newcommand{\\vSigmat}{\\boldsymbol{\\tilde{\\Sigma}}}\n",
-    "\\newcommand{\\vsigmoid}{\\vsigma}\n",
-    "\\newcommand{\\vtau}{\\boldsymbol{\\tau}}\n",
-    "\\newcommand{\\vxi}{\\boldsymbol{\\xi}}\n",
-    "\n",
-    "\n",
-    "% Lower Roman (Vectors)\n",
-    "\\newcommand{\\va}{\\mathbf{a}}\n",
-    "\\newcommand{\\vb}{\\mathbf{b}}\n",
-    "\\newcommand{\\vBt}{\\mathbf{\\tilde{B}}}\n",
-    "\\newcommand{\\vc}{\\mathbf{c}}\n",
-    "\\newcommand{\\vct}{\\mathbf{\\tilde{c}}}\n",
-    "\\newcommand{\\vd}{\\mathbf{d}}\n",
-    "\\newcommand{\\ve}{\\mathbf{e}}\n",
-    "\\newcommand{\\vf}{\\mathbf{f}}\n",
-    "\\newcommand{\\vg}{\\mathbf{g}}\n",
-    "\\newcommand{\\vh}{\\mathbf{h}}\n",
-    "%\\newcommand{\\myvh}{\\mathbf{h}}\n",
-    "\\newcommand{\\vi}{\\mathbf{i}}\n",
-    "\\newcommand{\\vj}{\\mathbf{j}}\n",
-    "\\newcommand{\\vk}{\\mathbf{k}}\n",
-    "\\newcommand{\\vl}{\\mathbf{l}}\n",
-    "\\newcommand{\\vm}{\\mathbf{m}}\n",
-    "\\newcommand{\\vn}{\\mathbf{n}}\n",
-    "\\newcommand{\\vo}{\\mathbf{o}}\n",
-    "\\newcommand{\\vp}{\\mathbf{p}}\n",
-    "\\newcommand{\\vq}{\\mathbf{q}}\n",
-    "\\newcommand{\\vr}{\\mathbf{r}}\n",
-    "\\newcommand{\\vs}{\\mathbf{s}}\n",
-    "\\newcommand{\\vt}{\\mathbf{t}}\n",
-    "\\newcommand{\\vu}{\\mathbf{u}}\n",
-    "\\newcommand{\\vv}{\\mathbf{v}}\n",
-    "\\newcommand{\\vw}{\\mathbf{w}}\n",
-    "\\newcommand{\\vws}{\\vw_s}\n",
-    "\\newcommand{\\vwt}{\\mathbf{\\tilde{w}}}\n",
-    "\\newcommand{\\vWt}{\\mathbf{\\tilde{W}}}\n",
-    "\\newcommand{\\vwh}{\\hat{\\vw}}\n",
-    "\\newcommand{\\vx}{\\mathbf{x}}\n",
-    "%\\newcommand{\\vx}{\\mathbf{x}}\n",
-    "\\newcommand{\\vxt}{\\mathbf{\\tilde{x}}}\n",
-    "\\newcommand{\\vy}{\\mathbf{y}}\n",
-    "\\newcommand{\\vyt}{\\mathbf{\\tilde{y}}}\n",
-    "\\newcommand{\\vz}{\\mathbf{z}}\n",
-    "%\\newcommand{\\vzt}{\\mathbf{\\tilde{z}}}\n",
-    "\n",
-    "\n",
-    "% Upper Roman (Matrices)\n",
-    "\\newcommand{\\vA}{\\mathbf{A}}\n",
-    "\\newcommand{\\vB}{\\mathbf{B}}\n",
-    "\\newcommand{\\vC}{\\mathbf{C}}\n",
-    "\\newcommand{\\vD}{\\mathbf{D}}\n",
-    "\\newcommand{\\vE}{\\mathbf{E}}\n",
-    "\\newcommand{\\vF}{\\mathbf{F}}\n",
-    "\\newcommand{\\vG}{\\mathbf{G}}\n",
-    "\\newcommand{\\vH}{\\mathbf{H}}\n",
-    "\\newcommand{\\vI}{\\mathbf{I}}\n",
-    "\\newcommand{\\vJ}{\\mathbf{J}}\n",
-    "\\newcommand{\\vK}{\\mathbf{K}}\n",
-    "\\newcommand{\\vL}{\\mathbf{L}}\n",
-    "\\newcommand{\\vM}{\\mathbf{M}}\n",
-    "\\newcommand{\\vMt}{\\mathbf{\\tilde{M}}}\n",
-    "\\newcommand{\\vN}{\\mathbf{N}}\n",
-    "\\newcommand{\\vO}{\\mathbf{O}}\n",
-    "\\newcommand{\\vP}{\\mathbf{P}}\n",
-    "\\newcommand{\\vQ}{\\mathbf{Q}}\n",
-    "\\newcommand{\\vR}{\\mathbf{R}}\n",
-    "\\newcommand{\\vS}{\\mathbf{S}}\n",
-    "\\newcommand{\\vT}{\\mathbf{T}}\n",
-    "\\newcommand{\\vU}{\\mathbf{U}}\n",
-    "\\newcommand{\\vV}{\\mathbf{V}}\n",
-    "\\newcommand{\\vW}{\\mathbf{W}}\n",
-    "\\newcommand{\\vX}{\\mathbf{X}}\n",
-    "%\\newcommand{\\vXs}{\\vX_{\\vs}}\n",
-    "\\newcommand{\\vXs}{\\vX_{s}}\n",
-    "\\newcommand{\\vXt}{\\mathbf{\\tilde{X}}}\n",
-    "\\newcommand{\\vY}{\\mathbf{Y}}\n",
-    "\\newcommand{\\vZ}{\\mathbf{Z}}\n",
-    "\\newcommand{\\vZt}{\\mathbf{\\tilde{Z}}}\n",
-    "\\newcommand{\\vzt}{\\mathbf{\\tilde{z}}}\n",
-    "\n",
-    "\n",
-    "%%%%\n",
-    "\\newcommand{\\hidden}{\\vz}\n",
-    "\\newcommand{\\hid}{\\hidden}\n",
-    "\\newcommand{\\observed}{\\vy}\n",
-    "\\newcommand{\\obs}{\\observed}\n",
-    "\\newcommand{\\inputs}{\\vu}\n",
-    "\\newcommand{\\input}{\\inputs}\n",
-    "\n",
-    "\\newcommand{\\hmmTrans}{\\vA}\n",
-    "\\newcommand{\\hmmObs}{\\vB}\n",
-    "\\newcommand{\\hmmInit}{\\vpi}\n",
-    "\\newcommand{\\hmmhid}{\\hidden}\n",
-    "\\newcommand{\\hmmobs}{\\obs}\n",
-    "\n",
-    "\\newcommand{\\ldsDyn}{\\vA}\n",
-    "\\newcommand{\\ldsObs}{\\vC}\n",
-    "\\newcommand{\\ldsDynIn}{\\vB}\n",
-    "\\newcommand{\\ldsObsIn}{\\vD}\n",
-    "\\newcommand{\\ldsDynNoise}{\\vQ}\n",
-    "\\newcommand{\\ldsObsNoise}{\\vR}\n",
-    "\n",
-    "\\newcommand{\\ssmDynFn}{f}\n",
-    "\\newcommand{\\ssmObsFn}{h}\n",
-    "\n",
-    "\n",
-    "%%%\n",
-    "\\newcommand{\\gauss}{\\mathcal{N}}\n",
-    "\n",
-    "\\newcommand{\\diag}{\\mathrm{diag}}\n",
-    "```\n"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "(sec:hmm-intro)=\n",
-    "# Hidden Markov Models\n",
-    "\n",
-    "In this section, we discuss the\n",
-    "hidden Markov model or HMM,\n",
-    "which is a state space model in which the hidden states\n",
-    "are discrete, so $\\hmmhid_t \\in \\{1,\\ldots, K\\}$.\n",
-    "The observations may be discrete,\n",
-    "$\\hmmobs_t \\in \\{1,\\ldots, C\\}$,\n",
-    "or continuous,\n",
-    "$\\hmmobs_t \\in \\real^D$,\n",
-    "or some combination,\n",
-    "as we illustrate below.\n",
-    "More details can be found in e.g., \n",
-    "{cite}`Rabiner89,Fraser08,Cappe05`.\n",
-    "For an interactive introduction,\n",
-    "see https://nipunbatra.github.io/hmm/."
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
     "(sec:casino)=\n",
     "### Example: Casino HMM\n",
     "\n",
@@ -291,9 +88,9 @@
     "\n",
     "The transition model is denoted by\n",
     "```{math}\n",
-    "p(z_t=j|z_{t-1}=i) = \\hmmTrans_{ij}\n",
+    "p(\\hidden_t=j|\\hidden_{t-1}=i) = \\hmmTrans_{ij}\n",
     "```\n",
-    "Here the $i$'th row of $\\vA$ corresponds to the outgoing distribution from state $i$.\n",
+    "Here the $i$'th row of $\\hmmTrans$ corresponds to the outgoing distribution from state $i$.\n",
     "This is  a row stochastic matrix,\n",
     "meaning each row sums to one.\n",
     "We can visualize\n",
@@ -314,15 +111,15 @@
     "p(\\obs_t=k|\\hidden_t=j) = \\hmmObs_{jk} \n",
     "```\n",
     "This is represented by the histograms associated with each\n",
-    "state in  {numref}`casino-fig`.\n",
+    "state in  {numref}`fig:casino`.\n",
     "\n",
     "Finally,\n",
     "the initial state distribution is denoted by\n",
     "```{math}\n",
-    "p(z_1=j) = \\hmmInit_j\n",
+    "p(\\hidden_1=j) = \\hmmInit_j\n",
     "```\n",
     "\n",
-    "Collectively we denote all the parameters by $\\vtheta=(\\hmmTrans, \\hmmObs, \\hmmInit)$.\n",
+    "Collectively we denote all the parameters by $\\params=(\\hmmTrans, \\hmmObs, \\hmmInit)$.\n",
     "\n",
     "Now let us implement this model in code."
    ]
@@ -412,7 +209,7 @@
     "\n",
     "seed = 314\n",
     "n_samples = 300\n",
-    "z_hist, x_hist = hmm.sample(seed=PRNGKey(seed), seq_len=n_samples)\n",
+    "z_hist, x_hist = hmm.sample(seed=jr.PRNGKey(seed), seq_len=n_samples)\n",
     "\n",
     "z_hist_str = \"\".join((np.array(z_hist) + 1).astype(str))[:60]\n",
     "x_hist_str = \"\".join((np.array(x_hist) + 1).astype(str))[:60]\n",

+ 30 - 244
chapters/ssm/inference.ipynb

@@ -2,52 +2,6 @@
  "cells": [
   {
    "cell_type": "code",
-   "execution_count": 2,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "# meta-data does not work yet in VScode\n",
-    "# https://github.com/microsoft/vscode-jupyter/issues/1121\n",
-    "\n",
-    "{\n",
-    "    \"tags\": [\n",
-    "        \"hide-cell\"\n",
-    "    ]\n",
-    "}\n",
-    "\n",
-    "\n",
-    "### Install necessary libraries\n",
-    "\n",
-    "try:\n",
-    "    import jax\n",
-    "except:\n",
-    "    # For cuda version, see https://github.com/google/jax#installation\n",
-    "    %pip install --upgrade \"jax[cpu]\" \n",
-    "    import jax\n",
-    "\n",
-    "try:\n",
-    "    import distrax\n",
-    "except:\n",
-    "    %pip install --upgrade  distrax\n",
-    "    import distrax\n",
-    "\n",
-    "try:\n",
-    "    import jsl\n",
-    "except:\n",
-    "    %pip install git+https://github.com/probml/jsl\n",
-    "    import jsl\n",
-    "\n",
-    "try:\n",
-    "    import rich\n",
-    "except:\n",
-    "    %pip install rich\n",
-    "    import rich\n",
-    "\n",
-    "\n"
-   ]
-  },
-  {
-   "cell_type": "code",
    "execution_count": 3,
    "metadata": {},
    "outputs": [],
@@ -80,178 +34,8 @@
     "from functools import partial\n",
     "from jax.random import PRNGKey, split\n",
     "\n",
-    "import inspect\n",
-    "import inspect as py_inspect\n",
-    "import rich\n",
-    "from rich import inspect as r_inspect\n",
-    "from rich import print as r_print\n",
-    "\n",
-    "def print_source(fname):\n",
-    "    r_print(py_inspect.getsource(fname))"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "```{math}\n",
-    "\n",
-    "\\newcommand\\floor[1]{\\lfloor#1\\rfloor}\n",
-    "\n",
-    "\\newcommand{\\real}{\\mathbb{R}}\n",
-    "\n",
-    "% Numbers\n",
-    "\\newcommand{\\vzero}{\\boldsymbol{0}}\n",
-    "\\newcommand{\\vone}{\\boldsymbol{1}}\n",
-    "\n",
-    "% Greek https://www.latex-tutorial.com/symbols/greek-alphabet/\n",
-    "\\newcommand{\\valpha}{\\boldsymbol{\\alpha}}\n",
-    "\\newcommand{\\vbeta}{\\boldsymbol{\\beta}}\n",
-    "\\newcommand{\\vchi}{\\boldsymbol{\\chi}}\n",
-    "\\newcommand{\\vdelta}{\\boldsymbol{\\delta}}\n",
-    "\\newcommand{\\vDelta}{\\boldsymbol{\\Delta}}\n",
-    "\\newcommand{\\vepsilon}{\\boldsymbol{\\epsilon}}\n",
-    "\\newcommand{\\vzeta}{\\boldsymbol{\\zeta}}\n",
-    "\\newcommand{\\vXi}{\\boldsymbol{\\Xi}}\n",
-    "\\newcommand{\\vell}{\\boldsymbol{\\ell}}\n",
-    "\\newcommand{\\veta}{\\boldsymbol{\\eta}}\n",
-    "%\\newcommand{\\vEta}{\\boldsymbol{\\Eta}}\n",
-    "\\newcommand{\\vgamma}{\\boldsymbol{\\gamma}}\n",
-    "\\newcommand{\\vGamma}{\\boldsymbol{\\Gamma}}\n",
-    "\\newcommand{\\vmu}{\\boldsymbol{\\mu}}\n",
-    "\\newcommand{\\vmut}{\\boldsymbol{\\tilde{\\mu}}}\n",
-    "\\newcommand{\\vnu}{\\boldsymbol{\\nu}}\n",
-    "\\newcommand{\\vkappa}{\\boldsymbol{\\kappa}}\n",
-    "\\newcommand{\\vlambda}{\\boldsymbol{\\lambda}}\n",
-    "\\newcommand{\\vLambda}{\\boldsymbol{\\Lambda}}\n",
-    "\\newcommand{\\vLambdaBar}{\\overline{\\vLambda}}\n",
-    "%\\newcommand{\\vnu}{\\boldsymbol{\\nu}}\n",
-    "\\newcommand{\\vomega}{\\boldsymbol{\\omega}}\n",
-    "\\newcommand{\\vOmega}{\\boldsymbol{\\Omega}}\n",
-    "\\newcommand{\\vphi}{\\boldsymbol{\\phi}}\n",
-    "\\newcommand{\\vvarphi}{\\boldsymbol{\\varphi}}\n",
-    "\\newcommand{\\vPhi}{\\boldsymbol{\\Phi}}\n",
-    "\\newcommand{\\vpi}{\\boldsymbol{\\pi}}\n",
-    "\\newcommand{\\vPi}{\\boldsymbol{\\Pi}}\n",
-    "\\newcommand{\\vpsi}{\\boldsymbol{\\psi}}\n",
-    "\\newcommand{\\vPsi}{\\boldsymbol{\\Psi}}\n",
-    "\\newcommand{\\vrho}{\\boldsymbol{\\rho}}\n",
-    "\\newcommand{\\vtheta}{\\boldsymbol{\\theta}}\n",
-    "\\newcommand{\\vthetat}{\\boldsymbol{\\tilde{\\theta}}}\n",
-    "\\newcommand{\\vTheta}{\\boldsymbol{\\Theta}}\n",
-    "\\newcommand{\\vsigma}{\\boldsymbol{\\sigma}}\n",
-    "\\newcommand{\\vSigma}{\\boldsymbol{\\Sigma}}\n",
-    "\\newcommand{\\vSigmat}{\\boldsymbol{\\tilde{\\Sigma}}}\n",
-    "\\newcommand{\\vsigmoid}{\\vsigma}\n",
-    "\\newcommand{\\vtau}{\\boldsymbol{\\tau}}\n",
-    "\\newcommand{\\vxi}{\\boldsymbol{\\xi}}\n",
-    "\n",
-    "\n",
-    "% Lower Roman (Vectors)\n",
-    "\\newcommand{\\va}{\\mathbf{a}}\n",
-    "\\newcommand{\\vb}{\\mathbf{b}}\n",
-    "\\newcommand{\\vBt}{\\mathbf{\\tilde{B}}}\n",
-    "\\newcommand{\\vc}{\\mathbf{c}}\n",
-    "\\newcommand{\\vct}{\\mathbf{\\tilde{c}}}\n",
-    "\\newcommand{\\vd}{\\mathbf{d}}\n",
-    "\\newcommand{\\ve}{\\mathbf{e}}\n",
-    "\\newcommand{\\vf}{\\mathbf{f}}\n",
-    "\\newcommand{\\vg}{\\mathbf{g}}\n",
-    "\\newcommand{\\vh}{\\mathbf{h}}\n",
-    "%\\newcommand{\\myvh}{\\mathbf{h}}\n",
-    "\\newcommand{\\vi}{\\mathbf{i}}\n",
-    "\\newcommand{\\vj}{\\mathbf{j}}\n",
-    "\\newcommand{\\vk}{\\mathbf{k}}\n",
-    "\\newcommand{\\vl}{\\mathbf{l}}\n",
-    "\\newcommand{\\vm}{\\mathbf{m}}\n",
-    "\\newcommand{\\vn}{\\mathbf{n}}\n",
-    "\\newcommand{\\vo}{\\mathbf{o}}\n",
-    "\\newcommand{\\vp}{\\mathbf{p}}\n",
-    "\\newcommand{\\vq}{\\mathbf{q}}\n",
-    "\\newcommand{\\vr}{\\mathbf{r}}\n",
-    "\\newcommand{\\vs}{\\mathbf{s}}\n",
-    "\\newcommand{\\vt}{\\mathbf{t}}\n",
-    "\\newcommand{\\vu}{\\mathbf{u}}\n",
-    "\\newcommand{\\vv}{\\mathbf{v}}\n",
-    "\\newcommand{\\vw}{\\mathbf{w}}\n",
-    "\\newcommand{\\vws}{\\vw_s}\n",
-    "\\newcommand{\\vwt}{\\mathbf{\\tilde{w}}}\n",
-    "\\newcommand{\\vWt}{\\mathbf{\\tilde{W}}}\n",
-    "\\newcommand{\\vwh}{\\hat{\\vw}}\n",
-    "\\newcommand{\\vx}{\\mathbf{x}}\n",
-    "%\\newcommand{\\vx}{\\mathbf{x}}\n",
-    "\\newcommand{\\vxt}{\\mathbf{\\tilde{x}}}\n",
-    "\\newcommand{\\vy}{\\mathbf{y}}\n",
-    "\\newcommand{\\vyt}{\\mathbf{\\tilde{y}}}\n",
-    "\\newcommand{\\vz}{\\mathbf{z}}\n",
-    "%\\newcommand{\\vzt}{\\mathbf{\\tilde{z}}}\n",
-    "\n",
-    "\n",
-    "% Upper Roman (Matrices)\n",
-    "\\newcommand{\\vA}{\\mathbf{A}}\n",
-    "\\newcommand{\\vB}{\\mathbf{B}}\n",
-    "\\newcommand{\\vC}{\\mathbf{C}}\n",
-    "\\newcommand{\\vD}{\\mathbf{D}}\n",
-    "\\newcommand{\\vE}{\\mathbf{E}}\n",
-    "\\newcommand{\\vF}{\\mathbf{F}}\n",
-    "\\newcommand{\\vG}{\\mathbf{G}}\n",
-    "\\newcommand{\\vH}{\\mathbf{H}}\n",
-    "\\newcommand{\\vI}{\\mathbf{I}}\n",
-    "\\newcommand{\\vJ}{\\mathbf{J}}\n",
-    "\\newcommand{\\vK}{\\mathbf{K}}\n",
-    "\\newcommand{\\vL}{\\mathbf{L}}\n",
-    "\\newcommand{\\vM}{\\mathbf{M}}\n",
-    "\\newcommand{\\vMt}{\\mathbf{\\tilde{M}}}\n",
-    "\\newcommand{\\vN}{\\mathbf{N}}\n",
-    "\\newcommand{\\vO}{\\mathbf{O}}\n",
-    "\\newcommand{\\vP}{\\mathbf{P}}\n",
-    "\\newcommand{\\vQ}{\\mathbf{Q}}\n",
-    "\\newcommand{\\vR}{\\mathbf{R}}\n",
-    "\\newcommand{\\vS}{\\mathbf{S}}\n",
-    "\\newcommand{\\vT}{\\mathbf{T}}\n",
-    "\\newcommand{\\vU}{\\mathbf{U}}\n",
-    "\\newcommand{\\vV}{\\mathbf{V}}\n",
-    "\\newcommand{\\vW}{\\mathbf{W}}\n",
-    "\\newcommand{\\vX}{\\mathbf{X}}\n",
-    "%\\newcommand{\\vXs}{\\vX_{\\vs}}\n",
-    "\\newcommand{\\vXs}{\\vX_{s}}\n",
-    "\\newcommand{\\vXt}{\\mathbf{\\tilde{X}}}\n",
-    "\\newcommand{\\vY}{\\mathbf{Y}}\n",
-    "\\newcommand{\\vZ}{\\mathbf{Z}}\n",
-    "\\newcommand{\\vZt}{\\mathbf{\\tilde{Z}}}\n",
-    "\\newcommand{\\vzt}{\\mathbf{\\tilde{z}}}\n",
-    "\n",
-    "\n",
-    "%%%%\n",
-    "\\newcommand{\\hidden}{\\vz}\n",
-    "\\newcommand{\\hid}{\\hidden}\n",
-    "\\newcommand{\\observed}{\\vy}\n",
-    "\\newcommand{\\obs}{\\observed}\n",
-    "\\newcommand{\\inputs}{\\vu}\n",
-    "\\newcommand{\\input}{\\inputs}\n",
-    "\n",
-    "\\newcommand{\\hmmTrans}{\\vA}\n",
-    "\\newcommand{\\hmmObs}{\\vB}\n",
-    "\\newcommand{\\hmmInit}{\\vpi}\n",
-    "\\newcommand{\\hmmhid}{\\hidden}\n",
-    "\\newcommand{\\hmmobs}{\\obs}\n",
-    "\n",
-    "\\newcommand{\\ldsDyn}{\\vA}\n",
-    "\\newcommand{\\ldsObs}{\\vC}\n",
-    "\\newcommand{\\ldsDynIn}{\\vB}\n",
-    "\\newcommand{\\ldsObsIn}{\\vD}\n",
-    "\\newcommand{\\ldsDynNoise}{\\vQ}\n",
-    "\\newcommand{\\ldsObsNoise}{\\vR}\n",
-    "\n",
-    "\\newcommand{\\ssmDynFn}{f}\n",
-    "\\newcommand{\\ssmObsFn}{h}\n",
-    "\n",
-    "\n",
-    "%%%\n",
-    "\\newcommand{\\gauss}{\\mathcal{N}}\n",
-    "\n",
-    "\\newcommand{\\diag}{\\mathrm{diag}}\n",
-    "```\n"
+    "import jsl\n",
+    "import ssm_jax\n"
    ]
   },
   {
@@ -259,19 +43,9 @@
    "metadata": {},
    "source": [
     "(sec:inference)=\n",
-    "# Inferential goals\n",
+    "# States estimation (inference)\n",
     "\n",
-    "```{figure} /figures/inference-problems-tikz.png\n",
-    ":scale: 30%\n",
-    ":name: fig:dbn-inference\n",
     "\n",
-    "Illustration of the different kinds of inference in an SSM.\n",
-    " The main kinds of inference for state-space models.\n",
-    "    The shaded region is the interval for which we have data.\n",
-    "    The arrow represents the time step at which we want to perform inference.\n",
-    "    $t$ is the current time,  $T$ is the sequence length,\n",
-    "$\\ell$ is the lag and $h$ is the prediction horizon.\n",
-    "```\n",
     "\n",
     "\n",
     "\n",
@@ -284,41 +58,53 @@
     "there are multiple forms of posterior we may be interested in computing,\n",
     "including the following:\n",
     "- the filtering distribution\n",
-    "$p(\\hmmhid_t|\\hmmobs_{1:t})$\n",
+    "$p(\\hidden_t|\\obs_{1:t})$\n",
     "- the smoothing distribution\n",
-    "$p(\\hmmhid_t|\\hmmobs_{1:T})$ (note that this conditions on future data $T>t$)\n",
+    "$p(\\hidden_t|\\obs_{1:T})$ (note that this conditions on future data $T>t$)\n",
     "- the fixed-lag smoothing distribution\n",
-    "$p(\\hmmhid_{t-\\ell}|\\hmmobs_{1:t})$ (note that this\n",
+    "$p(\\hidden_{t-\\ell}|\\obs_{1:t})$ (note that this\n",
     "infers $\\ell$ steps in the past given data up to the present).\n",
     "\n",
     "We may also want to compute the\n",
     "predictive distribution $h$ steps into the future:\n",
-    "```{math}\n",
-    "p(\\hmmobs_{t+h}|\\hmmobs_{1:t})\n",
-    "= \\sum_{\\hmmhid_{t+h}} p(\\hmmobs_{t+h}|\\hmmhid_{t+h}) p(\\hmmhid_{t+h}|\\hmmobs_{1:t})\n",
-    "```\n",
+    "\\begin{align}\n",
+    "p(\\obs_{t+h}|\\obs_{1:t})\n",
+    "= \\sum_{\\hidden_{t+h}} p(\\obs_{t+h}|\\hidden_{t+h}) p(\\hidden_{t+h}|\\obs_{1:t})\n",
+    "\\end{align}\n",
     "where the hidden state predictive distribution is\n",
     "\\begin{align}\n",
-    "p(\\hmmhid_{t+h}|\\hmmobs_{1:t})\n",
-    "&= \\sum_{\\hmmhid_{t:t+h-1}}\n",
-    " p(\\hmmhid_t|\\hmmobs_{1:t}) \n",
-    " p(\\hmmhid_{t+1}|\\hmmhid_{t})\n",
-    " p(\\hmmhid_{t+2}|\\hmmhid_{t+1})\n",
+    "p(\\hidden_{t+h}|\\obs_{1:t})\n",
+    "&= \\sum_{\\hidden_{t:t+h-1}}\n",
+    " p(\\hidden_t|\\obs_{1:t}) \n",
+    " p(\\hidden_{t+1}|\\hidden_{t})\n",
+    " p(\\hidden_{t+2}|\\hidden_{t+1})\n",
     "\\cdots\n",
-    " p(\\hmmhid_{t+h}|\\hmmhid_{t+h-1})\n",
+    " p(\\hidden_{t+h}|\\hidden_{t+h-1})\n",
     "\\end{align}\n",
     "See \n",
     "{numref}`fig:dbn-inference` for a summary of these distributions.\n",
     "\n",
+    "```{figure} /figures/inference-problems-tikz.png\n",
+    ":scale: 30%\n",
+    ":name: fig:dbn-inference\n",
+    "\n",
+    "Illustration of the different kinds of inference in an SSM.\n",
+    " The main kinds of inference for state-space models.\n",
+    "    The shaded region is the interval for which we have data.\n",
+    "    The arrow represents the time step at which we want to perform inference.\n",
+    "    $t$ is the current time,  $T$ is the sequence length,\n",
+    "$\\ell$ is the lag and $h$ is the prediction horizon.\n",
+    "```\n",
+    "\n",
     "In addition  to comuting posterior marginals,\n",
     "we may want to compute the most probable hidden sequence,\n",
     "i.e., the joint MAP estimate\n",
     "```{math}\n",
-    "\\arg \\max_{\\hmmhid_{1:T}} p(\\hmmhid_{1:T}|\\hmmobs_{1:T})\n",
+    "\\arg \\max_{\\hidden_{1:T}} p(\\hidden_{1:T}|\\obs_{1:T})\n",
     "```\n",
     "or sample sequences from the posterior\n",
     "```{math}\n",
-    "\\hmmhid_{1:T} \\sim p(\\hmmhid_{1:T}|\\hmmobs_{1:T})\n",
+    "\\hidden_{1:T} \\sim p(\\hidden_{1:T}|\\obs_{1:T})\n",
     "```\n",
     "\n",
     "Algorithms for all these task are discussed in the following chapters,\n",

+ 55 - 301
chapters/ssm/lds.ipynb

@@ -2,125 +2,35 @@
  "cells": [
   {
    "cell_type": "code",
-   "execution_count": 1,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "Requirement already satisfied: distrax in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (0.0.1)\n",
-      "Collecting distrax\n",
-      "  Downloading distrax-0.1.2-py3-none-any.whl (272 kB)\n",
-      "\u001b[K     |████████████████████████████████| 272 kB 6.9 MB/s eta 0:00:01\n",
-      "\u001b[?25hRequirement already satisfied: jax>=0.1.55 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from distrax) (0.2.11)\n",
-      "Requirement already satisfied: absl-py>=0.9.0 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from distrax) (0.12.0)\n",
-      "Requirement already satisfied: chex>=0.0.7 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from distrax) (0.0.8)\n",
-      "Requirement already satisfied: jaxlib>=0.1.67 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from distrax) (0.1.70)\n",
-      "Requirement already satisfied: numpy>=1.18.0 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from distrax) (1.19.5)\n",
-      "Collecting tensorflow-probability>=0.15.0\n",
-      "  Using cached tensorflow_probability-0.16.0-py2.py3-none-any.whl (6.3 MB)\n",
-      "Requirement already satisfied: six in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from absl-py>=0.9.0->distrax) (1.15.0)\n",
-      "Requirement already satisfied: dm-tree>=0.1.5 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from chex>=0.0.7->distrax) (0.1.6)\n",
-      "Requirement already satisfied: toolz>=0.9.0 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from chex>=0.0.7->distrax) (0.11.1)\n",
-      "Requirement already satisfied: opt-einsum in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from jax>=0.1.55->distrax) (3.3.0)\n",
-      "Requirement already satisfied: flatbuffers<3.0,>=1.12 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from jaxlib>=0.1.67->distrax) (1.12)\n",
-      "Requirement already satisfied: scipy in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from jaxlib>=0.1.67->distrax) (1.6.3)\n",
-      "Requirement already satisfied: cloudpickle>=1.3 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from tensorflow-probability>=0.15.0->distrax) (1.6.0)\n",
-      "Requirement already satisfied: decorator in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from tensorflow-probability>=0.15.0->distrax) (4.4.2)\n",
-      "Requirement already satisfied: gast>=0.3.2 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from tensorflow-probability>=0.15.0->distrax) (0.4.0)\n",
-      "Installing collected packages: tensorflow-probability, distrax\n",
-      "  Attempting uninstall: tensorflow-probability\n",
-      "    Found existing installation: tensorflow-probability 0.13.0\n",
-      "    Uninstalling tensorflow-probability-0.13.0:\n",
-      "      Successfully uninstalled tensorflow-probability-0.13.0\n",
-      "  Attempting uninstall: distrax\n",
-      "    Found existing installation: distrax 0.0.1\n",
-      "    Uninstalling distrax-0.0.1:\n",
-      "      Successfully uninstalled distrax-0.0.1\n",
-      "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
-      "jsl 0.0.0 requires dataclasses, which is not installed.\u001b[0m\n",
-      "Successfully installed distrax-0.1.2 tensorflow-probability-0.16.0\n",
-      "\u001b[33mWARNING: You are using pip version 21.2.4; however, version 22.0.4 is available.\n",
-      "You should consider upgrading via the '/opt/anaconda3/envs/spyder-dev/bin/python -m pip install --upgrade pip' command.\u001b[0m\n",
-      "Note: you may need to restart the kernel to use updated packages.\n"
-     ]
-    }
-   ],
-   "source": [
-    "# meta-data does not work yet in VScode\n",
-    "# https://github.com/microsoft/vscode-jupyter/issues/1121\n",
-    "\n",
-    "{\n",
-    "    \"tags\": [\n",
-    "        \"hide-cell\"\n",
-    "    ]\n",
-    "}\n",
-    "\n",
-    "\n",
-    "### Install necessary libraries\n",
-    "\n",
-    "try:\n",
-    "    import jax\n",
-    "except:\n",
-    "    # For cuda version, see https://github.com/google/jax#installation\n",
-    "    %pip install --upgrade \"jax[cpu]\" \n",
-    "    import jax\n",
-    "\n",
-    "try:\n",
-    "    import distrax\n",
-    "except:\n",
-    "    %pip install --upgrade  distrax\n",
-    "    import distrax\n",
-    "\n",
-    "try:\n",
-    "    import jsl\n",
-    "except:\n",
-    "    %pip install git+https://github.com/probml/jsl\n",
-    "    import jsl\n",
-    "\n",
-    "try:\n",
-    "    import rich\n",
-    "except:\n",
-    "    %pip install rich\n",
-    "    import rich\n",
-    "\n",
-    "\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 2,
+   "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
-    "{\n",
-    "    \"tags\": [\n",
-    "        \"hide-cell\"\n",
-    "    ]\n",
-    "}\n",
-    "\n",
-    "\n",
     "### Import standard libraries\n",
     "\n",
     "import abc\n",
     "from dataclasses import dataclass\n",
     "import functools\n",
+    "from functools import partial\n",
     "import itertools\n",
-    "\n",
-    "from typing import Any, Callable, NamedTuple, Optional, Union, Tuple\n",
-    "\n",
     "import matplotlib.pyplot as plt\n",
     "import numpy as np\n",
-    "\n",
+    "from typing import Any, Callable, NamedTuple, Optional, Union, Tuple\n",
     "\n",
     "import jax\n",
     "import jax.numpy as jnp\n",
     "from jax import lax, vmap, jit, grad\n",
-    "from jax.scipy.special import logit\n",
-    "from jax.nn import softmax\n",
-    "from functools import partial\n",
-    "from jax.random import PRNGKey, split\n",
+    "#from jax.scipy.special import logit\n",
+    "#from jax.nn import softmax\n",
+    "import jax.random as jr\n",
+    "\n",
+    "\n",
+    "\n",
+    "import distrax\n",
+    "import optax\n",
+    "\n",
+    "import jsl\n",
+    "import ssm_jax\n",
     "\n",
     "import inspect\n",
     "import inspect as py_inspect\n",
@@ -136,170 +46,6 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "```{math}\n",
-    "\n",
-    "\\newcommand\\floor[1]{\\lfloor#1\\rfloor}\n",
-    "\n",
-    "\\newcommand{\\real}{\\mathbb{R}}\n",
-    "\n",
-    "% Numbers\n",
-    "\\newcommand{\\vzero}{\\boldsymbol{0}}\n",
-    "\\newcommand{\\vone}{\\boldsymbol{1}}\n",
-    "\n",
-    "% Greek https://www.latex-tutorial.com/symbols/greek-alphabet/\n",
-    "\\newcommand{\\valpha}{\\boldsymbol{\\alpha}}\n",
-    "\\newcommand{\\vbeta}{\\boldsymbol{\\beta}}\n",
-    "\\newcommand{\\vchi}{\\boldsymbol{\\chi}}\n",
-    "\\newcommand{\\vdelta}{\\boldsymbol{\\delta}}\n",
-    "\\newcommand{\\vDelta}{\\boldsymbol{\\Delta}}\n",
-    "\\newcommand{\\vepsilon}{\\boldsymbol{\\epsilon}}\n",
-    "\\newcommand{\\vzeta}{\\boldsymbol{\\zeta}}\n",
-    "\\newcommand{\\vXi}{\\boldsymbol{\\Xi}}\n",
-    "\\newcommand{\\vell}{\\boldsymbol{\\ell}}\n",
-    "\\newcommand{\\veta}{\\boldsymbol{\\eta}}\n",
-    "%\\newcommand{\\vEta}{\\boldsymbol{\\Eta}}\n",
-    "\\newcommand{\\vgamma}{\\boldsymbol{\\gamma}}\n",
-    "\\newcommand{\\vGamma}{\\boldsymbol{\\Gamma}}\n",
-    "\\newcommand{\\vmu}{\\boldsymbol{\\mu}}\n",
-    "\\newcommand{\\vmut}{\\boldsymbol{\\tilde{\\mu}}}\n",
-    "\\newcommand{\\vnu}{\\boldsymbol{\\nu}}\n",
-    "\\newcommand{\\vkappa}{\\boldsymbol{\\kappa}}\n",
-    "\\newcommand{\\vlambda}{\\boldsymbol{\\lambda}}\n",
-    "\\newcommand{\\vLambda}{\\boldsymbol{\\Lambda}}\n",
-    "\\newcommand{\\vLambdaBar}{\\overline{\\vLambda}}\n",
-    "%\\newcommand{\\vnu}{\\boldsymbol{\\nu}}\n",
-    "\\newcommand{\\vomega}{\\boldsymbol{\\omega}}\n",
-    "\\newcommand{\\vOmega}{\\boldsymbol{\\Omega}}\n",
-    "\\newcommand{\\vphi}{\\boldsymbol{\\phi}}\n",
-    "\\newcommand{\\vvarphi}{\\boldsymbol{\\varphi}}\n",
-    "\\newcommand{\\vPhi}{\\boldsymbol{\\Phi}}\n",
-    "\\newcommand{\\vpi}{\\boldsymbol{\\pi}}\n",
-    "\\newcommand{\\vPi}{\\boldsymbol{\\Pi}}\n",
-    "\\newcommand{\\vpsi}{\\boldsymbol{\\psi}}\n",
-    "\\newcommand{\\vPsi}{\\boldsymbol{\\Psi}}\n",
-    "\\newcommand{\\vrho}{\\boldsymbol{\\rho}}\n",
-    "\\newcommand{\\vtheta}{\\boldsymbol{\\theta}}\n",
-    "\\newcommand{\\vthetat}{\\boldsymbol{\\tilde{\\theta}}}\n",
-    "\\newcommand{\\vTheta}{\\boldsymbol{\\Theta}}\n",
-    "\\newcommand{\\vsigma}{\\boldsymbol{\\sigma}}\n",
-    "\\newcommand{\\vSigma}{\\boldsymbol{\\Sigma}}\n",
-    "\\newcommand{\\vSigmat}{\\boldsymbol{\\tilde{\\Sigma}}}\n",
-    "\\newcommand{\\vsigmoid}{\\vsigma}\n",
-    "\\newcommand{\\vtau}{\\boldsymbol{\\tau}}\n",
-    "\\newcommand{\\vxi}{\\boldsymbol{\\xi}}\n",
-    "\n",
-    "\n",
-    "% Lower Roman (Vectors)\n",
-    "\\newcommand{\\va}{\\mathbf{a}}\n",
-    "\\newcommand{\\vb}{\\mathbf{b}}\n",
-    "\\newcommand{\\vBt}{\\mathbf{\\tilde{B}}}\n",
-    "\\newcommand{\\vc}{\\mathbf{c}}\n",
-    "\\newcommand{\\vct}{\\mathbf{\\tilde{c}}}\n",
-    "\\newcommand{\\vd}{\\mathbf{d}}\n",
-    "\\newcommand{\\ve}{\\mathbf{e}}\n",
-    "\\newcommand{\\vf}{\\mathbf{f}}\n",
-    "\\newcommand{\\vg}{\\mathbf{g}}\n",
-    "\\newcommand{\\vh}{\\mathbf{h}}\n",
-    "%\\newcommand{\\myvh}{\\mathbf{h}}\n",
-    "\\newcommand{\\vi}{\\mathbf{i}}\n",
-    "\\newcommand{\\vj}{\\mathbf{j}}\n",
-    "\\newcommand{\\vk}{\\mathbf{k}}\n",
-    "\\newcommand{\\vl}{\\mathbf{l}}\n",
-    "\\newcommand{\\vm}{\\mathbf{m}}\n",
-    "\\newcommand{\\vn}{\\mathbf{n}}\n",
-    "\\newcommand{\\vo}{\\mathbf{o}}\n",
-    "\\newcommand{\\vp}{\\mathbf{p}}\n",
-    "\\newcommand{\\vq}{\\mathbf{q}}\n",
-    "\\newcommand{\\vr}{\\mathbf{r}}\n",
-    "\\newcommand{\\vs}{\\mathbf{s}}\n",
-    "\\newcommand{\\vt}{\\mathbf{t}}\n",
-    "\\newcommand{\\vu}{\\mathbf{u}}\n",
-    "\\newcommand{\\vv}{\\mathbf{v}}\n",
-    "\\newcommand{\\vw}{\\mathbf{w}}\n",
-    "\\newcommand{\\vws}{\\vw_s}\n",
-    "\\newcommand{\\vwt}{\\mathbf{\\tilde{w}}}\n",
-    "\\newcommand{\\vWt}{\\mathbf{\\tilde{W}}}\n",
-    "\\newcommand{\\vwh}{\\hat{\\vw}}\n",
-    "\\newcommand{\\vx}{\\mathbf{x}}\n",
-    "%\\newcommand{\\vx}{\\mathbf{x}}\n",
-    "\\newcommand{\\vxt}{\\mathbf{\\tilde{x}}}\n",
-    "\\newcommand{\\vy}{\\mathbf{y}}\n",
-    "\\newcommand{\\vyt}{\\mathbf{\\tilde{y}}}\n",
-    "\\newcommand{\\vz}{\\mathbf{z}}\n",
-    "%\\newcommand{\\vzt}{\\mathbf{\\tilde{z}}}\n",
-    "\n",
-    "\n",
-    "% Upper Roman (Matrices)\n",
-    "\\newcommand{\\vA}{\\mathbf{A}}\n",
-    "\\newcommand{\\vB}{\\mathbf{B}}\n",
-    "\\newcommand{\\vC}{\\mathbf{C}}\n",
-    "\\newcommand{\\vD}{\\mathbf{D}}\n",
-    "\\newcommand{\\vE}{\\mathbf{E}}\n",
-    "\\newcommand{\\vF}{\\mathbf{F}}\n",
-    "\\newcommand{\\vG}{\\mathbf{G}}\n",
-    "\\newcommand{\\vH}{\\mathbf{H}}\n",
-    "\\newcommand{\\vI}{\\mathbf{I}}\n",
-    "\\newcommand{\\vJ}{\\mathbf{J}}\n",
-    "\\newcommand{\\vK}{\\mathbf{K}}\n",
-    "\\newcommand{\\vL}{\\mathbf{L}}\n",
-    "\\newcommand{\\vM}{\\mathbf{M}}\n",
-    "\\newcommand{\\vMt}{\\mathbf{\\tilde{M}}}\n",
-    "\\newcommand{\\vN}{\\mathbf{N}}\n",
-    "\\newcommand{\\vO}{\\mathbf{O}}\n",
-    "\\newcommand{\\vP}{\\mathbf{P}}\n",
-    "\\newcommand{\\vQ}{\\mathbf{Q}}\n",
-    "\\newcommand{\\vR}{\\mathbf{R}}\n",
-    "\\newcommand{\\vS}{\\mathbf{S}}\n",
-    "\\newcommand{\\vT}{\\mathbf{T}}\n",
-    "\\newcommand{\\vU}{\\mathbf{U}}\n",
-    "\\newcommand{\\vV}{\\mathbf{V}}\n",
-    "\\newcommand{\\vW}{\\mathbf{W}}\n",
-    "\\newcommand{\\vX}{\\mathbf{X}}\n",
-    "%\\newcommand{\\vXs}{\\vX_{\\vs}}\n",
-    "\\newcommand{\\vXs}{\\vX_{s}}\n",
-    "\\newcommand{\\vXt}{\\mathbf{\\tilde{X}}}\n",
-    "\\newcommand{\\vY}{\\mathbf{Y}}\n",
-    "\\newcommand{\\vZ}{\\mathbf{Z}}\n",
-    "\\newcommand{\\vZt}{\\mathbf{\\tilde{Z}}}\n",
-    "\\newcommand{\\vzt}{\\mathbf{\\tilde{z}}}\n",
-    "\n",
-    "\n",
-    "%%%%\n",
-    "\\newcommand{\\hidden}{\\vz}\n",
-    "\\newcommand{\\hid}{\\hidden}\n",
-    "\\newcommand{\\observed}{\\vy}\n",
-    "\\newcommand{\\obs}{\\observed}\n",
-    "\\newcommand{\\inputs}{\\vu}\n",
-    "\\newcommand{\\input}{\\inputs}\n",
-    "\n",
-    "\\newcommand{\\hmmTrans}{\\vA}\n",
-    "\\newcommand{\\hmmObs}{\\vB}\n",
-    "\\newcommand{\\hmmInit}{\\vpi}\n",
-    "\\newcommand{\\hmmhid}{\\hidden}\n",
-    "\\newcommand{\\hmmobs}{\\obs}\n",
-    "\n",
-    "\\newcommand{\\ldsDyn}{\\vA}\n",
-    "\\newcommand{\\ldsObs}{\\vC}\n",
-    "\\newcommand{\\ldsDynIn}{\\vB}\n",
-    "\\newcommand{\\ldsObsIn}{\\vD}\n",
-    "\\newcommand{\\ldsDynNoise}{\\vQ}\n",
-    "\\newcommand{\\ldsObsNoise}{\\vR}\n",
-    "\n",
-    "\\newcommand{\\ssmDynFn}{f}\n",
-    "\\newcommand{\\ssmObsFn}{h}\n",
-    "\n",
-    "\n",
-    "%%%\n",
-    "\\newcommand{\\gauss}{\\mathcal{N}}\n",
-    "\n",
-    "\\newcommand{\\diag}{\\mathrm{diag}}\n",
-    "```\n"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
     "(sec:lds-intro)=\n",
     "# Linear Gaussian SSMs\n",
     "\n",
@@ -310,50 +56,50 @@
     "hidden states and inputs (i.e. there are no auto-regressive dependencies\n",
     "between the observables).\n",
     "We can rewrite this model as \n",
-    "a stochastic nonlinear dynamical system (NLDS)\n",
+    "a stochastic $\\keyword{nonlinear dynamical system}$ or $\\keyword{NLDS}$\n",
     "by defining the distribution of the next hidden state \n",
     "as a deterministic function of the past state\n",
-    "plus random process noise $\\vepsilon_t$ \n",
+    "plus random $\\keyword{process noise}$ $\\transNoise_t$ \n",
     "\\begin{align}\n",
-    "\\hmmhid_t &= \\ssmDynFn(\\hmmhid_{t-1}, \\inputs_t, \\vepsilon_t)  \n",
+    "\\hidden_t &= \\dynamicsFn(\\hidden_{t-1}, \\inputs_t, \\transNoise_t)  \n",
     "\\end{align}\n",
-    "where $\\vepsilon_t$ is drawn from the distribution such\n",
+    "where $\\transNoise_t$ is drawn from the distribution such\n",
     "that the induced distribution\n",
-    "on $\\hmmhid_t$ matches $p(\\hmmhid_t|\\hmmhid_{t-1}, \\inputs_t)$.\n",
+    "on $\\hidden_t$ matches $p(\\hidden_t|\\hidden_{t-1}, \\inputs_t)$.\n",
     "Similarly we can rewrite the observation distributions\n",
     "as a deterministic function of the hidden state\n",
-    "plus observation noise $\\veta_t$:\n",
+    "plus $\\keyword{observation noise}$ $\\obsNoise_t$:\n",
     "\\begin{align}\n",
-    "\\hmmobs_t &= \\ssmObsFn(\\hmmhid_{t}, \\inputs_t, \\veta_t)\n",
+    "\\obs_t &= \\measurementFn(\\hidden_{t}, \\inputs_t, \\obsNoise_t)\n",
     "\\end{align}\n",
     "\n",
     "\n",
     "If we assume additive Gaussian noise,\n",
     "the model becomes\n",
     "\\begin{align}\n",
-    "\\hmmhid_t &= \\ssmDynFn(\\hmmhid_{t-1}, \\inputs_t) +  \\vepsilon_t  \\\\\n",
-    "\\hmmobs_t &= \\ssmObsFn(\\hmmhid_{t}, \\inputs_t) + \\veta_t\n",
+    "\\hidden_t &= \\dynamicsFn(\\hidden_{t-1}, \\inputs_t) +  \\transNoise_t  \\\\\n",
+    "\\obs_t &= \\measurementFn(\\hidden_{t}, \\inputs_t) + \\obsNoise_t\n",
     "\\end{align}\n",
-    "where $\\vepsilon_t \\sim \\gauss(\\vzero,\\vQ_t)$\n",
-    "and $\\veta_t \\sim \\gauss(\\vzero,\\vR_t)$.\n",
-    "We will call these Gaussian SSMs.\n",
+    "where $\\transNoise_t \\sim \\gauss(\\vzero,\\transCov_t)$\n",
+    "and $\\obsNoise_t \\sim \\gauss(\\vzero,\\obsCov_t)$.\n",
+    "We will call these $\\keyword{Gaussian SSMs}$.\n",
     "\n",
     "If we additionally assume\n",
-    "the transition function $\\ssmDynFn$\n",
-    "and the observation function $\\ssmObsFn$ are both linear,\n",
+    "the transition function $\\dynamicsFn$\n",
+    "and the observation function $\\measurementFn$ are both linear,\n",
     "then we can rewrite the model as follows:\n",
     "\\begin{align}\n",
-    "p(\\hmmhid_t|\\hmmhid_{t-1},\\inputs_t) &= \\gauss(\\hmmhid_t|\\ldsDyn_t \\hmmhid_{t-1}\n",
-    "+ \\ldsDynIn_t \\inputs_t, \\vQ_t)\n",
+    "p(\\hidden_t|\\hidden_{t-1},\\inputs_t) &= \\gauss(\\hidden_t|\\ldsDyn_t \\hidden_{t-1}\n",
+    "+ \\ldsDynIn_t \\inputs_t, \\transCov_t)\n",
     "\\\\\n",
-    "p(\\hmmobs_t|\\hmmhid_t,\\inputs_t) &= \\gauss(\\hmmobs_t|\\ldsObs_t \\hmmhid_{t}\n",
-    "+ \\ldsObsIn_t \\inputs_t, \\vR_t)\n",
+    "p(\\obs_t|\\hidden_t,\\inputs_t) &= \\gauss(\\obs_t|\\ldsObs_t \\hidden_{t}\n",
+    "+ \\ldsObsIn_t \\inputs_t, \\obsCov_t)\n",
     "\\end{align}\n",
     "This is called a \n",
-    "linear-Gaussian state space model\n",
-    "(LG-SSM),\n",
-    "or a\n",
-    "linear dynamical system (LDS).\n",
+    "$\\keyword{linear-Gaussian state space model}\n",
+    "or $\\keyword{LG-SSM}$;\n",
+    "it is also called \n",
+    "a $\\keyword{linear dynamical system}$ or $\\keyword{LDS}$.\n",
     "We usually assume the parameters are independent of time, in which case\n",
     "the model is said to be time-invariant or homogeneous.\n"
    ]
@@ -372,7 +118,7 @@
     "Consider an object moving in $\\real^2$.\n",
     "Let the state be\n",
     "the position and velocity of the object,\n",
-    "$\\vz_t =\\begin{pmatrix} u_t & \\dot{u}_t & v_t & \\dot{v}_t \\end{pmatrix}$.\n",
+    "$\\hidden_t =\\begin{pmatrix} u_t & \\dot{u}_t & v_t & \\dot{v}_t \\end{pmatrix}$.\n",
     "(We use $u$ and $v$ for the two coordinates,\n",
     "to avoid confusion with the state and observation variables.)\n",
     "If we use Euler discretization,\n",
@@ -390,9 +136,9 @@
     "}_{\\ldsDyn}\n",
     "\\\n",
     "\\underbrace{\\begin{pmatrix} u_{t-1} \\\\ \\dot{u}_{t-1} \\\\ v_{t-1} \\\\ \\dot{v}_{t-1} \\end{pmatrix}}_{\\vz_{t-1}}\n",
-    "+ \\vepsilon_t\n",
+    "+ \\transNoise_t\n",
     "\\end{align}\n",
-    "where $\\vepsilon_t \\sim \\gauss(\\vzero,\\vQ)$ is\n",
+    "where $\\transNoise_t \\sim \\gauss(\\vzero,\\transCov)$ is\n",
     "the process noise.\n",
     "\n",
     "Let us assume\n",
@@ -401,7 +147,7 @@
     "of the state, but not to the location.\n",
     "(This is known as a random accelerations model.)\n",
     "We can approximate the resulting process in discrete time by assuming\n",
-    "$\\vQ = \\diag(0, q, 0, q)$.\n",
+    "$\\transCov = \\diag(0, q, 0, q)$.\n",
     "(See  {cite}`Sarkka13` p60 for a more accurate way\n",
     "to convert the continuous time process to discrete time.)\n",
     "\n",
@@ -411,7 +157,7 @@
     "corrupted by  Gaussian noise.\n",
     "Thus the observation model becomes\n",
     "\\begin{align}\n",
-    "\\underbrace{\\begin{pmatrix}  y_{1,t} \\\\  y_{2,t} \\end{pmatrix}}_{\\vy_t}\n",
+    "\\underbrace{\\begin{pmatrix}  \\obs_{1,t} \\\\  \\obs_{2,t} \\end{pmatrix}}_{\\obs_t}\n",
     "  &=\n",
     "    \\underbrace{\n",
     "    \\begin{pmatrix}\n",
@@ -420,18 +166,18 @@
     "    \\end{pmatrix}\n",
     "    }_{\\ldsObs}\n",
     "    \\\n",
-    "\\underbrace{\\begin{pmatrix} u_t\\\\ \\dot{u}_t \\\\ v_t \\\\ \\dot{v}_t \\end{pmatrix}}_{\\vz_t}    \n",
-    " + \\veta_t\n",
+    "\\underbrace{\\begin{pmatrix} u_t\\\\ \\dot{u}_t \\\\ v_t \\\\ \\dot{v}_t \\end{pmatrix}}_{\\hidden_t}    \n",
+    " + \\obsNoise_t\n",
     "\\end{align}\n",
-    "where $\\veta_t \\sim \\gauss(\\vzero,\\vR)$ is the \\keywordDef{observation noise}.\n",
+    "where $\\obsNoise_t \\sim \\gauss(\\vzero,\\obsCov)$ is the \\keywordDef{observation noise}.\n",
     "We see that the observation matrix $\\ldsObs$ simply ``extracts'' the\n",
     "relevant parts  of the state vector.\n",
     "\n",
     "Suppose we sample a trajectory and corresponding set\n",
     "of noisy observations from this model,\n",
-    "$(\\vz_{1:T}, \\vy_{1:T}) \\sim p(\\vz,\\vy|\\vtheta)$.\n",
+    "$(\\hidden_{1:T}, \\obs_{1:T}) \\sim p(\\hidden,\\obs|\\params)$.\n",
     "(We use diagonal observation noise,\n",
-    "$\\vR = \\diag(\\sigma_1^2, \\sigma_2^2)$.)\n",
+    "$\\obsCov = \\diag(\\sigma_1^2, \\sigma_2^2)$.)\n",
     "The results are shown below. \n"
    ]
   },
@@ -561,7 +307,15 @@
    "metadata": {},
    "source": [
     "The main task is to infer the hidden states given the noisy\n",
-    "observations, i.e., $p(\\vz|\\vy,\\vtheta)$. We discuss the topic of inference in {ref}`sec:inference`."
+    "observations, i.e., $p(\\hidden_t|\\obs_{1:t},\\params)$\n",
+    "or $p(\\hidden_t|\\obs_{1:T}, \\params)$ in the offline case.\n",
+    "We discuss the topic of inference in {ref}`sec:inference`.\n",
+    "We will usually represent this belief state by a Gaussian distribution,\n",
+    "$p(\\hidden_t|\\obs_{1:t},\\params) = \\gauss(\\hidden_t|\\mean_{t|t}, \\covMat_{t|t})$\n",
+    "and\n",
+    "$p(\\hidden_t|\\obs_{1:T},\\params) = \\gauss(\\hidden_t|\\mean_{t|T}, \\covMat_{t|T})$.\n",
+    "Sometimes we use information form, \n",
+    "$p(\\hidden_t|\\obs_{1:t},\\params) = \\gaussInfo(\\hidden_t|\\mean_{t|t}, \\covMat_{t|t})$"
    ]
   }
  ],

+ 116 - 0
chapters/ssm/learning.ipynb

@@ -0,0 +1,116 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "### Import standard libraries\n",
+    "\n",
+    "import abc\n",
+    "from dataclasses import dataclass\n",
+    "import functools\n",
+    "from functools import partial\n",
+    "import itertools\n",
+    "import matplotlib.pyplot as plt\n",
+    "import numpy as np\n",
+    "from typing import Any, Callable, NamedTuple, Optional, Union, Tuple\n",
+    "\n",
+    "import jax\n",
+    "import jax.numpy as jnp\n",
+    "from jax import lax, vmap, jit, grad\n",
+    "#from jax.scipy.special import logit\n",
+    "#from jax.nn import softmax\n",
+    "import jax.random as jr\n",
+    "\n",
+    "\n",
+    "\n",
+    "import distrax\n",
+    "import optax\n",
+    "\n",
+    "import jsl\n",
+    "import ssm_jax\n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "(sec:learning)=\n",
+    "# Parameter estimation (learning)\n",
+    "\n",
+    "\n",
+    "So far, we have assumed that the parameters $\\params$ of the SSM are known.\n",
+    "For example, in the case of an HMM with categorical observations\n",
+    "we have $\\params = (\\hmmInit, \\hmmTrans, \\hmmObs)$,\n",
+    "and in the case of an LDS, we have $\\params = (\\ldsTrans, \\ldsObs, \\ldsTransNoise, \\ldsObsNoise)$.\n",
+    "(We also need to specify the initial state distribution, $(\\ldsInitMean, \\ldsInitCov)$.)\n",
+    "If we adopt a Bayesian perspective, we can view these parameters as random variables that are\n",
+    "shared across all time steps, and across all sequences.\n",
+    "This is shown in {numref}`fig:hmm-plates`, where we adopt $\\keyword{plate notation}$\n",
+    "to represent repetitive structure.\n",
+    "\n",
+    "```{figure} /figures/hmmDgmPlatesY.png\n",
+    ":scale: 100%\n",
+    ":name: fig:hmm-plates\n",
+    "\n",
+    "Illustration of an HMM using plate notation, where we show the parameter\n",
+    "nodes which are shared across all the sequences.\n",
+    "```\n",
+    "\n",
+    "!!!\n",
+    "Suppose we observe $N$ sequences $\\data = \\{\\obs_{n,1:T_n}: n=1:N\\}$.\n",
+    "Then the goal of $\\keyword{parameter estimation}$, also called $\\keyword{model learning}$\n",
+    "or $\\keyword{model fitting}$, is to approximate the posterior\n",
+    "\\begin{align}\n",
+    "p(\\params|\\data) \\propto p(\\params) \\prod_{n=1}^N p(\\obs_{n,1:T_n} | \\params)\n",
+    "\\end{align}\n",
+    "where $p(\\obs_{n,1:T_n} | \\params)$ is the marginal likelihood of sequence $n$:\n",
+    "\\begin{align}\n",
+    "p(\\obs_{1:T} | \\params) = \\int  p(\\hidden_{1:T}, \\obs_{1:T} | \\params) d\\hidden_{1:T}\n",
+    "\\end{align}\n",
+    "\n",
+    "Since computing the full posterior is computationally difficult, we often settle for computing\n",
+    "a point estimate such as the MAP (maximum a posterior) estimate\n",
+    "\\begin{align}\n",
+    "\\params_{\\map} = \\arg \\max_{\\params} \\log p(\\params) + \\sum_{n=1}^N \\log p(\\obs_{n,1:T_n} | \\params)\n",
+    "\\end{align}\n",
+    "If we ignore the prior term, we get the maximum likelihood estimate or MLE:\n",
+    "\\begin{align}\n",
+    "\\params_{\\mle} = \\arg \\max_{\\params}  \\sum_{n=1}^N \\log p(\\obs_{n,1:T_n} | \\params)\n",
+    "\\end{align}\n",
+    "In practice, the MAP estimate often works better than the MLE, since the prior can regularize\n",
+    "the estimate to ensure the model is numerically stable and does not overfit the training set.\n",
+    "\n",
+    "We will discuss a variety of algorithms for parameter estimation in later chapters.\n",
+    "\n"
+   ]
+  }
+ ],
+ "metadata": {
+  "interpreter": {
+   "hash": "6407c60499271029b671b4ff687c4ed4626355c45fd34c44476827f4be42c4d7"
+  },
+  "kernelspec": {
+   "display_name": "Python 3.9.2 ('spyder-dev')",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.9.2"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}

+ 0 - 164
chapters/ssm/nlds.ipynb

@@ -136,170 +136,6 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "```{math}\n",
-    "\n",
-    "\\newcommand\\floor[1]{\\lfloor#1\\rfloor}\n",
-    "\n",
-    "\\newcommand{\\real}{\\mathbb{R}}\n",
-    "\n",
-    "% Numbers\n",
-    "\\newcommand{\\vzero}{\\boldsymbol{0}}\n",
-    "\\newcommand{\\vone}{\\boldsymbol{1}}\n",
-    "\n",
-    "% Greek https://www.latex-tutorial.com/symbols/greek-alphabet/\n",
-    "\\newcommand{\\valpha}{\\boldsymbol{\\alpha}}\n",
-    "\\newcommand{\\vbeta}{\\boldsymbol{\\beta}}\n",
-    "\\newcommand{\\vchi}{\\boldsymbol{\\chi}}\n",
-    "\\newcommand{\\vdelta}{\\boldsymbol{\\delta}}\n",
-    "\\newcommand{\\vDelta}{\\boldsymbol{\\Delta}}\n",
-    "\\newcommand{\\vepsilon}{\\boldsymbol{\\epsilon}}\n",
-    "\\newcommand{\\vzeta}{\\boldsymbol{\\zeta}}\n",
-    "\\newcommand{\\vXi}{\\boldsymbol{\\Xi}}\n",
-    "\\newcommand{\\vell}{\\boldsymbol{\\ell}}\n",
-    "\\newcommand{\\veta}{\\boldsymbol{\\eta}}\n",
-    "%\\newcommand{\\vEta}{\\boldsymbol{\\Eta}}\n",
-    "\\newcommand{\\vgamma}{\\boldsymbol{\\gamma}}\n",
-    "\\newcommand{\\vGamma}{\\boldsymbol{\\Gamma}}\n",
-    "\\newcommand{\\vmu}{\\boldsymbol{\\mu}}\n",
-    "\\newcommand{\\vmut}{\\boldsymbol{\\tilde{\\mu}}}\n",
-    "\\newcommand{\\vnu}{\\boldsymbol{\\nu}}\n",
-    "\\newcommand{\\vkappa}{\\boldsymbol{\\kappa}}\n",
-    "\\newcommand{\\vlambda}{\\boldsymbol{\\lambda}}\n",
-    "\\newcommand{\\vLambda}{\\boldsymbol{\\Lambda}}\n",
-    "\\newcommand{\\vLambdaBar}{\\overline{\\vLambda}}\n",
-    "%\\newcommand{\\vnu}{\\boldsymbol{\\nu}}\n",
-    "\\newcommand{\\vomega}{\\boldsymbol{\\omega}}\n",
-    "\\newcommand{\\vOmega}{\\boldsymbol{\\Omega}}\n",
-    "\\newcommand{\\vphi}{\\boldsymbol{\\phi}}\n",
-    "\\newcommand{\\vvarphi}{\\boldsymbol{\\varphi}}\n",
-    "\\newcommand{\\vPhi}{\\boldsymbol{\\Phi}}\n",
-    "\\newcommand{\\vpi}{\\boldsymbol{\\pi}}\n",
-    "\\newcommand{\\vPi}{\\boldsymbol{\\Pi}}\n",
-    "\\newcommand{\\vpsi}{\\boldsymbol{\\psi}}\n",
-    "\\newcommand{\\vPsi}{\\boldsymbol{\\Psi}}\n",
-    "\\newcommand{\\vrho}{\\boldsymbol{\\rho}}\n",
-    "\\newcommand{\\vtheta}{\\boldsymbol{\\theta}}\n",
-    "\\newcommand{\\vthetat}{\\boldsymbol{\\tilde{\\theta}}}\n",
-    "\\newcommand{\\vTheta}{\\boldsymbol{\\Theta}}\n",
-    "\\newcommand{\\vsigma}{\\boldsymbol{\\sigma}}\n",
-    "\\newcommand{\\vSigma}{\\boldsymbol{\\Sigma}}\n",
-    "\\newcommand{\\vSigmat}{\\boldsymbol{\\tilde{\\Sigma}}}\n",
-    "\\newcommand{\\vsigmoid}{\\vsigma}\n",
-    "\\newcommand{\\vtau}{\\boldsymbol{\\tau}}\n",
-    "\\newcommand{\\vxi}{\\boldsymbol{\\xi}}\n",
-    "\n",
-    "\n",
-    "% Lower Roman (Vectors)\n",
-    "\\newcommand{\\va}{\\mathbf{a}}\n",
-    "\\newcommand{\\vb}{\\mathbf{b}}\n",
-    "\\newcommand{\\vBt}{\\mathbf{\\tilde{B}}}\n",
-    "\\newcommand{\\vc}{\\mathbf{c}}\n",
-    "\\newcommand{\\vct}{\\mathbf{\\tilde{c}}}\n",
-    "\\newcommand{\\vd}{\\mathbf{d}}\n",
-    "\\newcommand{\\ve}{\\mathbf{e}}\n",
-    "\\newcommand{\\vf}{\\mathbf{f}}\n",
-    "\\newcommand{\\vg}{\\mathbf{g}}\n",
-    "\\newcommand{\\vh}{\\mathbf{h}}\n",
-    "%\\newcommand{\\myvh}{\\mathbf{h}}\n",
-    "\\newcommand{\\vi}{\\mathbf{i}}\n",
-    "\\newcommand{\\vj}{\\mathbf{j}}\n",
-    "\\newcommand{\\vk}{\\mathbf{k}}\n",
-    "\\newcommand{\\vl}{\\mathbf{l}}\n",
-    "\\newcommand{\\vm}{\\mathbf{m}}\n",
-    "\\newcommand{\\vn}{\\mathbf{n}}\n",
-    "\\newcommand{\\vo}{\\mathbf{o}}\n",
-    "\\newcommand{\\vp}{\\mathbf{p}}\n",
-    "\\newcommand{\\vq}{\\mathbf{q}}\n",
-    "\\newcommand{\\vr}{\\mathbf{r}}\n",
-    "\\newcommand{\\vs}{\\mathbf{s}}\n",
-    "\\newcommand{\\vt}{\\mathbf{t}}\n",
-    "\\newcommand{\\vu}{\\mathbf{u}}\n",
-    "\\newcommand{\\vv}{\\mathbf{v}}\n",
-    "\\newcommand{\\vw}{\\mathbf{w}}\n",
-    "\\newcommand{\\vws}{\\vw_s}\n",
-    "\\newcommand{\\vwt}{\\mathbf{\\tilde{w}}}\n",
-    "\\newcommand{\\vWt}{\\mathbf{\\tilde{W}}}\n",
-    "\\newcommand{\\vwh}{\\hat{\\vw}}\n",
-    "\\newcommand{\\vx}{\\mathbf{x}}\n",
-    "%\\newcommand{\\vx}{\\mathbf{x}}\n",
-    "\\newcommand{\\vxt}{\\mathbf{\\tilde{x}}}\n",
-    "\\newcommand{\\vy}{\\mathbf{y}}\n",
-    "\\newcommand{\\vyt}{\\mathbf{\\tilde{y}}}\n",
-    "\\newcommand{\\vz}{\\mathbf{z}}\n",
-    "%\\newcommand{\\vzt}{\\mathbf{\\tilde{z}}}\n",
-    "\n",
-    "\n",
-    "% Upper Roman (Matrices)\n",
-    "\\newcommand{\\vA}{\\mathbf{A}}\n",
-    "\\newcommand{\\vB}{\\mathbf{B}}\n",
-    "\\newcommand{\\vC}{\\mathbf{C}}\n",
-    "\\newcommand{\\vD}{\\mathbf{D}}\n",
-    "\\newcommand{\\vE}{\\mathbf{E}}\n",
-    "\\newcommand{\\vF}{\\mathbf{F}}\n",
-    "\\newcommand{\\vG}{\\mathbf{G}}\n",
-    "\\newcommand{\\vH}{\\mathbf{H}}\n",
-    "\\newcommand{\\vI}{\\mathbf{I}}\n",
-    "\\newcommand{\\vJ}{\\mathbf{J}}\n",
-    "\\newcommand{\\vK}{\\mathbf{K}}\n",
-    "\\newcommand{\\vL}{\\mathbf{L}}\n",
-    "\\newcommand{\\vM}{\\mathbf{M}}\n",
-    "\\newcommand{\\vMt}{\\mathbf{\\tilde{M}}}\n",
-    "\\newcommand{\\vN}{\\mathbf{N}}\n",
-    "\\newcommand{\\vO}{\\mathbf{O}}\n",
-    "\\newcommand{\\vP}{\\mathbf{P}}\n",
-    "\\newcommand{\\vQ}{\\mathbf{Q}}\n",
-    "\\newcommand{\\vR}{\\mathbf{R}}\n",
-    "\\newcommand{\\vS}{\\mathbf{S}}\n",
-    "\\newcommand{\\vT}{\\mathbf{T}}\n",
-    "\\newcommand{\\vU}{\\mathbf{U}}\n",
-    "\\newcommand{\\vV}{\\mathbf{V}}\n",
-    "\\newcommand{\\vW}{\\mathbf{W}}\n",
-    "\\newcommand{\\vX}{\\mathbf{X}}\n",
-    "%\\newcommand{\\vXs}{\\vX_{\\vs}}\n",
-    "\\newcommand{\\vXs}{\\vX_{s}}\n",
-    "\\newcommand{\\vXt}{\\mathbf{\\tilde{X}}}\n",
-    "\\newcommand{\\vY}{\\mathbf{Y}}\n",
-    "\\newcommand{\\vZ}{\\mathbf{Z}}\n",
-    "\\newcommand{\\vZt}{\\mathbf{\\tilde{Z}}}\n",
-    "\\newcommand{\\vzt}{\\mathbf{\\tilde{z}}}\n",
-    "\n",
-    "\n",
-    "%%%%\n",
-    "\\newcommand{\\hidden}{\\vz}\n",
-    "\\newcommand{\\hid}{\\hidden}\n",
-    "\\newcommand{\\observed}{\\vy}\n",
-    "\\newcommand{\\obs}{\\observed}\n",
-    "\\newcommand{\\inputs}{\\vu}\n",
-    "\\newcommand{\\input}{\\inputs}\n",
-    "\n",
-    "\\newcommand{\\hmmTrans}{\\vA}\n",
-    "\\newcommand{\\hmmObs}{\\vB}\n",
-    "\\newcommand{\\hmmInit}{\\vpi}\n",
-    "\\newcommand{\\hmmhid}{\\hidden}\n",
-    "\\newcommand{\\hmmobs}{\\obs}\n",
-    "\n",
-    "\\newcommand{\\ldsDyn}{\\vA}\n",
-    "\\newcommand{\\ldsObs}{\\vC}\n",
-    "\\newcommand{\\ldsDynIn}{\\vB}\n",
-    "\\newcommand{\\ldsObsIn}{\\vD}\n",
-    "\\newcommand{\\ldsDynNoise}{\\vQ}\n",
-    "\\newcommand{\\ldsObsNoise}{\\vR}\n",
-    "\n",
-    "\\newcommand{\\ssmDynFn}{f}\n",
-    "\\newcommand{\\ssmObsFn}{h}\n",
-    "\n",
-    "\n",
-    "%%%\n",
-    "\\newcommand{\\gauss}{\\mathcal{N}}\n",
-    "\n",
-    "\\newcommand{\\diag}{\\mathrm{diag}}\n",
-    "```\n"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
     "(sec:nlds-intro)=\n",
     "# Nonlinear Gaussian SSMs\n",
     "\n",

+ 21 - 186
chapters/ssm/ssm_intro.ipynb

@@ -4,175 +4,6 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "```{math}\n",
-    "\n",
-    "\\newcommand\\floor[1]{\\lfloor#1\\rfloor}\n",
-    "\n",
-    "\\newcommand{\\real}{\\mathbb{R}}\n",
-    "\n",
-    "% Numbers\n",
-    "\\newcommand{\\vzero}{\\boldsymbol{0}}\n",
-    "\\newcommand{\\vone}{\\boldsymbol{1}}\n",
-    "\n",
-    "% Greek https://www.latex-tutorial.com/symbols/greek-alphabet/\n",
-    "\\newcommand{\\valpha}{\\boldsymbol{\\alpha}}\n",
-    "\\newcommand{\\vbeta}{\\boldsymbol{\\beta}}\n",
-    "\\newcommand{\\vchi}{\\boldsymbol{\\chi}}\n",
-    "\\newcommand{\\vdelta}{\\boldsymbol{\\delta}}\n",
-    "\\newcommand{\\vDelta}{\\boldsymbol{\\Delta}}\n",
-    "\\newcommand{\\vepsilon}{\\boldsymbol{\\epsilon}}\n",
-    "\\newcommand{\\vzeta}{\\boldsymbol{\\zeta}}\n",
-    "\\newcommand{\\vXi}{\\boldsymbol{\\Xi}}\n",
-    "\\newcommand{\\vell}{\\boldsymbol{\\ell}}\n",
-    "\\newcommand{\\veta}{\\boldsymbol{\\eta}}\n",
-    "%\\newcommand{\\vEta}{\\boldsymbol{\\Eta}}\n",
-    "\\newcommand{\\vgamma}{\\boldsymbol{\\gamma}}\n",
-    "\\newcommand{\\vGamma}{\\boldsymbol{\\Gamma}}\n",
-    "\\newcommand{\\vmu}{\\boldsymbol{\\mu}}\n",
-    "\\newcommand{\\vmut}{\\boldsymbol{\\tilde{\\mu}}}\n",
-    "\\newcommand{\\vnu}{\\boldsymbol{\\nu}}\n",
-    "\\newcommand{\\vkappa}{\\boldsymbol{\\kappa}}\n",
-    "\\newcommand{\\vlambda}{\\boldsymbol{\\lambda}}\n",
-    "\\newcommand{\\vLambda}{\\boldsymbol{\\Lambda}}\n",
-    "\\newcommand{\\vLambdaBar}{\\overline{\\vLambda}}\n",
-    "%\\newcommand{\\vnu}{\\boldsymbol{\\nu}}\n",
-    "\\newcommand{\\vomega}{\\boldsymbol{\\omega}}\n",
-    "\\newcommand{\\vOmega}{\\boldsymbol{\\Omega}}\n",
-    "\\newcommand{\\vphi}{\\boldsymbol{\\phi}}\n",
-    "\\newcommand{\\vvarphi}{\\boldsymbol{\\varphi}}\n",
-    "\\newcommand{\\vPhi}{\\boldsymbol{\\Phi}}\n",
-    "\\newcommand{\\vpi}{\\boldsymbol{\\pi}}\n",
-    "\\newcommand{\\vPi}{\\boldsymbol{\\Pi}}\n",
-    "\\newcommand{\\vpsi}{\\boldsymbol{\\psi}}\n",
-    "\\newcommand{\\vPsi}{\\boldsymbol{\\Psi}}\n",
-    "\\newcommand{\\vrho}{\\boldsymbol{\\rho}}\n",
-    "\\newcommand{\\vtheta}{\\boldsymbol{\\theta}}\n",
-    "\\newcommand{\\vthetat}{\\boldsymbol{\\tilde{\\theta}}}\n",
-    "\\newcommand{\\vTheta}{\\boldsymbol{\\Theta}}\n",
-    "\\newcommand{\\vsigma}{\\boldsymbol{\\sigma}}\n",
-    "\\newcommand{\\vSigma}{\\boldsymbol{\\Sigma}}\n",
-    "\\newcommand{\\vSigmat}{\\boldsymbol{\\tilde{\\Sigma}}}\n",
-    "\\newcommand{\\vsigmoid}{\\vsigma}\n",
-    "\\newcommand{\\vtau}{\\boldsymbol{\\tau}}\n",
-    "\\newcommand{\\vxi}{\\boldsymbol{\\xi}}\n",
-    "\n",
-    "\n",
-    "% Lower Roman (Vectors)\n",
-    "\\newcommand{\\va}{\\mathbf{a}}\n",
-    "\\newcommand{\\vb}{\\mathbf{b}}\n",
-    "\\newcommand{\\vBt}{\\mathbf{\\tilde{B}}}\n",
-    "\\newcommand{\\vc}{\\mathbf{c}}\n",
-    "\\newcommand{\\vct}{\\mathbf{\\tilde{c}}}\n",
-    "\\newcommand{\\vd}{\\mathbf{d}}\n",
-    "\\newcommand{\\ve}{\\mathbf{e}}\n",
-    "\\newcommand{\\vf}{\\mathbf{f}}\n",
-    "\\newcommand{\\vg}{\\mathbf{g}}\n",
-    "\\newcommand{\\vh}{\\mathbf{h}}\n",
-    "%\\newcommand{\\myvh}{\\mathbf{h}}\n",
-    "\\newcommand{\\vi}{\\mathbf{i}}\n",
-    "\\newcommand{\\vj}{\\mathbf{j}}\n",
-    "\\newcommand{\\vk}{\\mathbf{k}}\n",
-    "\\newcommand{\\vl}{\\mathbf{l}}\n",
-    "\\newcommand{\\vm}{\\mathbf{m}}\n",
-    "\\newcommand{\\vn}{\\mathbf{n}}\n",
-    "\\newcommand{\\vo}{\\mathbf{o}}\n",
-    "\\newcommand{\\vp}{\\mathbf{p}}\n",
-    "\\newcommand{\\vq}{\\mathbf{q}}\n",
-    "\\newcommand{\\vr}{\\mathbf{r}}\n",
-    "\\newcommand{\\vs}{\\mathbf{s}}\n",
-    "\\newcommand{\\vt}{\\mathbf{t}}\n",
-    "\\newcommand{\\vu}{\\mathbf{u}}\n",
-    "\\newcommand{\\vv}{\\mathbf{v}}\n",
-    "\\newcommand{\\vw}{\\mathbf{w}}\n",
-    "\\newcommand{\\vws}{\\vw_s}\n",
-    "\\newcommand{\\vwt}{\\mathbf{\\tilde{w}}}\n",
-    "\\newcommand{\\vWt}{\\mathbf{\\tilde{W}}}\n",
-    "\\newcommand{\\vwh}{\\hat{\\vw}}\n",
-    "\\newcommand{\\vx}{\\mathbf{x}}\n",
-    "%\\newcommand{\\vx}{\\mathbf{x}}\n",
-    "\\newcommand{\\vxt}{\\mathbf{\\tilde{x}}}\n",
-    "\\newcommand{\\vy}{\\mathbf{y}}\n",
-    "\\newcommand{\\vyt}{\\mathbf{\\tilde{y}}}\n",
-    "\\newcommand{\\vz}{\\mathbf{z}}\n",
-    "%\\newcommand{\\vzt}{\\mathbf{\\tilde{z}}}\n",
-    "\n",
-    "\n",
-    "% Upper Roman (Matrices)\n",
-    "\\newcommand{\\vA}{\\mathbf{A}}\n",
-    "\\newcommand{\\vB}{\\mathbf{B}}\n",
-    "\\newcommand{\\vC}{\\mathbf{C}}\n",
-    "\\newcommand{\\vD}{\\mathbf{D}}\n",
-    "\\newcommand{\\vE}{\\mathbf{E}}\n",
-    "\\newcommand{\\vF}{\\mathbf{F}}\n",
-    "\\newcommand{\\vG}{\\mathbf{G}}\n",
-    "\\newcommand{\\vH}{\\mathbf{H}}\n",
-    "\\newcommand{\\vI}{\\mathbf{I}}\n",
-    "\\newcommand{\\vJ}{\\mathbf{J}}\n",
-    "\\newcommand{\\vK}{\\mathbf{K}}\n",
-    "\\newcommand{\\vL}{\\mathbf{L}}\n",
-    "\\newcommand{\\vM}{\\mathbf{M}}\n",
-    "\\newcommand{\\vMt}{\\mathbf{\\tilde{M}}}\n",
-    "\\newcommand{\\vN}{\\mathbf{N}}\n",
-    "\\newcommand{\\vO}{\\mathbf{O}}\n",
-    "\\newcommand{\\vP}{\\mathbf{P}}\n",
-    "\\newcommand{\\vQ}{\\mathbf{Q}}\n",
-    "\\newcommand{\\vR}{\\mathbf{R}}\n",
-    "\\newcommand{\\vS}{\\mathbf{S}}\n",
-    "\\newcommand{\\vT}{\\mathbf{T}}\n",
-    "\\newcommand{\\vU}{\\mathbf{U}}\n",
-    "\\newcommand{\\vV}{\\mathbf{V}}\n",
-    "\\newcommand{\\vW}{\\mathbf{W}}\n",
-    "\\newcommand{\\vX}{\\mathbf{X}}\n",
-    "%\\newcommand{\\vXs}{\\vX_{\\vs}}\n",
-    "\\newcommand{\\vXs}{\\vX_{s}}\n",
-    "\\newcommand{\\vXt}{\\mathbf{\\tilde{X}}}\n",
-    "\\newcommand{\\vY}{\\mathbf{Y}}\n",
-    "\\newcommand{\\vZ}{\\mathbf{Z}}\n",
-    "\\newcommand{\\vZt}{\\mathbf{\\tilde{Z}}}\n",
-    "\\newcommand{\\vzt}{\\mathbf{\\tilde{z}}}\n",
-    "\n",
-    "\n",
-    "%%%%\n",
-    "\\newcommand{\\hidden}{\\vz}\n",
-    "\\newcommand{\\hid}{\\hidden}\n",
-    "\\newcommand{\\observed}{\\vy}\n",
-    "\\newcommand{\\obs}{\\observed}\n",
-    "\\newcommand{\\inputs}{\\vu}\n",
-    "\\newcommand{\\input}{\\inputs}\n",
-    "\n",
-    "\\newcommand{\\hmmTrans}{\\vA}\n",
-    "\\newcommand{\\hmmObs}{\\vB}\n",
-    "\\newcommand{\\hmmInit}{\\vpi}\n",
-    "\\newcommand{\\hmmhid}{\\hidden}\n",
-    "\\newcommand{\\hmmobs}{\\obs}\n",
-    "\n",
-    "\\newcommand{\\ldsDyn}{\\vA}\n",
-    "\\newcommand{\\ldsObs}{\\vC}\n",
-    "\\newcommand{\\ldsDynIn}{\\vB}\n",
-    "\\newcommand{\\ldsObsIn}{\\vD}\n",
-    "\\newcommand{\\ldsDynNoise}{\\vQ}\n",
-    "\\newcommand{\\ldsObsNoise}{\\vR}\n",
-    "\n",
-    "\\newcommand{\\ssmDynFn}{f}\n",
-    "\\newcommand{\\ssmObsFn}{h}\n",
-    "\n",
-    "\n",
-    "%%%\n",
-    "\\newcommand{\\gauss}{\\mathcal{N}}\n",
-    "\n",
-    "\\newcommand{\\diag}{\\mathrm{diag}}\n",
-    "```\n"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": []
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
     "(sec:ssm-intro)=\n",
     "# What are State Space Models?\n",
     "\n",
@@ -197,7 +28,7 @@
     "unlike a standard Markov model.\n",
     "\n",
     "```{figure} /figures/SSM-AR-inputs.png\n",
-    ":height: 300px\n",
+    ":height: 150px\n",
     ":name: fig:ssm-ar\n",
     "\n",
     "Illustration of an SSM as a graphical model.\n",
@@ -208,14 +39,14 @@
     "as the following joint distribution:\n",
     "```{math}\n",
     ":label: eq:SSM-ar\n",
-    "p(\\hmmobs_{1:T},\\hmmhid_{1:T}|\\inputs_{1:T})\n",
-    " = \\left[ p(\\hmmhid_1|\\inputs_1) \\prod_{t=2}^{T}\n",
-    " p(\\hmmhid_t|\\hmmhid_{t-1},\\inputs_t) \\right]\n",
-    " \\left[ \\prod_{t=1}^T p(\\hmmobs_t|\\hmmhid_t, \\inputs_t, \\hmmobs_{t-1}) \\right]\n",
+    "p(\\obs_{1:T},\\hidden_{1:T}|\\inputs_{1:T})\n",
+    " = \\left[ p(\\hidden_1|\\inputs_1) \\prod_{t=2}^{T}\n",
+    " p(\\hidden_t|\\hidden_{t-1},\\inputs_t) \\right]\n",
+    " \\left[ \\prod_{t=1}^T p(\\obs_t|\\hidden_t, \\inputs_t, \\obs_{t-1}) \\right]\n",
     "```\n",
-    "where $p(\\hmmhid_t|\\hmmhid_{t-1},\\inputs_t)$ is the\n",
+    "where $p(\\hidden_t|\\hidden_{t-1},\\inputs_t)$ is the\n",
     "transition model,\n",
-    "$p(\\hmmobs_t|\\hmmhid_t, \\inputs_t, \\hmmobs_{t-1})$ is the\n",
+    "$p(\\obs_t|\\hidden_t, \\inputs_t, \\obs_{t-1})$ is the\n",
     "observation model,\n",
     "and $\\inputs_{t}$ is an optional input or action.\n",
     "See {numref}`fig:ssm-ar` \n",
@@ -228,31 +59,35 @@
     "In this case the joint simplifies to \n",
     "```{math}\n",
     ":label: eq:SSM-input\n",
-    "p(\\hmmobs_{1:T},\\hmmhid_{1:T}|\\inputs_{1:T})\n",
-    " = \\left[ p(\\hmmhid_1|\\inputs_1) \\prod_{t=2}^{T}\n",
-    " p(\\hmmhid_t|\\hmmhid_{t-1},\\inputs_t) \\right]\n",
-    " \\left[ \\prod_{t=1}^T p(\\hmmobs_t|\\hmmhid_t, \\inputs_t) \\right]\n",
+    "p(\\obs_{1:T},\\hidden_{1:T}|\\inputs_{1:T})\n",
+    " = \\left[ p(\\hidden_1|\\inputs_1) \\prod_{t=2}^{T}\n",
+    " p(\\hidden_t|\\hidden_{t-1},\\inputs_t) \\right]\n",
+    " \\left[ \\prod_{t=1}^T p(\\obs_t|\\hidden_t, \\inputs_t) \\right]\n",
     "```\n",
     "Sometimes there are no external inputs, so the model further\n",
     "simplifies to the following unconditional generative model: \n",
     "```{math}\n",
     ":label: eq:SSM-no-input\n",
-    "p(\\hmmobs_{1:T},\\hmmhid_{1:T})\n",
-    " = \\left[ p(\\hmmhid_1) \\prod_{t=2}^{T}\n",
-    " p(\\hmmhid_t|\\hmmhid_{t-1}) \\right]\n",
-    " \\left[ \\prod_{t=1}^T p(\\hmmobs_t|\\hmmhid_t) \\right]\n",
+    "p(\\obs_{1:T},\\hidden_{1:T})\n",
+    " = \\left[ p(\\hidden_1) \\prod_{t=2}^{T}\n",
+    " p(\\hidden_t|\\hidden_{t-1}) \\right]\n",
+    " \\left[ \\prod_{t=1}^T p(\\obs_t|\\hidden_t) \\right]\n",
     "```\n",
     "See {numref}`ssm-simplified` \n",
     "for an illustration of the corresponding graphical model.\n",
     "\n",
     "\n",
     "```{figure} /figures/SSM-simplified.png\n",
-    ":scale: 100%\n",
+    ":height: 150px\n",
     ":name: ssm-simplified\n",
     "\n",
     "Illustration of a simplified SSM.\n",
     "```\n",
-    "\n"
+    "\n",
+    "SSMs are widely used in many areas of science, engineering, finance, economics, etc.\n",
+    "The main applications are state estimation (i.e., inferring the underlying hidden state of the system given the observation),\n",
+    "forecasting (i.e., predicting future states and observations), and control (i.e., inferring the sequence of inputs that will\n",
+    "give rise to a desired target state). We will discuss these applications in alter chapters."
    ]
   },
   {

+ 12 - 2
references.bib

@@ -1,6 +1,4 @@
 
----
----
 
 %@string{aij = "Artificial Intelligence J."}
 @string{aij = "AIJ"}
@@ -554,3 +552,15 @@ publisher = "Cambridge University Press",
  year = 2005,
  publisher = "Springer"
 }
+
+
+@article{Devijver85,
+  title={Baum's forward-backward algorithm revisited},
+  author={Devijver, Pierre A},
+  journal={Pattern Recognition Letters},
+  volume={3},
+  number={6},
+  pages={369--373},
+  year={1985},
+  publisher={Elsevier}
+}