Explorar o código

add macros and more content

Kevin P Murphy %!s(int64=3) %!d(string=hai) anos
pai
achega
d323ba98a2

+ 1 - 1
_toc.yml

@@ -4,7 +4,6 @@
 format: jb-book
 format: jb-book
 root: root
 root: root
 chapters:
 chapters:
-- file: chapters/scratch
 
 
 - file: chapters/ssm/ssm_index
 - file: chapters/ssm/ssm_index
   sections:
   sections:
@@ -13,6 +12,7 @@ chapters:
   - file: chapters/ssm/lds
   - file: chapters/ssm/lds
   - file: chapters/ssm/nlds
   - file: chapters/ssm/nlds
   - file: chapters/ssm/inference
   - file: chapters/ssm/inference
+  - file: chapters/ssm/learning
 
 
 - file: chapters/hmm/hmm_index
 - file: chapters/hmm/hmm_index
   sections:
   sections:

+ 0 - 59
chapters/blank.ipynb

@@ -1,59 +0,0 @@
-{
- "cells": [
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "(chap:my-chap)=\n",
-    "# Chapter title\n",
-    "\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 1,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "42\n"
-     ]
-    }
-   ],
-   "source": [
-    "print(42)"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": []
-  }
- ],
- "metadata": {
-  "interpreter": {
-   "hash": "6407c60499271029b671b4ff687c4ed4626355c45fd34c44476827f4be42c4d7"
-  },
-  "kernelspec": {
-   "display_name": "Python 3.9.2 ('spyder-dev')",
-   "language": "python",
-   "name": "python3"
-  },
-  "language_info": {
-   "codemirror_mode": {
-    "name": "ipython",
-    "version": 3
-   },
-   "file_extension": ".py",
-   "mimetype": "text/x-python",
-   "name": "python",
-   "nbconvert_exporter": "python",
-   "pygments_lexer": "ipython3",
-   "version": "3.9.2"
-  }
- },
- "nbformat": 4,
- "nbformat_minor": 4
-}

+ 0 - 612
chapters/hmm/hmm.ipynb

@@ -1,612 +0,0 @@
-{
- "cells": [
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "(sec:hmm-ex)=\n",
-    "# Hidden Markov Models\n",
-    "\n",
-    "In this section, we introduce Hidden Markov Models (HMMs)."
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "## Boilerplate"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 1,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "Collecting jax[cpu]\n",
-      "  Downloading jax-0.3.5.tar.gz (946 kB)\n",
-      "\u001b[K     |████████████████████████████████| 946 kB 2.7 MB/s eta 0:00:01\n",
-      "\u001b[?25hCollecting absl-py\n",
-      "  Downloading absl_py-1.0.0-py3-none-any.whl (126 kB)\n",
-      "\u001b[K     |████████████████████████████████| 126 kB 47.7 MB/s eta 0:00:01\n",
-      "\u001b[?25hCollecting numpy>=1.19\n",
-      "  Downloading numpy-1.22.3-cp38-cp38-macosx_10_14_x86_64.whl (17.6 MB)\n",
-      "\u001b[K     |████████████████████████████████| 17.6 MB 47.5 MB/s eta 0:00:01\n",
-      "\u001b[?25hCollecting opt_einsum\n",
-      "  Using cached opt_einsum-3.3.0-py3-none-any.whl (65 kB)\n",
-      "Collecting scipy>=1.2.1\n",
-      "  Downloading scipy-1.8.0-cp38-cp38-macosx_12_0_universal2.macosx_10_9_x86_64.whl (55.3 MB)\n",
-      "\u001b[K     |████████████████████████████████| 55.3 MB 73.1 MB/s eta 0:00:01\n",
-      "\u001b[?25hCollecting typing_extensions\n",
-      "  Using cached typing_extensions-4.1.1-py3-none-any.whl (26 kB)\n",
-      "Collecting jaxlib==0.3.5\n",
-      "  Downloading jaxlib-0.3.5-cp38-none-macosx_10_9_x86_64.whl (70.5 MB)\n",
-      "\u001b[K     |████████████████████████████████| 70.5 MB 723 kB/s  eta 0:00:01\n",
-      "\u001b[?25hCollecting flatbuffers<3.0,>=1.12\n",
-      "  Using cached flatbuffers-2.0-py2.py3-none-any.whl (26 kB)\n",
-      "Requirement already satisfied: six in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from absl-py->jax[cpu]) (1.16.0)\n",
-      "Building wheels for collected packages: jax\n",
-      "  Building wheel for jax (setup.py) ... \u001b[?25ldone\n",
-      "\u001b[?25h  Created wheel for jax: filename=jax-0.3.5-py3-none-any.whl size=1095861 sha256=6886baa70817bbac3b5797b3720dcb81e49097f61cee7d2e1255823ea32ccad8\n",
-      "  Stored in directory: /Users/kpmurphy/Library/Caches/pip/wheels/05/30/aa/908988293721511b4b29e0aadf9b5d133d0f14f6c0a188e764\n",
-      "Successfully built jax\n",
-      "Installing collected packages: numpy, typing-extensions, scipy, opt-einsum, flatbuffers, absl-py, jaxlib, jax\n",
-      "Successfully installed absl-py-1.0.0 flatbuffers-2.0 jax-0.3.5 jaxlib-0.3.5 numpy-1.22.3 opt-einsum-3.3.0 scipy-1.8.0 typing-extensions-4.1.1\n",
-      "Note: you may need to restart the kernel to use updated packages.\n",
-      "Collecting git+https://github.com/probml/jsl\n",
-      "  Cloning https://github.com/probml/jsl to /private/var/folders/mn/vt7cgfsx6zs9vblhvbbk7pf8003xtr/T/pip-req-build-i8seqdiw\n",
-      "  Running command git clone -q https://github.com/probml/jsl /private/var/folders/mn/vt7cgfsx6zs9vblhvbbk7pf8003xtr/T/pip-req-build-i8seqdiw\n",
-      "Collecting chex\n",
-      "  Downloading chex-0.1.2-py3-none-any.whl (72 kB)\n",
-      "\u001b[K     |████████████████████████████████| 72 kB 1.3 MB/s eta 0:00:011\n",
-      "\u001b[?25hCollecting dataclasses\n",
-      "  Using cached dataclasses-0.6-py3-none-any.whl (14 kB)\n",
-      "Requirement already satisfied: jaxlib in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from jsl==0.0.0) (0.3.5)\n",
-      "Requirement already satisfied: jax in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from jsl==0.0.0) (0.3.5)\n",
-      "Collecting matplotlib\n",
-      "  Downloading matplotlib-3.5.1-cp38-cp38-macosx_10_9_x86_64.whl (7.3 MB)\n",
-      "\u001b[K     |████████████████████████████████| 7.3 MB 3.9 MB/s eta 0:00:01\n",
-      "\u001b[?25hCollecting tensorflow_probability\n",
-      "  Using cached tensorflow_probability-0.16.0-py2.py3-none-any.whl (6.3 MB)\n",
-      "Collecting dm-tree>=0.1.5\n",
-      "  Using cached dm_tree-0.1.6-cp38-cp38-macosx_10_14_x86_64.whl (95 kB)\n",
-      "Requirement already satisfied: absl-py>=0.9.0 in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from chex->jsl==0.0.0) (1.0.0)\n",
-      "Requirement already satisfied: numpy>=1.18.0 in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from chex->jsl==0.0.0) (1.22.3)\n",
-      "Collecting toolz>=0.9.0\n",
-      "  Downloading toolz-0.11.2-py3-none-any.whl (55 kB)\n",
-      "\u001b[K     |████████████████████████████████| 55 kB 11.3 MB/s eta 0:00:01\n",
-      "\u001b[?25hRequirement already satisfied: six in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from absl-py>=0.9.0->chex->jsl==0.0.0) (1.16.0)\n",
-      "Requirement already satisfied: typing-extensions in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from jax->jsl==0.0.0) (4.1.1)\n",
-      "Requirement already satisfied: scipy>=1.2.1 in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from jax->jsl==0.0.0) (1.8.0)\n",
-      "Requirement already satisfied: opt-einsum in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from jax->jsl==0.0.0) (3.3.0)\n",
-      "Requirement already satisfied: flatbuffers<3.0,>=1.12 in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from jaxlib->jsl==0.0.0) (2.0)\n",
-      "Collecting cycler>=0.10\n",
-      "  Downloading cycler-0.11.0-py3-none-any.whl (6.4 kB)\n",
-      "Collecting kiwisolver>=1.0.1\n",
-      "  Downloading kiwisolver-1.4.2-cp38-cp38-macosx_10_9_x86_64.whl (65 kB)\n",
-      "\u001b[K     |████████████████████████████████| 65 kB 8.5 MB/s  eta 0:00:01\n",
-      "\u001b[?25hCollecting fonttools>=4.22.0\n",
-      "  Downloading fonttools-4.32.0-py3-none-any.whl (900 kB)\n",
-      "\u001b[K     |████████████████████████████████| 900 kB 35.8 MB/s eta 0:00:01\n",
-      "\u001b[?25hRequirement already satisfied: pyparsing>=2.2.1 in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from matplotlib->jsl==0.0.0) (3.0.7)\n",
-      "Requirement already satisfied: packaging>=20.0 in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from matplotlib->jsl==0.0.0) (21.3)\n",
-      "Requirement already satisfied: python-dateutil>=2.7 in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from matplotlib->jsl==0.0.0) (2.8.2)\n",
-      "Collecting pillow>=6.2.0\n",
-      "  Downloading Pillow-9.1.0-cp38-cp38-macosx_10_9_x86_64.whl (3.1 MB)\n",
-      "\u001b[K     |████████████████████████████████| 3.1 MB 76.6 MB/s eta 0:00:01\n",
-      "\u001b[?25hCollecting gast>=0.3.2\n",
-      "  Downloading gast-0.5.3-py3-none-any.whl (19 kB)\n",
-      "Requirement already satisfied: decorator in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from tensorflow_probability->jsl==0.0.0) (5.1.1)\n",
-      "Collecting cloudpickle>=1.3\n",
-      "  Downloading cloudpickle-2.0.0-py3-none-any.whl (25 kB)\n",
-      "Building wheels for collected packages: jsl\n",
-      "  Building wheel for jsl (setup.py) ... \u001b[?25ldone\n",
-      "\u001b[?25h  Created wheel for jsl: filename=jsl-0.0.0-py3-none-any.whl size=77852 sha256=e7365293dc97e2b3e72bf42cc19db7d7e355abec312fc4d87961fa2044fa06f0\n",
-      "  Stored in directory: /private/var/folders/mn/vt7cgfsx6zs9vblhvbbk7pf8003xtr/T/pip-ephem-wheel-cache-63vxzlng/wheels/ed/8b/bf/0105dc839fecf1fc8db14f7267a6ce5ee876324b58565b359f\n",
-      "Successfully built jsl\n",
-      "Installing collected packages: toolz, pillow, kiwisolver, gast, fonttools, dm-tree, cycler, cloudpickle, tensorflow-probability, matplotlib, dataclasses, chex, jsl\n",
-      "Successfully installed chex-0.1.2 cloudpickle-2.0.0 cycler-0.11.0 dataclasses-0.6 dm-tree-0.1.6 fonttools-4.32.0 gast-0.5.3 jsl-0.0.0 kiwisolver-1.4.2 matplotlib-3.5.1 pillow-9.1.0 tensorflow-probability-0.16.0 toolz-0.11.2\n",
-      "Note: you may need to restart the kernel to use updated packages.\n",
-      "Collecting rich\n",
-      "  Downloading rich-12.2.0-py3-none-any.whl (229 kB)\n",
-      "\u001b[K     |████████████████████████████████| 229 kB 2.9 MB/s eta 0:00:01\n",
-      "\u001b[?25hRequirement already satisfied: typing-extensions<5.0,>=4.0.0 in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from rich) (4.1.1)\n",
-      "Requirement already satisfied: pygments<3.0.0,>=2.6.0 in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from rich) (2.11.2)\n",
-      "Collecting commonmark<0.10.0,>=0.9.0\n",
-      "  Using cached commonmark-0.9.1-py2.py3-none-any.whl (51 kB)\n",
-      "Installing collected packages: commonmark, rich\n",
-      "Successfully installed commonmark-0.9.1 rich-12.2.0\n",
-      "Note: you may need to restart the kernel to use updated packages.\n"
-     ]
-    }
-   ],
-   "source": [
-    "# Install necessary libraries\n",
-    "\n",
-    "try:\n",
-    "    import jax\n",
-    "except:\n",
-    "    # For cuda version, see https://github.com/google/jax#installation\n",
-    "    %pip install --upgrade \"jax[cpu]\" \n",
-    "    import jax\n",
-    "\n",
-    "try:\n",
-    "    import jsl\n",
-    "except:\n",
-    "    %pip install git+https://github.com/probml/jsl\n",
-    "    import jsl\n",
-    "\n",
-    "try:\n",
-    "    import rich\n",
-    "except:\n",
-    "    %pip install rich\n",
-    "    import rich\n",
-    "\n",
-    "\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 2,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "# Import standard libraries\n",
-    "\n",
-    "import abc\n",
-    "from dataclasses import dataclass\n",
-    "import functools\n",
-    "import itertools\n",
-    "\n",
-    "from typing import Any, Callable, NamedTuple, Optional, Union, Tuple\n",
-    "\n",
-    "import matplotlib.pyplot as plt\n",
-    "import numpy as np\n",
-    "\n",
-    "\n",
-    "import jax\n",
-    "import jax.numpy as jnp\n",
-    "from jax import lax, vmap, jit, grad\n",
-    "from jax.scipy.special import logit\n",
-    "from jax.nn import softmax\n",
-    "from functools import partial\n",
-    "from jax.random import PRNGKey, split\n",
-    "\n",
-    "import inspect\n",
-    "import inspect as py_inspect\n",
-    "from rich import inspect as r_inspect\n",
-    "from rich import print as r_print\n",
-    "\n",
-    "def print_source(fname):\n",
-    "    r_print(py_inspect.getsource(fname))"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "## Utility code"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 3,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "\n",
-    "\n",
-    "def normalize(u, axis=0, eps=1e-15):\n",
-    "    '''\n",
-    "    Normalizes the values within the axis in a way that they sum up to 1.\n",
-    "    Parameters\n",
-    "    ----------\n",
-    "    u : array\n",
-    "    axis : int\n",
-    "    eps : float\n",
-    "        Threshold for the alpha values\n",
-    "    Returns\n",
-    "    -------\n",
-    "    * array\n",
-    "        Normalized version of the given matrix\n",
-    "    * array(seq_len, n_hidden) :\n",
-    "        The values of the normalizer\n",
-    "    '''\n",
-    "    u = jnp.where(u == 0, 0, jnp.where(u < eps, eps, u))\n",
-    "    c = u.sum(axis=axis)\n",
-    "    c = jnp.where(c == 0, 1, c)\n",
-    "    return u / c, c"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "(sec:casino-ex)=\n",
-    "## Example: Casino HMM\n",
-    "\n",
-    "We first create the \"Ocassionally dishonest casino\" model from {cite}`Durbin98`.\n",
-    "\n",
-    "```{figure} /figures/casino.png\n",
-    ":scale: 50%\n",
-    ":name: casino-fig\n",
-    "\n",
-    "Illustration of the casino HMM.\n",
-    "```\n",
-    "\n",
-    "There are 2 hidden states, each of which emit 6 possible observations."
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 5,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
-     ]
-    }
-   ],
-   "source": [
-    "# state transition matrix\n",
-    "A = np.array([\n",
-    "    [0.95, 0.05],\n",
-    "    [0.10, 0.90]\n",
-    "])\n",
-    "\n",
-    "# observation matrix\n",
-    "B = np.array([\n",
-    "    [1/6, 1/6, 1/6, 1/6, 1/6, 1/6], # fair die\n",
-    "    [1/10, 1/10, 1/10, 1/10, 1/10, 5/10] # loaded die\n",
-    "])\n",
-    "\n",
-    "pi, _ = normalize(np.array([1, 1]))\n",
-    "pi = np.array(pi)\n",
-    "\n",
-    "\n",
-    "(nstates, nobs) = np.shape(B)\n"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "Let's make a little data structure to store all the parameters.\n",
-    "We use NamedTuple rather than dataclass, since we assume these are immutable.\n",
-    "(Also, standard python dataclass does not work well with JAX, which requires parameters to be\n",
-    "pytrees, as discussed in https://github.com/google/jax/issues/2371)."
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 74,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "HMM(trans_mat=array([[0.95, 0.05],\n",
-      "       [0.1 , 0.9 ]]), obs_mat=array([[0.16666667, 0.16666667, 0.16666667, 0.16666667, 0.16666667,\n",
-      "        0.16666667],\n",
-      "       [0.1       , 0.1       , 0.1       , 0.1       , 0.1       ,\n",
-      "        0.5       ]]), init_dist=array([0.5, 0.5], dtype=float32))\n",
-      "<class 'numpy.ndarray'>\n",
-      "HMM(trans_mat=DeviceArray([[0.95, 0.05],\n",
-      "             [0.1 , 0.9 ]], dtype=float32), obs_mat=DeviceArray([[0.16666667, 0.16666667, 0.16666667, 0.16666667, 0.16666667,\n",
-      "              0.16666667],\n",
-      "             [0.1       , 0.1       , 0.1       , 0.1       , 0.1       ,\n",
-      "              0.5       ]], dtype=float32), init_dist=DeviceArray([0.5, 0.5], dtype=float32))\n",
-      "<class 'jaxlib.xla_extension.DeviceArray'>\n"
-     ]
-    }
-   ],
-   "source": [
-    "Array = Union[np.array, jnp.array]\n",
-    "\n",
-    "class HMM(NamedTuple):\n",
-    "    trans_mat: Array  # A : (n_states, n_states)\n",
-    "    obs_mat: Array  # B : (n_states, n_obs)\n",
-    "    init_dist: Array  # pi : (n_states)\n",
-    "\n",
-    "params_np = HMM(A, B, pi)\n",
-    "print(params_np)\n",
-    "print(type(params_np.trans_mat))\n",
-    "\n",
-    "\n",
-    "params = jax.tree_map(lambda x: jnp.array(x), params_np)\n",
-    "print(params)\n",
-    "print(type(params.trans_mat))"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "## Sampling from the joint\n",
-    "\n",
-    "Let's write code to sample from this model. \n"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "### Numpy version\n",
-    "\n",
-    "First we code it in numpy using a for loop."
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 30,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "def hmm_sample_np(params, seq_len, random_state=0):\n",
-    "    np.random.seed(random_state)\n",
-    "    trans_mat, obs_mat, init_dist = params.trans_mat, params.obs_mat, params.init_dist\n",
-    "    n_states, n_obs = obs_mat.shape\n",
-    "    state_seq = np.zeros(seq_len, dtype=int)\n",
-    "    obs_seq = np.zeros(seq_len, dtype=int)\n",
-    "    for t in range(seq_len):\n",
-    "        if t==0:\n",
-    "            zt = np.random.choice(n_states, p=init_dist)\n",
-    "        else:\n",
-    "            zt = np.random.choice(n_states, p=trans_mat[zt])\n",
-    "        yt = np.random.choice(n_obs, p=obs_mat[zt])\n",
-    "        state_seq[t] = zt\n",
-    "        obs_seq[t] = yt\n",
-    "\n",
-    "    return state_seq, obs_seq"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 75,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0\n",
-      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
-      " 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
-      "[4 1 0 2 3 4 5 4 3 1 5 4 5 0 5 2 5 3 5 4 5 5 4 2 1 4 1 0 0 4 2 2 3 3 3 0 4\n",
-      " 0 2 4 3 2 5 5 3 5 3 1 3 3 3 2 3 5 5 0 4 4 5 0 0 1 3 5 1 5 0 1 2 4 0 0 0 4\n",
-      " 0 5 1 4 3 5 4 5 0 2 3 5 2 4 1 2 1 0 4 3 5 0 4 5 1 5]\n"
-     ]
-    }
-   ],
-   "source": [
-    "seq_len = 100\n",
-    "state_seq, obs_seq = hmm_sample_np(params_np, seq_len, random_state=1)\n",
-    "print(state_seq)\n",
-    "print(obs_seq)"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "### JAX version\n",
-    "\n",
-    "Now let's write a JAX version using jax.lax.scan (for the inter-dependent states) and vmap (for the observations).\n",
-    "This is harder to read than the numpy version, but faster."
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 91,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "#@partial(jit, static_argnums=(1,))\n",
-    "def markov_chain_sample(rng_key, init_dist, trans_mat, seq_len):\n",
-    "    n_states = len(init_dist)\n",
-    "\n",
-    "    def draw_state(prev_state, key):\n",
-    "        state = jax.random.choice(key, n_states, p=trans_mat[prev_state])\n",
-    "        return state, state\n",
-    "\n",
-    "    rng_key, rng_state = jax.random.split(rng_key, 2)\n",
-    "    keys = jax.random.split(rng_state, seq_len - 1)\n",
-    "    initial_state = jax.random.choice(rng_key, n_states, p=init_dist)\n",
-    "    final_state, states = jax.lax.scan(draw_state, initial_state, keys)\n",
-    "    state_seq = jnp.append(jnp.array([initial_state]), states)\n",
-    "\n",
-    "    return state_seq"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 90,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "#@partial(jit, static_argnums=(1,))\n",
-    "def hmm_sample(rng_key, params, seq_len):\n",
-    "\n",
-    "    trans_mat, obs_mat, init_dist = params.trans_mat, params.obs_mat, params.init_dist\n",
-    "    n_states, n_obs = obs_mat.shape\n",
-    "    rng_key, rng_obs = jax.random.split(rng_key, 2)\n",
-    "    state_seq = markov_chain_sample(rng_key, init_dist, trans_mat, seq_len)\n",
-    "\n",
-    "    def draw_obs(z, key):\n",
-    "        obs = jax.random.choice(key, n_obs, p=obs_mat[z])\n",
-    "        return obs\n",
-    "\n",
-    "    keys = jax.random.split(rng_obs, seq_len)\n",
-    "    obs_seq = jax.vmap(draw_obs, in_axes=(0, 0))(state_seq, keys)\n",
-    "    \n",
-    "    return state_seq, obs_seq"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 70,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "#@partial(jit, static_argnums=(1,))\n",
-    "def hmm_sample2(rng_key, params, seq_len):\n",
-    "\n",
-    "    trans_mat, obs_mat, init_dist = params.trans_mat, params.obs_mat, params.init_dist\n",
-    "    n_states, n_obs = obs_mat.shape\n",
-    "\n",
-    "    def draw_state(prev_state, key):\n",
-    "        state = jax.random.choice(key, n_states, p=trans_mat[prev_state])\n",
-    "        return state, state\n",
-    "\n",
-    "    rng_key, rng_state, rng_obs = jax.random.split(rng_key, 3)\n",
-    "    keys = jax.random.split(rng_state, seq_len - 1)\n",
-    "    initial_state = jax.random.choice(rng_key, n_states, p=init_dist)\n",
-    "    final_state, states = jax.lax.scan(draw_state, initial_state, keys)\n",
-    "    state_seq = jnp.append(jnp.array([initial_state]), states)\n",
-    "\n",
-    "    def draw_obs(z, key):\n",
-    "        obs = jax.random.choice(key, n_obs, p=obs_mat[z])\n",
-    "        return obs\n",
-    "\n",
-    "    keys = jax.random.split(rng_obs, seq_len)\n",
-    "    obs_seq = jax.vmap(draw_obs, in_axes=(0, 0))(state_seq, keys)\n",
-    "\n",
-    "    return state_seq, obs_seq"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 93,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "[1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
-      " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1\n",
-      " 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
-      "[5 5 2 2 0 0 0 1 3 3 2 2 5 1 5 1 0 2 2 4 2 5 1 5 5 0 0 4 2 4 3 2 3 4 1 0 5\n",
-      " 2 2 2 1 4 3 2 2 2 4 1 0 3 5 2 5 1 4 2 5 2 5 0 5 4 4 4 2 2 0 4 5 2 2 0 1 5\n",
-      " 1 3 4 5 1 5 0 5 1 5 1 2 4 5 3 4 5 4 0 4 0 2 4 5 3 3]\n"
-     ]
-    }
-   ],
-   "source": [
-    "\n",
-    "key = PRNGKey(2)\n",
-    "seq_len = 100\n",
-    "\n",
-    "state_seq, obs_seq = hmm_sample(key, params, seq_len)\n",
-    "print(state_seq)\n",
-    "print(obs_seq)"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "### Check correctness by computing empirical pairwise statistics\n",
-    "\n",
-    "We will compute the number of i->j transitions, and check that it is close to the true \n",
-    "A[i,j] transition probabilites."
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 107,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "[0 0 1 1 1 1 0 0 1 1 1 0 1 0 0 1 1 1 1 0 1 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 1\n",
-      " 1 0 0 0 0 1 0 1 0 0 0 0 1 0 0 1 1 0 1 1 0 1 1 0 1 1 1 0 0 1 1 0 1 0 0 1 0\n",
-      " 1 0 0 0 0 0 0 0 0 0 0 1 1 1 1 0 0 0 1 1 1 0 1 0 0 0 0 1 0 0 0 0 1 1 0 0 0\n",
-      " 0 0 1 1 1 1 1 1 0 0 0 1 1 0 0 0 0 0 1 0 0 0 1 0 1 1 0 1 1 0 0 0 0 0 0 1 0\n",
-      " 0 0 0 1 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 1 1 1 0 1 1 0 0 0 0 0 1 0 0 0 0\n",
-      " 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 1 0 0 0 1 1 1 0 0 0 1 1 0 0 0\n",
-      " 0 0 0 1 1 1 0 0 0 0 1 0 0 1 1 1 0 1 1 1 1 1 0 1 1 0 0 0 1 1 0 1 0 0 1 0 0\n",
-      " 0 0 0 1 0 0 0 1 0 1 0 0 0 0 1 0 0 1 0 0 0 1 1 0 0 0 0 0 0 0 1 0 0 1 1 1 1\n",
-      " 1 1 0 0 0 0 0 1 1 0 0 0 0 0 0 1 0 0 1 0 0 0 0 0 0 1 0 0 0 0 1 0 1 0 0 0 1\n",
-      " 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 1 0 0 1 1 0 1 0 0 0\n",
-      " 0 0 0 0 0 0 0 1 0 0 1 1 1 1 0 0 1 1 0 0 0 0 1 1 0 1 1 0 0 0 0 0 0 0 0 1 0\n",
-      " 1 0 1 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0\n",
-      " 0 0 0 0 1 0 0 1 1 0 1 1 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 0 1 1 0 0 1 1 0 0 1\n",
-      " 1 0 0 0 0 0 0 0 1 0 0 1 1 0 0 0 0 1 1]\n",
-      "[[244.  93.]\n",
-      " [ 92.  70.]]\n",
-      "[[0.7240356  0.27596438]\n",
-      " [0.56790125 0.43209878]]\n"
-     ]
-    }
-   ],
-   "source": [
-    "import collections\n",
-    "def compute_counts(state_seq, nstates):\n",
-    "    wseq = np.array(state_seq)\n",
-    "    word_pairs = [pair for pair in zip(wseq[:-1], wseq[1:])]\n",
-    "    counter_pairs = collections.Counter(word_pairs)\n",
-    "    counts = np.zeros((nstates, nstates))\n",
-    "    for (k,v) in counter_pairs.items():\n",
-    "        counts[k[0], k[1]] = v\n",
-    "    return counts\n",
-    "\n",
-    "def normalize_counts(counts):\n",
-    "    ncounts = vmap(lambda v: normalize(v)[0], in_axes=0)(counts)\n",
-    "    return ncounts\n",
-    "\n",
-    "init_dist = jnp.array([1.0, 0.0])\n",
-    "trans_mat = jnp.array([[0.7, 0.3], [0.5, 0.5]])\n",
-    "rng_key = jax.random.PRNGKey(0)\n",
-    "seq_len = 500\n",
-    "state_seq = markov_chain_sample(rng_key, init_dist, trans_mat, seq_len)\n",
-    "print(state_seq)\n",
-    "\n",
-    "counts = compute_counts(state_seq, nstates=2)\n",
-    "print(counts)\n",
-    "\n",
-    "trans_mat_empirical = normalize_counts(counts)\n",
-    "print(trans_mat_empirical)\n",
-    "\n",
-    "assert jnp.allclose(trans_mat, trans_mat_empirical, atol=1e-1)\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": []
-  }
- ],
- "metadata": {
-  "kernelspec": {
-   "display_name": "Python 3",
-   "language": "python",
-   "name": "python3"
-  },
-  "language_info": {
-   "codemirror_mode": {
-    "name": "ipython",
-    "version": 3
-   },
-   "file_extension": ".py",
-   "mimetype": "text/x-python",
-   "name": "python",
-   "nbconvert_exporter": "python",
-   "pygments_lexer": "ipython3",
-   "version": "3.8.8"
-  }
- },
- "nbformat": 4,
- "nbformat_minor": 4
-}

+ 253 - 405
chapters/hmm/hmm_filter.ipynb

@@ -4,171 +4,8 @@
    "cell_type": "markdown",
    "cell_type": "markdown",
    "metadata": {},
    "metadata": {},
    "source": [
    "source": [
-    "```{math}\n",
-    "\n",
-    "\\newcommand{\\defeq}{\\triangleq}\n",
-    "\\newcommand{\\trans}{{\\mkern-1.5mu\\mathsf{T}}}\n",
-    "\\newcommand{\\transpose}[1]{{#1}^{\\trans}}\n",
-    "\n",
-    "\\newcommand{\\inv}[1]{{#1}^{-1}}\n",
-    "\\DeclareMathOperator{\\dotstar}{\\odot}\n",
-    "\n",
-    "\n",
-    "\\newcommand\\floor[1]{\\lfloor#1\\rfloor}\n",
-    "\n",
-    "\\newcommand{\\real}{\\mathbb{R}}\n",
-    "\n",
-    "% Numbers\n",
-    "\\newcommand{\\vzero}{\\boldsymbol{0}}\n",
-    "\\newcommand{\\vone}{\\boldsymbol{1}}\n",
-    "\n",
-    "% Greek https://www.latex-tutorial.com/symbols/greek-alphabet/\n",
-    "\\newcommand{\\valpha}{\\boldsymbol{\\alpha}}\n",
-    "\\newcommand{\\vbeta}{\\boldsymbol{\\beta}}\n",
-    "\\newcommand{\\vchi}{\\boldsymbol{\\chi}}\n",
-    "\\newcommand{\\vdelta}{\\boldsymbol{\\delta}}\n",
-    "\\newcommand{\\vDelta}{\\boldsymbol{\\Delta}}\n",
-    "\\newcommand{\\vepsilon}{\\boldsymbol{\\epsilon}}\n",
-    "\\newcommand{\\vzeta}{\\boldsymbol{\\zeta}}\n",
-    "\\newcommand{\\vXi}{\\boldsymbol{\\Xi}}\n",
-    "\\newcommand{\\vell}{\\boldsymbol{\\ell}}\n",
-    "\\newcommand{\\veta}{\\boldsymbol{\\eta}}\n",
-    "%\\newcommand{\\vEta}{\\boldsymbol{\\Eta}}\n",
-    "\\newcommand{\\vgamma}{\\boldsymbol{\\gamma}}\n",
-    "\\newcommand{\\vGamma}{\\boldsymbol{\\Gamma}}\n",
-    "\\newcommand{\\vmu}{\\boldsymbol{\\mu}}\n",
-    "\\newcommand{\\vmut}{\\boldsymbol{\\tilde{\\mu}}}\n",
-    "\\newcommand{\\vnu}{\\boldsymbol{\\nu}}\n",
-    "\\newcommand{\\vkappa}{\\boldsymbol{\\kappa}}\n",
-    "\\newcommand{\\vlambda}{\\boldsymbol{\\lambda}}\n",
-    "\\newcommand{\\vLambda}{\\boldsymbol{\\Lambda}}\n",
-    "\\newcommand{\\vLambdaBar}{\\overline{\\vLambda}}\n",
-    "%\\newcommand{\\vnu}{\\boldsymbol{\\nu}}\n",
-    "\\newcommand{\\vomega}{\\boldsymbol{\\omega}}\n",
-    "\\newcommand{\\vOmega}{\\boldsymbol{\\Omega}}\n",
-    "\\newcommand{\\vphi}{\\boldsymbol{\\phi}}\n",
-    "\\newcommand{\\vvarphi}{\\boldsymbol{\\varphi}}\n",
-    "\\newcommand{\\vPhi}{\\boldsymbol{\\Phi}}\n",
-    "\\newcommand{\\vpi}{\\boldsymbol{\\pi}}\n",
-    "\\newcommand{\\vPi}{\\boldsymbol{\\Pi}}\n",
-    "\\newcommand{\\vpsi}{\\boldsymbol{\\psi}}\n",
-    "\\newcommand{\\vPsi}{\\boldsymbol{\\Psi}}\n",
-    "\\newcommand{\\vrho}{\\boldsymbol{\\rho}}\n",
-    "\\newcommand{\\vtheta}{\\boldsymbol{\\theta}}\n",
-    "\\newcommand{\\vthetat}{\\boldsymbol{\\tilde{\\theta}}}\n",
-    "\\newcommand{\\vTheta}{\\boldsymbol{\\Theta}}\n",
-    "\\newcommand{\\vsigma}{\\boldsymbol{\\sigma}}\n",
-    "\\newcommand{\\vSigma}{\\boldsymbol{\\Sigma}}\n",
-    "\\newcommand{\\vSigmat}{\\boldsymbol{\\tilde{\\Sigma}}}\n",
-    "\\newcommand{\\vsigmoid}{\\vsigma}\n",
-    "\\newcommand{\\vtau}{\\boldsymbol{\\tau}}\n",
-    "\\newcommand{\\vxi}{\\boldsymbol{\\xi}}\n",
-    "\n",
-    "\n",
-    "% Lower Roman (Vectors)\n",
-    "\\newcommand{\\va}{\\mathbf{a}}\n",
-    "\\newcommand{\\vb}{\\mathbf{b}}\n",
-    "\\newcommand{\\vBt}{\\mathbf{\\tilde{B}}}\n",
-    "\\newcommand{\\vc}{\\mathbf{c}}\n",
-    "\\newcommand{\\vct}{\\mathbf{\\tilde{c}}}\n",
-    "\\newcommand{\\vd}{\\mathbf{d}}\n",
-    "\\newcommand{\\ve}{\\mathbf{e}}\n",
-    "\\newcommand{\\vf}{\\mathbf{f}}\n",
-    "\\newcommand{\\vg}{\\mathbf{g}}\n",
-    "\\newcommand{\\vh}{\\mathbf{h}}\n",
-    "%\\newcommand{\\myvh}{\\mathbf{h}}\n",
-    "\\newcommand{\\vi}{\\mathbf{i}}\n",
-    "\\newcommand{\\vj}{\\mathbf{j}}\n",
-    "\\newcommand{\\vk}{\\mathbf{k}}\n",
-    "\\newcommand{\\vl}{\\mathbf{l}}\n",
-    "\\newcommand{\\vm}{\\mathbf{m}}\n",
-    "\\newcommand{\\vn}{\\mathbf{n}}\n",
-    "\\newcommand{\\vo}{\\mathbf{o}}\n",
-    "\\newcommand{\\vp}{\\mathbf{p}}\n",
-    "\\newcommand{\\vq}{\\mathbf{q}}\n",
-    "\\newcommand{\\vr}{\\mathbf{r}}\n",
-    "\\newcommand{\\vs}{\\mathbf{s}}\n",
-    "\\newcommand{\\vt}{\\mathbf{t}}\n",
-    "\\newcommand{\\vu}{\\mathbf{u}}\n",
-    "\\newcommand{\\vv}{\\mathbf{v}}\n",
-    "\\newcommand{\\vw}{\\mathbf{w}}\n",
-    "\\newcommand{\\vws}{\\vw_s}\n",
-    "\\newcommand{\\vwt}{\\mathbf{\\tilde{w}}}\n",
-    "\\newcommand{\\vWt}{\\mathbf{\\tilde{W}}}\n",
-    "\\newcommand{\\vwh}{\\hat{\\vw}}\n",
-    "\\newcommand{\\vx}{\\mathbf{x}}\n",
-    "%\\newcommand{\\vx}{\\mathbf{x}}\n",
-    "\\newcommand{\\vxt}{\\mathbf{\\tilde{x}}}\n",
-    "\\newcommand{\\vy}{\\mathbf{y}}\n",
-    "\\newcommand{\\vyt}{\\mathbf{\\tilde{y}}}\n",
-    "\\newcommand{\\vz}{\\mathbf{z}}\n",
-    "%\\newcommand{\\vzt}{\\mathbf{\\tilde{z}}}\n",
-    "\n",
-    "\n",
-    "% Upper Roman (Matrices)\n",
-    "\\newcommand{\\vA}{\\mathbf{A}}\n",
-    "\\newcommand{\\vB}{\\mathbf{B}}\n",
-    "\\newcommand{\\vC}{\\mathbf{C}}\n",
-    "\\newcommand{\\vD}{\\mathbf{D}}\n",
-    "\\newcommand{\\vE}{\\mathbf{E}}\n",
-    "\\newcommand{\\vF}{\\mathbf{F}}\n",
-    "\\newcommand{\\vG}{\\mathbf{G}}\n",
-    "\\newcommand{\\vH}{\\mathbf{H}}\n",
-    "\\newcommand{\\vI}{\\mathbf{I}}\n",
-    "\\newcommand{\\vJ}{\\mathbf{J}}\n",
-    "\\newcommand{\\vK}{\\mathbf{K}}\n",
-    "\\newcommand{\\vL}{\\mathbf{L}}\n",
-    "\\newcommand{\\vM}{\\mathbf{M}}\n",
-    "\\newcommand{\\vMt}{\\mathbf{\\tilde{M}}}\n",
-    "\\newcommand{\\vN}{\\mathbf{N}}\n",
-    "\\newcommand{\\vO}{\\mathbf{O}}\n",
-    "\\newcommand{\\vP}{\\mathbf{P}}\n",
-    "\\newcommand{\\vQ}{\\mathbf{Q}}\n",
-    "\\newcommand{\\vR}{\\mathbf{R}}\n",
-    "\\newcommand{\\vS}{\\mathbf{S}}\n",
-    "\\newcommand{\\vT}{\\mathbf{T}}\n",
-    "\\newcommand{\\vU}{\\mathbf{U}}\n",
-    "\\newcommand{\\vV}{\\mathbf{V}}\n",
-    "\\newcommand{\\vW}{\\mathbf{W}}\n",
-    "\\newcommand{\\vX}{\\mathbf{X}}\n",
-    "%\\newcommand{\\vXs}{\\vX_{\\vs}}\n",
-    "\\newcommand{\\vXs}{\\vX_{s}}\n",
-    "\\newcommand{\\vXt}{\\mathbf{\\tilde{X}}}\n",
-    "\\newcommand{\\vY}{\\mathbf{Y}}\n",
-    "\\newcommand{\\vZ}{\\mathbf{Z}}\n",
-    "\\newcommand{\\vZt}{\\mathbf{\\tilde{Z}}}\n",
-    "\\newcommand{\\vzt}{\\mathbf{\\tilde{z}}}\n",
-    "\n",
-    "\n",
-    "%%%%\n",
-    "\\newcommand{\\hidden}{\\vz}\n",
-    "\\newcommand{\\hid}{\\hidden}\n",
-    "\\newcommand{\\observed}{\\vy}\n",
-    "\\newcommand{\\obs}{\\observed}\n",
-    "\\newcommand{\\inputs}{\\vu}\n",
-    "\\newcommand{\\input}{\\inputs}\n",
-    "\n",
-    "\\newcommand{\\hmmTrans}{\\vA}\n",
-    "\\newcommand{\\hmmObs}{\\vB}\n",
-    "\\newcommand{\\hmmInit}{\\vpi}\n",
-    "\n",
-    "\n",
-    "\\newcommand{\\ldsDyn}{\\vA}\n",
-    "\\newcommand{\\ldsObs}{\\vC}\n",
-    "\\newcommand{\\ldsDynIn}{\\vB}\n",
-    "\\newcommand{\\ldsObsIn}{\\vD}\n",
-    "\\newcommand{\\ldsDynNoise}{\\vQ}\n",
-    "\\newcommand{\\ldsObsNoise}{\\vR}\n",
-    "\n",
-    "\\newcommand{\\ssmDynFn}{f}\n",
-    "\\newcommand{\\ssmObsFn}{h}\n",
-    "\n",
-    "\n",
-    "%%%\n",
-    "\\newcommand{\\gauss}{\\mathcal{N}}\n",
-    "\n",
-    "\\newcommand{\\diag}{\\mathrm{diag}}\n",
-    "```\n"
+    "(sec:forwards)=\n",
+    "# HMM filtering (forwards algorithm)"
    ]
    ]
   },
   },
   {
   {
@@ -177,141 +14,85 @@
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
    "source": [
    "source": [
-    "# meta-data does not work yet in VScode\n",
-    "# https://github.com/microsoft/vscode-jupyter/issues/1121\n",
-    "\n",
-    "{\n",
-    "    \"tags\": [\n",
-    "        \"hide-cell\"\n",
-    "    ]\n",
-    "}\n",
-    "\n",
-    "\n",
-    "### Install necessary libraries\n",
-    "\n",
-    "try:\n",
-    "    import jax\n",
-    "except:\n",
-    "    # For cuda version, see https://github.com/google/jax#installation\n",
-    "    %pip install --upgrade \"jax[cpu]\" \n",
-    "    import jax\n",
-    "\n",
-    "try:\n",
-    "    import distrax\n",
-    "except:\n",
-    "    %pip install --upgrade  distrax\n",
-    "    import distrax\n",
-    "\n",
-    "try:\n",
-    "    import jsl\n",
-    "except:\n",
-    "    %pip install git+https://github.com/probml/jsl\n",
-    "    import jsl\n",
-    "\n",
-    "#try:\n",
-    "#    import ssm_jax\n",
-    "##except:\n",
-    "#    %pip install git+https://github.com/probml/ssm-jax\n",
-    "#    import ssm_jax\n",
-    "\n",
-    "try:\n",
-    "    import rich\n",
-    "except:\n",
-    "    %pip install rich\n",
-    "    import rich\n",
-    "\n",
-    "\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 4,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "{\n",
-    "    \"tags\": [\n",
-    "        \"hide-cell\"\n",
-    "    ]\n",
-    "}\n",
-    "\n",
-    "\n",
     "### Import standard libraries\n",
     "### Import standard libraries\n",
     "\n",
     "\n",
     "import abc\n",
     "import abc\n",
     "from dataclasses import dataclass\n",
     "from dataclasses import dataclass\n",
     "import functools\n",
     "import functools\n",
+    "from functools import partial\n",
     "import itertools\n",
     "import itertools\n",
-    "\n",
-    "from typing import Any, Callable, NamedTuple, Optional, Union, Tuple\n",
-    "\n",
     "import matplotlib.pyplot as plt\n",
     "import matplotlib.pyplot as plt\n",
     "import numpy as np\n",
     "import numpy as np\n",
-    "\n",
+    "from typing import Any, Callable, NamedTuple, Optional, Union, Tuple\n",
     "\n",
     "\n",
     "import jax\n",
     "import jax\n",
     "import jax.numpy as jnp\n",
     "import jax.numpy as jnp\n",
     "from jax import lax, vmap, jit, grad\n",
     "from jax import lax, vmap, jit, grad\n",
-    "from jax.scipy.special import logit\n",
-    "from jax.nn import softmax\n",
-    "from functools import partial\n",
-    "from jax.random import PRNGKey, split\n",
+    "#from jax.scipy.special import logit\n",
+    "#from jax.nn import softmax\n",
+    "import jax.random as jr\n",
     "\n",
     "\n",
-    "import inspect\n",
-    "import inspect as py_inspect\n",
-    "import rich\n",
-    "from rich import inspect as r_inspect\n",
-    "from rich import print as r_print\n",
     "\n",
     "\n",
-    "def print_source(fname):\n",
-    "    r_print(py_inspect.getsource(fname))"
+    "\n",
+    "import distrax\n",
+    "import optax\n",
+    "\n",
+    "import jsl\n",
+    "import ssm_jax"
    ]
    ]
   },
   },
   {
   {
    "cell_type": "markdown",
    "cell_type": "markdown",
    "metadata": {},
    "metadata": {},
    "source": [
    "source": [
-    "(sec:forwards)=\n",
-    "# HMM filtering (forwards algorithm)\n",
+    "\n",
+    "## Introduction\n",
     "\n",
     "\n",
     "\n",
     "\n",
-    "The  **Bayes filter** is an algorithm for recursively computing\n",
+    "The  $\\keyword{Bayes filter}$ is an algorithm for recursively computing\n",
     "the belief state\n",
     "the belief state\n",
     "$p(\\hidden_t|\\obs_{1:t})$ given\n",
     "$p(\\hidden_t|\\obs_{1:t})$ given\n",
     "the prior belief from the previous step,\n",
     "the prior belief from the previous step,\n",
     "$p(\\hidden_{t-1}|\\obs_{1:t-1})$,\n",
     "$p(\\hidden_{t-1}|\\obs_{1:t-1})$,\n",
     "the new observation $\\obs_t$,\n",
     "the new observation $\\obs_t$,\n",
     "and the model.\n",
     "and the model.\n",
-    "This can be done using **sequential Bayesian updating**.\n",
+    "This can be done using $\\keyword{sequential Bayesian updating}$.\n",
     "For a dynamical model, this reduces to the\n",
     "For a dynamical model, this reduces to the\n",
-    "**predict-update** cycle described below.\n",
+    "$\\keyword{predict-update}$ cycle described below.\n",
     "\n",
     "\n",
-    "\n",
-    "The **prediction step** is just the **Chapman-Kolmogorov equation**:\n",
-    "```{math}\n",
+    "The $\\keyword{prediction step}$ is just the $\\keyword{Chapman-Kolmogorov equation}$:\n",
+    "\\begin{align}\n",
     "p(\\hidden_t|\\obs_{1:t-1})\n",
     "p(\\hidden_t|\\obs_{1:t-1})\n",
     "= \\int p(\\hidden_t|\\hidden_{t-1}) p(\\hidden_{t-1}|\\obs_{1:t-1}) d\\hidden_{t-1}\n",
     "= \\int p(\\hidden_t|\\hidden_{t-1}) p(\\hidden_{t-1}|\\obs_{1:t-1}) d\\hidden_{t-1}\n",
-    "```\n",
+    "\\end{align}\n",
     "The prediction step computes\n",
     "The prediction step computes\n",
-    "the one-step-ahead predictive distribution\n",
-    "for the latent state, which updates\n",
-    "the posterior from the previous time step into the prior\n",
+    "the $\\keyword{one-step-ahead predictive distribution}$\n",
+    "for the latent state, which converts\n",
+    "the posterior from the previous time step to become the prior\n",
     "for the current step.\n",
     "for the current step.\n",
     "\n",
     "\n",
     "\n",
     "\n",
-    "The **update step**\n",
+    "The $\\keyword{update step}$\n",
     "is just Bayes rule:\n",
     "is just Bayes rule:\n",
-    "```{math}\n",
+    "\\begin{align}\n",
     "p(\\hidden_t|\\obs_{1:t}) = \\frac{1}{Z_t}\n",
     "p(\\hidden_t|\\obs_{1:t}) = \\frac{1}{Z_t}\n",
     "p(\\obs_t|\\hidden_t) p(\\hidden_t|\\obs_{1:t-1})\n",
     "p(\\obs_t|\\hidden_t) p(\\hidden_t|\\obs_{1:t-1})\n",
-    "```\n",
+    "\\end{align}\n",
     "where the normalization constant is\n",
     "where the normalization constant is\n",
-    "```{math}\n",
+    "\\begin{align}\n",
     "Z_t = \\int p(\\obs_t|\\hidden_t) p(\\hidden_t|\\obs_{1:t-1}) d\\hidden_{t}\n",
     "Z_t = \\int p(\\obs_t|\\hidden_t) p(\\hidden_t|\\obs_{1:t-1}) d\\hidden_{t}\n",
     "= p(\\obs_t|\\obs_{1:t-1})\n",
     "= p(\\obs_t|\\obs_{1:t-1})\n",
-    "```\n",
+    "\\end{align}\n",
     "\n",
     "\n",
+    "Note that we can derive the log marginal likelihood from these normalization constants\n",
+    "as follows:\n",
+    "```{math}\n",
+    ":label: eqn:logZ\n",
     "\n",
     "\n",
+    "\\log p(\\obs_{1:T})\n",
+    "= \\sum_{t=1}^{T} \\log p(\\obs_t|\\obs_{1:t-1})\n",
+    "= \\sum_{t=1}^{T} \\log Z_t\n",
+    "```\n",
     "\n"
     "\n"
    ]
    ]
   },
   },
@@ -323,10 +104,11 @@
     "When the latent states $\\hidden_t$ are discrete, as in HMM,\n",
     "When the latent states $\\hidden_t$ are discrete, as in HMM,\n",
     "the above integrals become sums.\n",
     "the above integrals become sums.\n",
     "In particular, suppose we define\n",
     "In particular, suppose we define\n",
-    "the belief state as $\\alpha_t(j) \\defeq p(\\hidden_t=j|\\obs_{1:t})$,\n",
-    "the local evidence as $\\lambda_t(j) \\defeq p(\\obs_t|\\hidden_t=j)$,\n",
-    "and the transition matrix\n",
-    "$A(i,j)  = p(\\hidden_t=j|\\hidden_{t-1}=i)$.\n",
+    "the $\\keyword{belief state}$ as $\\alpha_t(j) \\defeq p(\\hidden_t=j|\\obs_{1:t})$,\n",
+    "the  $\\keyword{local evidence}$ (or $\\keyword{local likelihood}$)\n",
+    "as $\\lambda_t(j) \\defeq p(\\obs_t|\\hidden_t=j)$,\n",
+    "and the transition matrix as\n",
+    "$\\hmmTrans(i,j)  = p(\\hidden_t=j|\\hidden_{t-1}=i)$.\n",
     "Then the predict step becomes\n",
     "Then the predict step becomes\n",
     "```{math}\n",
     "```{math}\n",
     ":label: eqn:predictiveHMM\n",
     ":label: eqn:predictiveHMM\n",
@@ -338,7 +120,7 @@
     ":label: eqn:fwdsEqn\n",
     ":label: eqn:fwdsEqn\n",
     "\\alpha_t(j)\n",
     "\\alpha_t(j)\n",
     "= \\frac{1}{Z_t} \\lambda_t(j) \\alpha_{t|t-1}(j)\n",
     "= \\frac{1}{Z_t} \\lambda_t(j) \\alpha_{t|t-1}(j)\n",
-    "= \\frac{1}{Z_t} \\lambda_t(j) \\left[\\sum_i \\alpha_{t-1}(i) A(i,j)  \\right]\n",
+    "= \\frac{1}{Z_t} \\lambda_t(j) \\left[\\sum_i \\alpha_{t-1}(i) \\hmmTrans(i,j)  \\right]\n",
     "```\n",
     "```\n",
     "where\n",
     "where\n",
     "the  normalization constant for each time step is given by\n",
     "the  normalization constant for each time step is given by\n",
@@ -350,20 +132,33 @@
     "&=  \\sum_{j=1}^K \\lambda_t(j) \\alpha_{t|t-1}(j)\n",
     "&=  \\sum_{j=1}^K \\lambda_t(j) \\alpha_{t|t-1}(j)\n",
     "\\end{align}\n",
     "\\end{align}\n",
     "```\n",
     "```\n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
     "\n",
     "\n",
     "Since all the quantities are finite length vectors and matrices,\n",
     "Since all the quantities are finite length vectors and matrices,\n",
-    "we can write the update equation\n",
-    "in matrix-vector notation as follows:\n",
+    "we can implement the whole procedure using matrix vector multoplication:\n",
     "```{math}\n",
     "```{math}\n",
+    ":label: eqn:fwdsAlgoMatrixForm\n",
     "\\valpha_t =\\text{normalize}\\left(\n",
     "\\valpha_t =\\text{normalize}\\left(\n",
-    "\\vlambda_t \\dotstar  (\\vA^{\\trans} \\valpha_{t-1}) \\right)\n",
-    "\\label{eqn:fwdsAlgoMatrixForm}\n",
+    "\\vlambda_t \\dotstar  (\\hmmTrans^{\\trans} \\valpha_{t-1}) \\right)\n",
     "```\n",
     "```\n",
     "where $\\dotstar$ represents\n",
     "where $\\dotstar$ represents\n",
     "elementwise vector multiplication,\n",
     "elementwise vector multiplication,\n",
-    "and the $\\text{normalize}$ function just ensures its argument sums to one.\n",
+    "and the $\\text{normalize}$ function just ensures its argument sums to one.\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Example\n",
     "\n",
     "\n",
-    "In {ref}(sec:casino-inference)\n",
+    "In {ref}`sec:casino-inference`\n",
     "we illustrate\n",
     "we illustrate\n",
     "filtering for the casino HMM,\n",
     "filtering for the casino HMM,\n",
     "applied to a random sequence $\\obs_{1:T}$ of length $T=300$.\n",
     "applied to a random sequence $\\obs_{1:T}$ of length $T=300$.\n",
@@ -371,156 +166,209 @@
     "based on the evidence seen so far.\n",
     "based on the evidence seen so far.\n",
     "The gray bars indicate time intervals during which the generative\n",
     "The gray bars indicate time intervals during which the generative\n",
     "process actually switched to the loaded dice.\n",
     "process actually switched to the loaded dice.\n",
-    "We see that the probability generally increases in the right places.\n"
+    "We see that the probability generally increases in the right places."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Normalization constants\n",
+    "\n",
+    "In most publications on HMMs,\n",
+    "such as {cite}`Rabiner89`,\n",
+    "the forwards message is defined\n",
+    "as the following unnormalized joint probability:\n",
+    "```{math}\n",
+    "\\alpha'_t(j) = p(\\hidden_t=j,\\obs_{1:t}) \n",
+    "= \\lambda_t(j) \\left[\\sum_i \\alpha'_{t-1}(i) A(i,j)  \\right]\n",
+    "```\n",
+    "In this book we define the forwards message   as the normalized\n",
+    "conditional probability\n",
+    "```{math}\n",
+    "\\alpha_t(j) = p(\\hidden_t=j|\\obs_{1:t}) \n",
+    "= \\frac{1}{Z_t} \\lambda_t(j) \\left[\\sum_i \\alpha_{t-1}(i) A(i,j)  \\right]\n",
+    "```\n",
+    "where $Z_t = p(\\obs_t|\\obs_{1:t-1})$.\n",
+    "\n",
+    "The \"traditional\" unnormalized form has several problems.\n",
+    "First, it rapidly suffers from numerical underflow,\n",
+    "since the probability of\n",
+    "the joint event that $(\\hidden_t=j,\\obs_{1:t})$\n",
+    "is vanishingly small. \n",
+    "To see why, suppose the observations are independent of the states.\n",
+    "In this case, the unnormalized joint has the form\n",
+    "\\begin{align}\n",
+    "p(\\hidden_t=j,\\obs_{1:t}) = p(\\hidden_t=j)\\prod_{i=1}^t p(\\obs_i)\n",
+    "\\end{align}\n",
+    "which becomes exponentially small with $t$, because we multiply\n",
+    "many probabilities which are less than one.\n",
+    "Second, the unnormalized probability is less interpretable,\n",
+    "since it is a joint distribution over states and observations,\n",
+    "rather than a conditional probability of states given observations.\n",
+    "Third, the unnormalized joint form is harder to approximate\n",
+    "than the normalized form.\n",
+    "Of course,\n",
+    "the two definitions only differ by a\n",
+    "multiplicative constant\n",
+    "{cite}`Devijver85`,\n",
+    "so the algorithmic difference is just\n",
+    "one line of code (namely the presence or absence of a call to the `normalize` function).\n",
+    "\n",
+    "\n",
+    "\n",
+    "\n"
    ]
    ]
   },
   },
   {
   {
    "cell_type": "markdown",
    "cell_type": "markdown",
    "metadata": {},
    "metadata": {},
    "source": [
    "source": [
-    "Here is a JAX implementation of the forwards algorithm."
+    "## Naive implementation\n",
+    "\n",
+    "Below we give a simple numpy implementation of the forwards algorithm.\n",
+    "We assume the HMM uses categorical observations, for simplicity.\n",
+    "\n"
    ]
    ]
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 5,
+   "execution_count": null,
    "metadata": {},
    "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/html": [
-       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">@jit\n",
-       "def <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">hmm_forwards_jax</span><span style=\"font-weight: bold\">(</span>params, obs_seq, <span style=\"color: #808000; text-decoration-color: #808000\">length</span>=<span style=\"color: #800080; text-decoration-color: #800080; font-style: italic\">None</span><span style=\"font-weight: bold\">)</span>:\n",
-       "    <span style=\"color: #008000; text-decoration-color: #008000\">''</span>'\n",
-       "    Calculates a belief state\n",
-       "\n",
-       "    Parameters\n",
-       "    ----------\n",
-       "    params : HMMJax\n",
-       "        Hidden Markov Model\n",
-       "\n",
-       "    obs_seq: <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">array</span><span style=\"font-weight: bold\">(</span>seq_len<span style=\"font-weight: bold\">)</span>\n",
-       "        History of observable events\n",
-       "\n",
-       "    Returns\n",
-       "    -------\n",
-       "    * float\n",
-       "        The loglikelihood giving <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">log</span><span style=\"font-weight: bold\">(</span><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">p</span><span style=\"font-weight: bold\">(</span>x|model<span style=\"font-weight: bold\">))</span>\n",
-       "\n",
-       "    * <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">array</span><span style=\"font-weight: bold\">(</span>seq_len, n_hidden<span style=\"font-weight: bold\">)</span> :\n",
-       "        All alpha values found for each sample\n",
-       "    <span style=\"color: #008000; text-decoration-color: #008000\">''</span>'\n",
-       "    seq_len = <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">len</span><span style=\"font-weight: bold\">(</span>obs_seq<span style=\"font-weight: bold\">)</span>\n",
-       "\n",
-       "    if length is <span style=\"color: #800080; text-decoration-color: #800080; font-style: italic\">None</span>:\n",
-       "        length = seq_len\n",
-       "\n",
-       "    trans_mat, obs_mat, init_dist = params.trans_mat, params.obs_mat, params.init_dist\n",
-       "\n",
-       "    trans_mat = <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">jnp.array</span><span style=\"font-weight: bold\">(</span>trans_mat<span style=\"font-weight: bold\">)</span>\n",
-       "    obs_mat = <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">jnp.array</span><span style=\"font-weight: bold\">(</span>obs_mat<span style=\"font-weight: bold\">)</span>\n",
-       "    init_dist = <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">jnp.array</span><span style=\"font-weight: bold\">(</span>init_dist<span style=\"font-weight: bold\">)</span>\n",
-       "\n",
-       "    n_states, n_obs = obs_mat.shape\n",
-       "\n",
-       "    def <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">scan_fn</span><span style=\"font-weight: bold\">(</span>carry, t<span style=\"font-weight: bold\">)</span>:\n",
-       "        <span style=\"font-weight: bold\">(</span>alpha_prev, log_ll_prev<span style=\"font-weight: bold\">)</span> = carry\n",
-       "        alpha_n = <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">jnp.where</span><span style=\"font-weight: bold\">(</span>t &lt; length,\n",
-       "                            obs_mat<span style=\"font-weight: bold\">[</span>:, obs_seq<span style=\"font-weight: bold\">]</span> * <span style=\"font-weight: bold\">(</span>alpha_prev<span style=\"font-weight: bold\">[</span>:, <span style=\"color: #800080; text-decoration-color: #800080; font-style: italic\">None</span><span style=\"font-weight: bold\">]</span> * \n",
-       "trans_mat<span style=\"font-weight: bold\">)</span><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">.sum</span><span style=\"font-weight: bold\">(</span><span style=\"color: #808000; text-decoration-color: #808000\">axis</span>=<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span><span style=\"font-weight: bold\">)</span>,\n",
-       "                            <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">jnp.zeros_like</span><span style=\"font-weight: bold\">(</span>alpha_prev<span style=\"font-weight: bold\">))</span>\n",
-       "\n",
-       "        alpha_n, cn = <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">normalize</span><span style=\"font-weight: bold\">(</span>alpha_n<span style=\"font-weight: bold\">)</span>\n",
-       "        carry = <span style=\"font-weight: bold\">(</span>alpha_n, <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">jnp.log</span><span style=\"font-weight: bold\">(</span>cn<span style=\"font-weight: bold\">)</span> + log_ll_prev<span style=\"font-weight: bold\">)</span>\n",
-       "\n",
-       "        return carry, alpha_n\n",
-       "\n",
-       "    # initial belief state\n",
-       "    alpha_0, c0 = <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">normalize</span><span style=\"font-weight: bold\">(</span>init_dist * obs_mat<span style=\"font-weight: bold\">[</span>:, obs_seq<span style=\"font-weight: bold\">[</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span><span style=\"font-weight: bold\">]])</span>\n",
-       "\n",
-       "    # setup scan loop\n",
-       "    init_state = <span style=\"font-weight: bold\">(</span>alpha_0, <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">jnp.log</span><span style=\"font-weight: bold\">(</span>c0<span style=\"font-weight: bold\">))</span>\n",
-       "    ts = <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">jnp.arange</span><span style=\"font-weight: bold\">(</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span>, seq_len<span style=\"font-weight: bold\">)</span>\n",
-       "    carry, alpha_hist = <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">lax.scan</span><span style=\"font-weight: bold\">(</span>scan_fn, init_state, ts<span style=\"font-weight: bold\">)</span>\n",
-       "\n",
-       "    # post-process\n",
-       "    alpha_hist = <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">jnp.vstack</span><span style=\"font-weight: bold\">()</span>\n",
-       "    <span style=\"font-weight: bold\">(</span>alpha_final, log_ll<span style=\"font-weight: bold\">)</span> = carry\n",
-       "    return log_ll, alpha_hist\n",
-       "\n",
-       "</pre>\n"
-      ],
-      "text/plain": [
-       "@jit\n",
-       "def \u001b[1;35mhmm_forwards_jax\u001b[0m\u001b[1m(\u001b[0mparams, obs_seq, \u001b[33mlength\u001b[0m=\u001b[3;35mNone\u001b[0m\u001b[1m)\u001b[0m:\n",
-       "    \u001b[32m''\u001b[0m'\n",
-       "    Calculates a belief state\n",
-       "\n",
-       "    Parameters\n",
-       "    ----------\n",
-       "    params : HMMJax\n",
-       "        Hidden Markov Model\n",
-       "\n",
-       "    obs_seq: \u001b[1;35marray\u001b[0m\u001b[1m(\u001b[0mseq_len\u001b[1m)\u001b[0m\n",
-       "        History of observable events\n",
-       "\n",
-       "    Returns\n",
-       "    -------\n",
-       "    * float\n",
-       "        The loglikelihood giving \u001b[1;35mlog\u001b[0m\u001b[1m(\u001b[0m\u001b[1;35mp\u001b[0m\u001b[1m(\u001b[0mx|model\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m\n",
-       "\n",
-       "    * \u001b[1;35marray\u001b[0m\u001b[1m(\u001b[0mseq_len, n_hidden\u001b[1m)\u001b[0m :\n",
-       "        All alpha values found for each sample\n",
-       "    \u001b[32m''\u001b[0m'\n",
-       "    seq_len = \u001b[1;35mlen\u001b[0m\u001b[1m(\u001b[0mobs_seq\u001b[1m)\u001b[0m\n",
-       "\n",
-       "    if length is \u001b[3;35mNone\u001b[0m:\n",
-       "        length = seq_len\n",
-       "\n",
-       "    trans_mat, obs_mat, init_dist = params.trans_mat, params.obs_mat, params.init_dist\n",
-       "\n",
-       "    trans_mat = \u001b[1;35mjnp.array\u001b[0m\u001b[1m(\u001b[0mtrans_mat\u001b[1m)\u001b[0m\n",
-       "    obs_mat = \u001b[1;35mjnp.array\u001b[0m\u001b[1m(\u001b[0mobs_mat\u001b[1m)\u001b[0m\n",
-       "    init_dist = \u001b[1;35mjnp.array\u001b[0m\u001b[1m(\u001b[0minit_dist\u001b[1m)\u001b[0m\n",
-       "\n",
-       "    n_states, n_obs = obs_mat.shape\n",
-       "\n",
-       "    def \u001b[1;35mscan_fn\u001b[0m\u001b[1m(\u001b[0mcarry, t\u001b[1m)\u001b[0m:\n",
-       "        \u001b[1m(\u001b[0malpha_prev, log_ll_prev\u001b[1m)\u001b[0m = carry\n",
-       "        alpha_n = \u001b[1;35mjnp.where\u001b[0m\u001b[1m(\u001b[0mt < length,\n",
-       "                            obs_mat\u001b[1m[\u001b[0m:, obs_seq\u001b[1m]\u001b[0m * \u001b[1m(\u001b[0malpha_prev\u001b[1m[\u001b[0m:, \u001b[3;35mNone\u001b[0m\u001b[1m]\u001b[0m * \n",
-       "trans_mat\u001b[1m)\u001b[0m\u001b[1;35m.sum\u001b[0m\u001b[1m(\u001b[0m\u001b[33maxis\u001b[0m=\u001b[1;36m0\u001b[0m\u001b[1m)\u001b[0m,\n",
-       "                            \u001b[1;35mjnp.zeros_like\u001b[0m\u001b[1m(\u001b[0malpha_prev\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m\n",
-       "\n",
-       "        alpha_n, cn = \u001b[1;35mnormalize\u001b[0m\u001b[1m(\u001b[0malpha_n\u001b[1m)\u001b[0m\n",
-       "        carry = \u001b[1m(\u001b[0malpha_n, \u001b[1;35mjnp.log\u001b[0m\u001b[1m(\u001b[0mcn\u001b[1m)\u001b[0m + log_ll_prev\u001b[1m)\u001b[0m\n",
-       "\n",
-       "        return carry, alpha_n\n",
-       "\n",
-       "    # initial belief state\n",
-       "    alpha_0, c0 = \u001b[1;35mnormalize\u001b[0m\u001b[1m(\u001b[0minit_dist * obs_mat\u001b[1m[\u001b[0m:, obs_seq\u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m\n",
-       "\n",
-       "    # setup scan loop\n",
-       "    init_state = \u001b[1m(\u001b[0malpha_0, \u001b[1;35mjnp.log\u001b[0m\u001b[1m(\u001b[0mc0\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m\n",
-       "    ts = \u001b[1;35mjnp.arange\u001b[0m\u001b[1m(\u001b[0m\u001b[1;36m1\u001b[0m, seq_len\u001b[1m)\u001b[0m\n",
-       "    carry, alpha_hist = \u001b[1;35mlax.scan\u001b[0m\u001b[1m(\u001b[0mscan_fn, init_state, ts\u001b[1m)\u001b[0m\n",
-       "\n",
-       "    # post-process\n",
-       "    alpha_hist = \u001b[1;35mjnp.vstack\u001b[0m\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\n",
-       "    \u001b[1m(\u001b[0malpha_final, log_ll\u001b[1m)\u001b[0m = carry\n",
-       "    return log_ll, alpha_hist\n",
-       "\n"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    }
-   ],
+   "outputs": [],
    "source": [
    "source": [
-    "import jsl.hmm.hmm_lib as hmm_lib\n",
-    "print_source(hmm_lib.hmm_forwards_jax)\n",
-    "#https://github.com/probml/JSL/blob/main/jsl/hmm/hmm_lib.py#L189\n",
-    "\n"
+    "\n",
+    "\n",
+    "def normalize_np(u, axis=0, eps=1e-15):\n",
+    "    u = np.where(u == 0, 0, np.where(u < eps, eps, u))\n",
+    "    c = u.sum(axis=axis)\n",
+    "    c = np.where(c == 0, 1, c)\n",
+    "    return u / c, c\n",
+    "\n",
+    "def hmm_forwards_np(trans_mat, obs_mat, init_dist, obs_seq):\n",
+    "    n_states, n_obs = obs_mat.shape\n",
+    "    seq_len = len(obs_seq)\n",
+    "\n",
+    "    alpha_hist = np.zeros((seq_len, n_states))\n",
+    "    ll_hist = np.zeros(seq_len)  # loglikelihood history\n",
+    "\n",
+    "    alpha_n = init_dist * obs_mat[:, obs_seq[0]]\n",
+    "    alpha_n, cn = normalize_np(alpha_n)\n",
+    "\n",
+    "    alpha_hist[0] = alpha_n\n",
+    "    log_normalizer = np.log(cn)\n",
+    "\n",
+    "    for t in range(1, seq_len):\n",
+    "        alpha_n = obs_mat[:, obs_seq[t]] * (alpha_n[:, None] * trans_mat).sum(axis=0)\n",
+    "        alpha_n, zn = normalize_np(alpha_n)\n",
+    "\n",
+    "        alpha_hist[t] = alpha_n\n",
+    "        log_normalizer = np.log(zn) + log_normalizer\n",
+    "\n",
+    "    return  log_normalizer, alpha_hist"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Numerically stable implementation \n",
+    "\n",
+    "\n",
+    "\n",
+    "In practice it is more numerically stable to compute\n",
+    "the log likelihoods $\\ell_t(j) = \\log p(\\obs_t|\\hidden_t=j)$,\n",
+    "rather than the likelioods $\\lambda_t(j) = p(\\obs_t|\\hidden_t=j)$.\n",
+    "In this case, we can perform the posterior updating in a numerically stable way as follows.\n",
+    "Define $L_t = \\max_j \\ell_t(j)$ and\n",
+    "\\begin{align}\n",
+    "\\tilde{p}(\\hidden_t=j,\\obs_t|\\obs_{1:t-1})\n",
+    "&\\defeq p(\\hidden_t=j|\\obs_{1:t-1}) p(\\obs_t|\\hidden_t=j) e^{-L_t} \\\\\n",
+    " &= p(\\hidden_t=j|\\obs_{1:t-1}) e^{\\ell_t(j) - L_t}\n",
+    "\\end{align}\n",
+    "Then we have\n",
+    "\\begin{align}\n",
+    "p(\\hidden_t=j|\\obs_t,\\obs_{1:t-1})\n",
+    "  &= \\frac{1}{\\tilde{Z}_t} \\tilde{p}(\\hidden_t=j,\\obs_t|\\obs_{1:t-1}) \\\\\n",
+    "\\tilde{Z}_t &= \\sum_j \\tilde{p}(\\hidden_t=j,\\obs_t|\\obs_{1:t-1})\n",
+    "= p(\\obs_t|\\obs_{1:t-1}) e^{-L_t} \\\\\n",
+    "\\log Z_t &= \\log p(\\obs_t|\\obs_{1:t-1}) = \\log \\tilde{Z}_t + L_t\n",
+    "\\end{align}\n",
+    "\n",
+    "Below we show some JAX code that implements this core operation.\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "\n",
+    "def _condition_on(probs, ll):\n",
+    "    ll_max = ll.max()\n",
+    "    new_probs = probs * jnp.exp(ll - ll_max)\n",
+    "    norm = new_probs.sum()\n",
+    "    new_probs /= norm\n",
+    "    log_norm = jnp.log(norm) + ll_max\n",
+    "    return new_probs, log_norm"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "With the above function, we can implement a more numerically stable version of the forwards filter,\n",
+    "that works for any likelihood function, as shown below. It takes in the prior predictive distribution,\n",
+    "$\\alpha_{t|t-1}$,\n",
+    "stored in `predicted_probs`, and conditions them on the log-likelihood for each time step $\\ell_t$ to get the\n",
+    "posterior, $\\alpha_t$, stored in `filtered_probs`,\n",
+    "which is then converted to the prediction for the next state, $\\alpha_{t+1|t}$."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def _predict(probs, A):\n",
+    "    return A.T @ probs\n",
+    "\n",
+    "\n",
+    "def hmm_filter(initial_distribution,\n",
+    "               transition_matrix,\n",
+    "               log_likelihoods):\n",
+    "    def _step(carry, t):\n",
+    "        log_normalizer, predicted_probs = carry\n",
+    "\n",
+    "        # Get parameters for time t\n",
+    "        get = lambda x: x[t] if x.ndim == 3 else x\n",
+    "        A = get(transition_matrix)\n",
+    "        ll = log_likelihoods[t]\n",
+    "\n",
+    "        # Condition on emissions at time t, being careful not to overflow\n",
+    "        filtered_probs, log_norm = _condition_on(predicted_probs, ll)\n",
+    "        # Update the log normalizer\n",
+    "        log_normalizer += log_norm\n",
+    "        # Predict the next state\n",
+    "        predicted_probs = _predict(filtered_probs, A)\n",
+    "\n",
+    "        return (log_normalizer, predicted_probs), (filtered_probs, predicted_probs)\n",
+    "\n",
+    "    num_timesteps = len(log_likelihoods)\n",
+    "    carry = (0.0, initial_distribution)\n",
+    "    (log_normalizer, _), (filtered_probs, predicted_probs) = lax.scan(\n",
+    "        _step, carry, jnp.arange(num_timesteps))\n",
+    "    return log_normalizer, filtered_probs, predicted_probs"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "\n",
+    "TODO: check equivalence of these two implementations!"
    ]
    ]
   }
   }
  ],
  ],

+ 1 - 1
chapters/hmm/hmm_index.md

@@ -2,7 +2,7 @@
 # Hidden Markov Models 
 # Hidden Markov Models 
 
 
 This chapter discusses Hidden Markov Models (HMMs), which are state space models
 This chapter discusses Hidden Markov Models (HMMs), which are state space models
-in which the latent state $z_t \in \{1,\ldots,K\}$ is discrete.
+in which the latent state $\hidden_t \in \{1,\ldots,\nstates\}$ is discrete.
 
 
 
 
 ```{tableofcontents}
 ```{tableofcontents}

+ 1 - 1
chapters/scratch.md

@@ -102,7 +102,7 @@ I am a useful note!
 
 
 ## Math
 ## Math
 
 
-Here is $\N=10$ and blah. $\floor{42.3}= 42$.
+Here is $\N=10$ and blah. $\floor{42.3}= 42$. Let's try again.
 
 
 We have $E= mc^2$, and also
 We have $E= mc^2$, and also
 
 

A diferenza do arquivo foi suprimida porque é demasiado grande
+ 552 - 0
chapters/scratchpad.ipynb


+ 36 - 239
chapters/ssm/hmm.ipynb

@@ -2,48 +2,27 @@
  "cells": [
  "cells": [
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 1,
+   "execution_count": null,
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
    "source": [
    "source": [
-    "# meta-data does not work yet in VScode\n",
-    "# https://github.com/microsoft/vscode-jupyter/issues/1121\n",
-    "\n",
-    "{\n",
-    "    \"tags\": [\n",
-    "        \"hide-cell\"\n",
-    "    ]\n",
-    "}\n",
-    "\n",
-    "\n",
-    "### Install necessary libraries\n",
-    "\n",
-    "try:\n",
-    "    import jax\n",
-    "except:\n",
-    "    # For cuda version, see https://github.com/google/jax#installation\n",
-    "    %pip install --upgrade \"jax[cpu]\" \n",
-    "    import jax\n",
-    "\n",
-    "try:\n",
-    "    import distrax\n",
-    "except:\n",
-    "    %pip install --upgrade  distrax\n",
-    "    import distrax\n",
-    "\n",
-    "try:\n",
-    "    import jsl\n",
-    "except:\n",
-    "    %pip install git+https://github.com/probml/jsl\n",
-    "    import jsl\n",
-    "\n",
-    "try:\n",
-    "    import rich\n",
-    "except:\n",
-    "    %pip install rich\n",
-    "    import rich\n",
+    "(sec:hmm-intro)=\n",
+    "# Hidden Markov Models\n",
     "\n",
     "\n",
-    "\n"
+    "In this section, we discuss the\n",
+    "hidden Markov model or HMM,\n",
+    "which is a state space model in which the hidden states\n",
+    "are discrete, so $\\hidden_t \\in \\{1,\\ldots, \\nstates\\}$.\n",
+    "The observations may be discrete,\n",
+    "$\\obs_t \\in \\{1,\\ldots, \\nsymbols\\}$,\n",
+    "or continuous,\n",
+    "$\\obs_t \\in \\real^\\nstates$,\n",
+    "or some combination,\n",
+    "as we illustrate below.\n",
+    "More details can be found in e.g., \n",
+    "{cite}`Rabiner89,Fraser08,Cappe05`.\n",
+    "For an interactive introduction,\n",
+    "see https://nipunbatra.github.io/hmm/."
    ]
    ]
   },
   },
   {
   {
@@ -64,21 +43,26 @@
     "import abc\n",
     "import abc\n",
     "from dataclasses import dataclass\n",
     "from dataclasses import dataclass\n",
     "import functools\n",
     "import functools\n",
+    "from functools import partial\n",
     "import itertools\n",
     "import itertools\n",
-    "\n",
-    "from typing import Any, Callable, NamedTuple, Optional, Union, Tuple\n",
-    "\n",
     "import matplotlib.pyplot as plt\n",
     "import matplotlib.pyplot as plt\n",
     "import numpy as np\n",
     "import numpy as np\n",
-    "\n",
+    "from typing import Any, Callable, NamedTuple, Optional, Union, Tuple\n",
     "\n",
     "\n",
     "import jax\n",
     "import jax\n",
     "import jax.numpy as jnp\n",
     "import jax.numpy as jnp\n",
     "from jax import lax, vmap, jit, grad\n",
     "from jax import lax, vmap, jit, grad\n",
-    "from jax.scipy.special import logit\n",
-    "from jax.nn import softmax\n",
-    "from functools import partial\n",
-    "from jax.random import PRNGKey, split\n",
+    "#from jax.scipy.special import logit\n",
+    "#from jax.nn import softmax\n",
+    "import jax.random as jr\n",
+    "\n",
+    "\n",
+    "\n",
+    "import distrax\n",
+    "import optax\n",
+    "\n",
+    "import jsl\n",
+    "import ssm_jax\n",
     "\n",
     "\n",
     "import inspect\n",
     "import inspect\n",
     "import inspect as py_inspect\n",
     "import inspect as py_inspect\n",
@@ -94,193 +78,6 @@
    "cell_type": "markdown",
    "cell_type": "markdown",
    "metadata": {},
    "metadata": {},
    "source": [
    "source": [
-    "```{math}\n",
-    "\n",
-    "\\newcommand\\floor[1]{\\lfloor#1\\rfloor}\n",
-    "\n",
-    "\\newcommand{\\real}{\\mathbb{R}}\n",
-    "\n",
-    "% Numbers\n",
-    "\\newcommand{\\vzero}{\\boldsymbol{0}}\n",
-    "\\newcommand{\\vone}{\\boldsymbol{1}}\n",
-    "\n",
-    "% Greek https://www.latex-tutorial.com/symbols/greek-alphabet/\n",
-    "\\newcommand{\\valpha}{\\boldsymbol{\\alpha}}\n",
-    "\\newcommand{\\vbeta}{\\boldsymbol{\\beta}}\n",
-    "\\newcommand{\\vchi}{\\boldsymbol{\\chi}}\n",
-    "\\newcommand{\\vdelta}{\\boldsymbol{\\delta}}\n",
-    "\\newcommand{\\vDelta}{\\boldsymbol{\\Delta}}\n",
-    "\\newcommand{\\vepsilon}{\\boldsymbol{\\epsilon}}\n",
-    "\\newcommand{\\vzeta}{\\boldsymbol{\\zeta}}\n",
-    "\\newcommand{\\vXi}{\\boldsymbol{\\Xi}}\n",
-    "\\newcommand{\\vell}{\\boldsymbol{\\ell}}\n",
-    "\\newcommand{\\veta}{\\boldsymbol{\\eta}}\n",
-    "%\\newcommand{\\vEta}{\\boldsymbol{\\Eta}}\n",
-    "\\newcommand{\\vgamma}{\\boldsymbol{\\gamma}}\n",
-    "\\newcommand{\\vGamma}{\\boldsymbol{\\Gamma}}\n",
-    "\\newcommand{\\vmu}{\\boldsymbol{\\mu}}\n",
-    "\\newcommand{\\vmut}{\\boldsymbol{\\tilde{\\mu}}}\n",
-    "\\newcommand{\\vnu}{\\boldsymbol{\\nu}}\n",
-    "\\newcommand{\\vkappa}{\\boldsymbol{\\kappa}}\n",
-    "\\newcommand{\\vlambda}{\\boldsymbol{\\lambda}}\n",
-    "\\newcommand{\\vLambda}{\\boldsymbol{\\Lambda}}\n",
-    "\\newcommand{\\vLambdaBar}{\\overline{\\vLambda}}\n",
-    "%\\newcommand{\\vnu}{\\boldsymbol{\\nu}}\n",
-    "\\newcommand{\\vomega}{\\boldsymbol{\\omega}}\n",
-    "\\newcommand{\\vOmega}{\\boldsymbol{\\Omega}}\n",
-    "\\newcommand{\\vphi}{\\boldsymbol{\\phi}}\n",
-    "\\newcommand{\\vvarphi}{\\boldsymbol{\\varphi}}\n",
-    "\\newcommand{\\vPhi}{\\boldsymbol{\\Phi}}\n",
-    "\\newcommand{\\vpi}{\\boldsymbol{\\pi}}\n",
-    "\\newcommand{\\vPi}{\\boldsymbol{\\Pi}}\n",
-    "\\newcommand{\\vpsi}{\\boldsymbol{\\psi}}\n",
-    "\\newcommand{\\vPsi}{\\boldsymbol{\\Psi}}\n",
-    "\\newcommand{\\vrho}{\\boldsymbol{\\rho}}\n",
-    "\\newcommand{\\vtheta}{\\boldsymbol{\\theta}}\n",
-    "\\newcommand{\\vthetat}{\\boldsymbol{\\tilde{\\theta}}}\n",
-    "\\newcommand{\\vTheta}{\\boldsymbol{\\Theta}}\n",
-    "\\newcommand{\\vsigma}{\\boldsymbol{\\sigma}}\n",
-    "\\newcommand{\\vSigma}{\\boldsymbol{\\Sigma}}\n",
-    "\\newcommand{\\vSigmat}{\\boldsymbol{\\tilde{\\Sigma}}}\n",
-    "\\newcommand{\\vsigmoid}{\\vsigma}\n",
-    "\\newcommand{\\vtau}{\\boldsymbol{\\tau}}\n",
-    "\\newcommand{\\vxi}{\\boldsymbol{\\xi}}\n",
-    "\n",
-    "\n",
-    "% Lower Roman (Vectors)\n",
-    "\\newcommand{\\va}{\\mathbf{a}}\n",
-    "\\newcommand{\\vb}{\\mathbf{b}}\n",
-    "\\newcommand{\\vBt}{\\mathbf{\\tilde{B}}}\n",
-    "\\newcommand{\\vc}{\\mathbf{c}}\n",
-    "\\newcommand{\\vct}{\\mathbf{\\tilde{c}}}\n",
-    "\\newcommand{\\vd}{\\mathbf{d}}\n",
-    "\\newcommand{\\ve}{\\mathbf{e}}\n",
-    "\\newcommand{\\vf}{\\mathbf{f}}\n",
-    "\\newcommand{\\vg}{\\mathbf{g}}\n",
-    "\\newcommand{\\vh}{\\mathbf{h}}\n",
-    "%\\newcommand{\\myvh}{\\mathbf{h}}\n",
-    "\\newcommand{\\vi}{\\mathbf{i}}\n",
-    "\\newcommand{\\vj}{\\mathbf{j}}\n",
-    "\\newcommand{\\vk}{\\mathbf{k}}\n",
-    "\\newcommand{\\vl}{\\mathbf{l}}\n",
-    "\\newcommand{\\vm}{\\mathbf{m}}\n",
-    "\\newcommand{\\vn}{\\mathbf{n}}\n",
-    "\\newcommand{\\vo}{\\mathbf{o}}\n",
-    "\\newcommand{\\vp}{\\mathbf{p}}\n",
-    "\\newcommand{\\vq}{\\mathbf{q}}\n",
-    "\\newcommand{\\vr}{\\mathbf{r}}\n",
-    "\\newcommand{\\vs}{\\mathbf{s}}\n",
-    "\\newcommand{\\vt}{\\mathbf{t}}\n",
-    "\\newcommand{\\vu}{\\mathbf{u}}\n",
-    "\\newcommand{\\vv}{\\mathbf{v}}\n",
-    "\\newcommand{\\vw}{\\mathbf{w}}\n",
-    "\\newcommand{\\vws}{\\vw_s}\n",
-    "\\newcommand{\\vwt}{\\mathbf{\\tilde{w}}}\n",
-    "\\newcommand{\\vWt}{\\mathbf{\\tilde{W}}}\n",
-    "\\newcommand{\\vwh}{\\hat{\\vw}}\n",
-    "\\newcommand{\\vx}{\\mathbf{x}}\n",
-    "%\\newcommand{\\vx}{\\mathbf{x}}\n",
-    "\\newcommand{\\vxt}{\\mathbf{\\tilde{x}}}\n",
-    "\\newcommand{\\vy}{\\mathbf{y}}\n",
-    "\\newcommand{\\vyt}{\\mathbf{\\tilde{y}}}\n",
-    "\\newcommand{\\vz}{\\mathbf{z}}\n",
-    "%\\newcommand{\\vzt}{\\mathbf{\\tilde{z}}}\n",
-    "\n",
-    "\n",
-    "% Upper Roman (Matrices)\n",
-    "\\newcommand{\\vA}{\\mathbf{A}}\n",
-    "\\newcommand{\\vB}{\\mathbf{B}}\n",
-    "\\newcommand{\\vC}{\\mathbf{C}}\n",
-    "\\newcommand{\\vD}{\\mathbf{D}}\n",
-    "\\newcommand{\\vE}{\\mathbf{E}}\n",
-    "\\newcommand{\\vF}{\\mathbf{F}}\n",
-    "\\newcommand{\\vG}{\\mathbf{G}}\n",
-    "\\newcommand{\\vH}{\\mathbf{H}}\n",
-    "\\newcommand{\\vI}{\\mathbf{I}}\n",
-    "\\newcommand{\\vJ}{\\mathbf{J}}\n",
-    "\\newcommand{\\vK}{\\mathbf{K}}\n",
-    "\\newcommand{\\vL}{\\mathbf{L}}\n",
-    "\\newcommand{\\vM}{\\mathbf{M}}\n",
-    "\\newcommand{\\vMt}{\\mathbf{\\tilde{M}}}\n",
-    "\\newcommand{\\vN}{\\mathbf{N}}\n",
-    "\\newcommand{\\vO}{\\mathbf{O}}\n",
-    "\\newcommand{\\vP}{\\mathbf{P}}\n",
-    "\\newcommand{\\vQ}{\\mathbf{Q}}\n",
-    "\\newcommand{\\vR}{\\mathbf{R}}\n",
-    "\\newcommand{\\vS}{\\mathbf{S}}\n",
-    "\\newcommand{\\vT}{\\mathbf{T}}\n",
-    "\\newcommand{\\vU}{\\mathbf{U}}\n",
-    "\\newcommand{\\vV}{\\mathbf{V}}\n",
-    "\\newcommand{\\vW}{\\mathbf{W}}\n",
-    "\\newcommand{\\vX}{\\mathbf{X}}\n",
-    "%\\newcommand{\\vXs}{\\vX_{\\vs}}\n",
-    "\\newcommand{\\vXs}{\\vX_{s}}\n",
-    "\\newcommand{\\vXt}{\\mathbf{\\tilde{X}}}\n",
-    "\\newcommand{\\vY}{\\mathbf{Y}}\n",
-    "\\newcommand{\\vZ}{\\mathbf{Z}}\n",
-    "\\newcommand{\\vZt}{\\mathbf{\\tilde{Z}}}\n",
-    "\\newcommand{\\vzt}{\\mathbf{\\tilde{z}}}\n",
-    "\n",
-    "\n",
-    "%%%%\n",
-    "\\newcommand{\\hidden}{\\vz}\n",
-    "\\newcommand{\\hid}{\\hidden}\n",
-    "\\newcommand{\\observed}{\\vy}\n",
-    "\\newcommand{\\obs}{\\observed}\n",
-    "\\newcommand{\\inputs}{\\vu}\n",
-    "\\newcommand{\\input}{\\inputs}\n",
-    "\n",
-    "\\newcommand{\\hmmTrans}{\\vA}\n",
-    "\\newcommand{\\hmmObs}{\\vB}\n",
-    "\\newcommand{\\hmmInit}{\\vpi}\n",
-    "\\newcommand{\\hmmhid}{\\hidden}\n",
-    "\\newcommand{\\hmmobs}{\\obs}\n",
-    "\n",
-    "\\newcommand{\\ldsDyn}{\\vA}\n",
-    "\\newcommand{\\ldsObs}{\\vC}\n",
-    "\\newcommand{\\ldsDynIn}{\\vB}\n",
-    "\\newcommand{\\ldsObsIn}{\\vD}\n",
-    "\\newcommand{\\ldsDynNoise}{\\vQ}\n",
-    "\\newcommand{\\ldsObsNoise}{\\vR}\n",
-    "\n",
-    "\\newcommand{\\ssmDynFn}{f}\n",
-    "\\newcommand{\\ssmObsFn}{h}\n",
-    "\n",
-    "\n",
-    "%%%\n",
-    "\\newcommand{\\gauss}{\\mathcal{N}}\n",
-    "\n",
-    "\\newcommand{\\diag}{\\mathrm{diag}}\n",
-    "```\n"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "(sec:hmm-intro)=\n",
-    "# Hidden Markov Models\n",
-    "\n",
-    "In this section, we discuss the\n",
-    "hidden Markov model or HMM,\n",
-    "which is a state space model in which the hidden states\n",
-    "are discrete, so $\\hmmhid_t \\in \\{1,\\ldots, K\\}$.\n",
-    "The observations may be discrete,\n",
-    "$\\hmmobs_t \\in \\{1,\\ldots, C\\}$,\n",
-    "or continuous,\n",
-    "$\\hmmobs_t \\in \\real^D$,\n",
-    "or some combination,\n",
-    "as we illustrate below.\n",
-    "More details can be found in e.g., \n",
-    "{cite}`Rabiner89,Fraser08,Cappe05`.\n",
-    "For an interactive introduction,\n",
-    "see https://nipunbatra.github.io/hmm/."
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
     "(sec:casino)=\n",
     "(sec:casino)=\n",
     "### Example: Casino HMM\n",
     "### Example: Casino HMM\n",
     "\n",
     "\n",
@@ -291,9 +88,9 @@
     "\n",
     "\n",
     "The transition model is denoted by\n",
     "The transition model is denoted by\n",
     "```{math}\n",
     "```{math}\n",
-    "p(z_t=j|z_{t-1}=i) = \\hmmTrans_{ij}\n",
+    "p(\\hidden_t=j|\\hidden_{t-1}=i) = \\hmmTrans_{ij}\n",
     "```\n",
     "```\n",
-    "Here the $i$'th row of $\\vA$ corresponds to the outgoing distribution from state $i$.\n",
+    "Here the $i$'th row of $\\hmmTrans$ corresponds to the outgoing distribution from state $i$.\n",
     "This is  a row stochastic matrix,\n",
     "This is  a row stochastic matrix,\n",
     "meaning each row sums to one.\n",
     "meaning each row sums to one.\n",
     "We can visualize\n",
     "We can visualize\n",
@@ -314,15 +111,15 @@
     "p(\\obs_t=k|\\hidden_t=j) = \\hmmObs_{jk} \n",
     "p(\\obs_t=k|\\hidden_t=j) = \\hmmObs_{jk} \n",
     "```\n",
     "```\n",
     "This is represented by the histograms associated with each\n",
     "This is represented by the histograms associated with each\n",
-    "state in  {numref}`casino-fig`.\n",
+    "state in  {numref}`fig:casino`.\n",
     "\n",
     "\n",
     "Finally,\n",
     "Finally,\n",
     "the initial state distribution is denoted by\n",
     "the initial state distribution is denoted by\n",
     "```{math}\n",
     "```{math}\n",
-    "p(z_1=j) = \\hmmInit_j\n",
+    "p(\\hidden_1=j) = \\hmmInit_j\n",
     "```\n",
     "```\n",
     "\n",
     "\n",
-    "Collectively we denote all the parameters by $\\vtheta=(\\hmmTrans, \\hmmObs, \\hmmInit)$.\n",
+    "Collectively we denote all the parameters by $\\params=(\\hmmTrans, \\hmmObs, \\hmmInit)$.\n",
     "\n",
     "\n",
     "Now let us implement this model in code."
     "Now let us implement this model in code."
    ]
    ]
@@ -412,7 +209,7 @@
     "\n",
     "\n",
     "seed = 314\n",
     "seed = 314\n",
     "n_samples = 300\n",
     "n_samples = 300\n",
-    "z_hist, x_hist = hmm.sample(seed=PRNGKey(seed), seq_len=n_samples)\n",
+    "z_hist, x_hist = hmm.sample(seed=jr.PRNGKey(seed), seq_len=n_samples)\n",
     "\n",
     "\n",
     "z_hist_str = \"\".join((np.array(z_hist) + 1).astype(str))[:60]\n",
     "z_hist_str = \"\".join((np.array(z_hist) + 1).astype(str))[:60]\n",
     "x_hist_str = \"\".join((np.array(x_hist) + 1).astype(str))[:60]\n",
     "x_hist_str = \"\".join((np.array(x_hist) + 1).astype(str))[:60]\n",

+ 30 - 244
chapters/ssm/inference.ipynb

@@ -2,52 +2,6 @@
  "cells": [
  "cells": [
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 2,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "# meta-data does not work yet in VScode\n",
-    "# https://github.com/microsoft/vscode-jupyter/issues/1121\n",
-    "\n",
-    "{\n",
-    "    \"tags\": [\n",
-    "        \"hide-cell\"\n",
-    "    ]\n",
-    "}\n",
-    "\n",
-    "\n",
-    "### Install necessary libraries\n",
-    "\n",
-    "try:\n",
-    "    import jax\n",
-    "except:\n",
-    "    # For cuda version, see https://github.com/google/jax#installation\n",
-    "    %pip install --upgrade \"jax[cpu]\" \n",
-    "    import jax\n",
-    "\n",
-    "try:\n",
-    "    import distrax\n",
-    "except:\n",
-    "    %pip install --upgrade  distrax\n",
-    "    import distrax\n",
-    "\n",
-    "try:\n",
-    "    import jsl\n",
-    "except:\n",
-    "    %pip install git+https://github.com/probml/jsl\n",
-    "    import jsl\n",
-    "\n",
-    "try:\n",
-    "    import rich\n",
-    "except:\n",
-    "    %pip install rich\n",
-    "    import rich\n",
-    "\n",
-    "\n"
-   ]
-  },
-  {
-   "cell_type": "code",
    "execution_count": 3,
    "execution_count": 3,
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
@@ -80,178 +34,8 @@
     "from functools import partial\n",
     "from functools import partial\n",
     "from jax.random import PRNGKey, split\n",
     "from jax.random import PRNGKey, split\n",
     "\n",
     "\n",
-    "import inspect\n",
-    "import inspect as py_inspect\n",
-    "import rich\n",
-    "from rich import inspect as r_inspect\n",
-    "from rich import print as r_print\n",
-    "\n",
-    "def print_source(fname):\n",
-    "    r_print(py_inspect.getsource(fname))"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "```{math}\n",
-    "\n",
-    "\\newcommand\\floor[1]{\\lfloor#1\\rfloor}\n",
-    "\n",
-    "\\newcommand{\\real}{\\mathbb{R}}\n",
-    "\n",
-    "% Numbers\n",
-    "\\newcommand{\\vzero}{\\boldsymbol{0}}\n",
-    "\\newcommand{\\vone}{\\boldsymbol{1}}\n",
-    "\n",
-    "% Greek https://www.latex-tutorial.com/symbols/greek-alphabet/\n",
-    "\\newcommand{\\valpha}{\\boldsymbol{\\alpha}}\n",
-    "\\newcommand{\\vbeta}{\\boldsymbol{\\beta}}\n",
-    "\\newcommand{\\vchi}{\\boldsymbol{\\chi}}\n",
-    "\\newcommand{\\vdelta}{\\boldsymbol{\\delta}}\n",
-    "\\newcommand{\\vDelta}{\\boldsymbol{\\Delta}}\n",
-    "\\newcommand{\\vepsilon}{\\boldsymbol{\\epsilon}}\n",
-    "\\newcommand{\\vzeta}{\\boldsymbol{\\zeta}}\n",
-    "\\newcommand{\\vXi}{\\boldsymbol{\\Xi}}\n",
-    "\\newcommand{\\vell}{\\boldsymbol{\\ell}}\n",
-    "\\newcommand{\\veta}{\\boldsymbol{\\eta}}\n",
-    "%\\newcommand{\\vEta}{\\boldsymbol{\\Eta}}\n",
-    "\\newcommand{\\vgamma}{\\boldsymbol{\\gamma}}\n",
-    "\\newcommand{\\vGamma}{\\boldsymbol{\\Gamma}}\n",
-    "\\newcommand{\\vmu}{\\boldsymbol{\\mu}}\n",
-    "\\newcommand{\\vmut}{\\boldsymbol{\\tilde{\\mu}}}\n",
-    "\\newcommand{\\vnu}{\\boldsymbol{\\nu}}\n",
-    "\\newcommand{\\vkappa}{\\boldsymbol{\\kappa}}\n",
-    "\\newcommand{\\vlambda}{\\boldsymbol{\\lambda}}\n",
-    "\\newcommand{\\vLambda}{\\boldsymbol{\\Lambda}}\n",
-    "\\newcommand{\\vLambdaBar}{\\overline{\\vLambda}}\n",
-    "%\\newcommand{\\vnu}{\\boldsymbol{\\nu}}\n",
-    "\\newcommand{\\vomega}{\\boldsymbol{\\omega}}\n",
-    "\\newcommand{\\vOmega}{\\boldsymbol{\\Omega}}\n",
-    "\\newcommand{\\vphi}{\\boldsymbol{\\phi}}\n",
-    "\\newcommand{\\vvarphi}{\\boldsymbol{\\varphi}}\n",
-    "\\newcommand{\\vPhi}{\\boldsymbol{\\Phi}}\n",
-    "\\newcommand{\\vpi}{\\boldsymbol{\\pi}}\n",
-    "\\newcommand{\\vPi}{\\boldsymbol{\\Pi}}\n",
-    "\\newcommand{\\vpsi}{\\boldsymbol{\\psi}}\n",
-    "\\newcommand{\\vPsi}{\\boldsymbol{\\Psi}}\n",
-    "\\newcommand{\\vrho}{\\boldsymbol{\\rho}}\n",
-    "\\newcommand{\\vtheta}{\\boldsymbol{\\theta}}\n",
-    "\\newcommand{\\vthetat}{\\boldsymbol{\\tilde{\\theta}}}\n",
-    "\\newcommand{\\vTheta}{\\boldsymbol{\\Theta}}\n",
-    "\\newcommand{\\vsigma}{\\boldsymbol{\\sigma}}\n",
-    "\\newcommand{\\vSigma}{\\boldsymbol{\\Sigma}}\n",
-    "\\newcommand{\\vSigmat}{\\boldsymbol{\\tilde{\\Sigma}}}\n",
-    "\\newcommand{\\vsigmoid}{\\vsigma}\n",
-    "\\newcommand{\\vtau}{\\boldsymbol{\\tau}}\n",
-    "\\newcommand{\\vxi}{\\boldsymbol{\\xi}}\n",
-    "\n",
-    "\n",
-    "% Lower Roman (Vectors)\n",
-    "\\newcommand{\\va}{\\mathbf{a}}\n",
-    "\\newcommand{\\vb}{\\mathbf{b}}\n",
-    "\\newcommand{\\vBt}{\\mathbf{\\tilde{B}}}\n",
-    "\\newcommand{\\vc}{\\mathbf{c}}\n",
-    "\\newcommand{\\vct}{\\mathbf{\\tilde{c}}}\n",
-    "\\newcommand{\\vd}{\\mathbf{d}}\n",
-    "\\newcommand{\\ve}{\\mathbf{e}}\n",
-    "\\newcommand{\\vf}{\\mathbf{f}}\n",
-    "\\newcommand{\\vg}{\\mathbf{g}}\n",
-    "\\newcommand{\\vh}{\\mathbf{h}}\n",
-    "%\\newcommand{\\myvh}{\\mathbf{h}}\n",
-    "\\newcommand{\\vi}{\\mathbf{i}}\n",
-    "\\newcommand{\\vj}{\\mathbf{j}}\n",
-    "\\newcommand{\\vk}{\\mathbf{k}}\n",
-    "\\newcommand{\\vl}{\\mathbf{l}}\n",
-    "\\newcommand{\\vm}{\\mathbf{m}}\n",
-    "\\newcommand{\\vn}{\\mathbf{n}}\n",
-    "\\newcommand{\\vo}{\\mathbf{o}}\n",
-    "\\newcommand{\\vp}{\\mathbf{p}}\n",
-    "\\newcommand{\\vq}{\\mathbf{q}}\n",
-    "\\newcommand{\\vr}{\\mathbf{r}}\n",
-    "\\newcommand{\\vs}{\\mathbf{s}}\n",
-    "\\newcommand{\\vt}{\\mathbf{t}}\n",
-    "\\newcommand{\\vu}{\\mathbf{u}}\n",
-    "\\newcommand{\\vv}{\\mathbf{v}}\n",
-    "\\newcommand{\\vw}{\\mathbf{w}}\n",
-    "\\newcommand{\\vws}{\\vw_s}\n",
-    "\\newcommand{\\vwt}{\\mathbf{\\tilde{w}}}\n",
-    "\\newcommand{\\vWt}{\\mathbf{\\tilde{W}}}\n",
-    "\\newcommand{\\vwh}{\\hat{\\vw}}\n",
-    "\\newcommand{\\vx}{\\mathbf{x}}\n",
-    "%\\newcommand{\\vx}{\\mathbf{x}}\n",
-    "\\newcommand{\\vxt}{\\mathbf{\\tilde{x}}}\n",
-    "\\newcommand{\\vy}{\\mathbf{y}}\n",
-    "\\newcommand{\\vyt}{\\mathbf{\\tilde{y}}}\n",
-    "\\newcommand{\\vz}{\\mathbf{z}}\n",
-    "%\\newcommand{\\vzt}{\\mathbf{\\tilde{z}}}\n",
-    "\n",
-    "\n",
-    "% Upper Roman (Matrices)\n",
-    "\\newcommand{\\vA}{\\mathbf{A}}\n",
-    "\\newcommand{\\vB}{\\mathbf{B}}\n",
-    "\\newcommand{\\vC}{\\mathbf{C}}\n",
-    "\\newcommand{\\vD}{\\mathbf{D}}\n",
-    "\\newcommand{\\vE}{\\mathbf{E}}\n",
-    "\\newcommand{\\vF}{\\mathbf{F}}\n",
-    "\\newcommand{\\vG}{\\mathbf{G}}\n",
-    "\\newcommand{\\vH}{\\mathbf{H}}\n",
-    "\\newcommand{\\vI}{\\mathbf{I}}\n",
-    "\\newcommand{\\vJ}{\\mathbf{J}}\n",
-    "\\newcommand{\\vK}{\\mathbf{K}}\n",
-    "\\newcommand{\\vL}{\\mathbf{L}}\n",
-    "\\newcommand{\\vM}{\\mathbf{M}}\n",
-    "\\newcommand{\\vMt}{\\mathbf{\\tilde{M}}}\n",
-    "\\newcommand{\\vN}{\\mathbf{N}}\n",
-    "\\newcommand{\\vO}{\\mathbf{O}}\n",
-    "\\newcommand{\\vP}{\\mathbf{P}}\n",
-    "\\newcommand{\\vQ}{\\mathbf{Q}}\n",
-    "\\newcommand{\\vR}{\\mathbf{R}}\n",
-    "\\newcommand{\\vS}{\\mathbf{S}}\n",
-    "\\newcommand{\\vT}{\\mathbf{T}}\n",
-    "\\newcommand{\\vU}{\\mathbf{U}}\n",
-    "\\newcommand{\\vV}{\\mathbf{V}}\n",
-    "\\newcommand{\\vW}{\\mathbf{W}}\n",
-    "\\newcommand{\\vX}{\\mathbf{X}}\n",
-    "%\\newcommand{\\vXs}{\\vX_{\\vs}}\n",
-    "\\newcommand{\\vXs}{\\vX_{s}}\n",
-    "\\newcommand{\\vXt}{\\mathbf{\\tilde{X}}}\n",
-    "\\newcommand{\\vY}{\\mathbf{Y}}\n",
-    "\\newcommand{\\vZ}{\\mathbf{Z}}\n",
-    "\\newcommand{\\vZt}{\\mathbf{\\tilde{Z}}}\n",
-    "\\newcommand{\\vzt}{\\mathbf{\\tilde{z}}}\n",
-    "\n",
-    "\n",
-    "%%%%\n",
-    "\\newcommand{\\hidden}{\\vz}\n",
-    "\\newcommand{\\hid}{\\hidden}\n",
-    "\\newcommand{\\observed}{\\vy}\n",
-    "\\newcommand{\\obs}{\\observed}\n",
-    "\\newcommand{\\inputs}{\\vu}\n",
-    "\\newcommand{\\input}{\\inputs}\n",
-    "\n",
-    "\\newcommand{\\hmmTrans}{\\vA}\n",
-    "\\newcommand{\\hmmObs}{\\vB}\n",
-    "\\newcommand{\\hmmInit}{\\vpi}\n",
-    "\\newcommand{\\hmmhid}{\\hidden}\n",
-    "\\newcommand{\\hmmobs}{\\obs}\n",
-    "\n",
-    "\\newcommand{\\ldsDyn}{\\vA}\n",
-    "\\newcommand{\\ldsObs}{\\vC}\n",
-    "\\newcommand{\\ldsDynIn}{\\vB}\n",
-    "\\newcommand{\\ldsObsIn}{\\vD}\n",
-    "\\newcommand{\\ldsDynNoise}{\\vQ}\n",
-    "\\newcommand{\\ldsObsNoise}{\\vR}\n",
-    "\n",
-    "\\newcommand{\\ssmDynFn}{f}\n",
-    "\\newcommand{\\ssmObsFn}{h}\n",
-    "\n",
-    "\n",
-    "%%%\n",
-    "\\newcommand{\\gauss}{\\mathcal{N}}\n",
-    "\n",
-    "\\newcommand{\\diag}{\\mathrm{diag}}\n",
-    "```\n"
+    "import jsl\n",
+    "import ssm_jax\n"
    ]
    ]
   },
   },
   {
   {
@@ -259,19 +43,9 @@
    "metadata": {},
    "metadata": {},
    "source": [
    "source": [
     "(sec:inference)=\n",
     "(sec:inference)=\n",
-    "# Inferential goals\n",
+    "# States estimation (inference)\n",
     "\n",
     "\n",
-    "```{figure} /figures/inference-problems-tikz.png\n",
-    ":scale: 30%\n",
-    ":name: fig:dbn-inference\n",
     "\n",
     "\n",
-    "Illustration of the different kinds of inference in an SSM.\n",
-    " The main kinds of inference for state-space models.\n",
-    "    The shaded region is the interval for which we have data.\n",
-    "    The arrow represents the time step at which we want to perform inference.\n",
-    "    $t$ is the current time,  $T$ is the sequence length,\n",
-    "$\\ell$ is the lag and $h$ is the prediction horizon.\n",
-    "```\n",
     "\n",
     "\n",
     "\n",
     "\n",
     "\n",
     "\n",
@@ -284,41 +58,53 @@
     "there are multiple forms of posterior we may be interested in computing,\n",
     "there are multiple forms of posterior we may be interested in computing,\n",
     "including the following:\n",
     "including the following:\n",
     "- the filtering distribution\n",
     "- the filtering distribution\n",
-    "$p(\\hmmhid_t|\\hmmobs_{1:t})$\n",
+    "$p(\\hidden_t|\\obs_{1:t})$\n",
     "- the smoothing distribution\n",
     "- the smoothing distribution\n",
-    "$p(\\hmmhid_t|\\hmmobs_{1:T})$ (note that this conditions on future data $T>t$)\n",
+    "$p(\\hidden_t|\\obs_{1:T})$ (note that this conditions on future data $T>t$)\n",
     "- the fixed-lag smoothing distribution\n",
     "- the fixed-lag smoothing distribution\n",
-    "$p(\\hmmhid_{t-\\ell}|\\hmmobs_{1:t})$ (note that this\n",
+    "$p(\\hidden_{t-\\ell}|\\obs_{1:t})$ (note that this\n",
     "infers $\\ell$ steps in the past given data up to the present).\n",
     "infers $\\ell$ steps in the past given data up to the present).\n",
     "\n",
     "\n",
     "We may also want to compute the\n",
     "We may also want to compute the\n",
     "predictive distribution $h$ steps into the future:\n",
     "predictive distribution $h$ steps into the future:\n",
-    "```{math}\n",
-    "p(\\hmmobs_{t+h}|\\hmmobs_{1:t})\n",
-    "= \\sum_{\\hmmhid_{t+h}} p(\\hmmobs_{t+h}|\\hmmhid_{t+h}) p(\\hmmhid_{t+h}|\\hmmobs_{1:t})\n",
-    "```\n",
+    "\\begin{align}\n",
+    "p(\\obs_{t+h}|\\obs_{1:t})\n",
+    "= \\sum_{\\hidden_{t+h}} p(\\obs_{t+h}|\\hidden_{t+h}) p(\\hidden_{t+h}|\\obs_{1:t})\n",
+    "\\end{align}\n",
     "where the hidden state predictive distribution is\n",
     "where the hidden state predictive distribution is\n",
     "\\begin{align}\n",
     "\\begin{align}\n",
-    "p(\\hmmhid_{t+h}|\\hmmobs_{1:t})\n",
-    "&= \\sum_{\\hmmhid_{t:t+h-1}}\n",
-    " p(\\hmmhid_t|\\hmmobs_{1:t}) \n",
-    " p(\\hmmhid_{t+1}|\\hmmhid_{t})\n",
-    " p(\\hmmhid_{t+2}|\\hmmhid_{t+1})\n",
+    "p(\\hidden_{t+h}|\\obs_{1:t})\n",
+    "&= \\sum_{\\hidden_{t:t+h-1}}\n",
+    " p(\\hidden_t|\\obs_{1:t}) \n",
+    " p(\\hidden_{t+1}|\\hidden_{t})\n",
+    " p(\\hidden_{t+2}|\\hidden_{t+1})\n",
     "\\cdots\n",
     "\\cdots\n",
-    " p(\\hmmhid_{t+h}|\\hmmhid_{t+h-1})\n",
+    " p(\\hidden_{t+h}|\\hidden_{t+h-1})\n",
     "\\end{align}\n",
     "\\end{align}\n",
     "See \n",
     "See \n",
     "{numref}`fig:dbn-inference` for a summary of these distributions.\n",
     "{numref}`fig:dbn-inference` for a summary of these distributions.\n",
     "\n",
     "\n",
+    "```{figure} /figures/inference-problems-tikz.png\n",
+    ":scale: 30%\n",
+    ":name: fig:dbn-inference\n",
+    "\n",
+    "Illustration of the different kinds of inference in an SSM.\n",
+    " The main kinds of inference for state-space models.\n",
+    "    The shaded region is the interval for which we have data.\n",
+    "    The arrow represents the time step at which we want to perform inference.\n",
+    "    $t$ is the current time,  $T$ is the sequence length,\n",
+    "$\\ell$ is the lag and $h$ is the prediction horizon.\n",
+    "```\n",
+    "\n",
     "In addition  to comuting posterior marginals,\n",
     "In addition  to comuting posterior marginals,\n",
     "we may want to compute the most probable hidden sequence,\n",
     "we may want to compute the most probable hidden sequence,\n",
     "i.e., the joint MAP estimate\n",
     "i.e., the joint MAP estimate\n",
     "```{math}\n",
     "```{math}\n",
-    "\\arg \\max_{\\hmmhid_{1:T}} p(\\hmmhid_{1:T}|\\hmmobs_{1:T})\n",
+    "\\arg \\max_{\\hidden_{1:T}} p(\\hidden_{1:T}|\\obs_{1:T})\n",
     "```\n",
     "```\n",
     "or sample sequences from the posterior\n",
     "or sample sequences from the posterior\n",
     "```{math}\n",
     "```{math}\n",
-    "\\hmmhid_{1:T} \\sim p(\\hmmhid_{1:T}|\\hmmobs_{1:T})\n",
+    "\\hidden_{1:T} \\sim p(\\hidden_{1:T}|\\obs_{1:T})\n",
     "```\n",
     "```\n",
     "\n",
     "\n",
     "Algorithms for all these task are discussed in the following chapters,\n",
     "Algorithms for all these task are discussed in the following chapters,\n",

+ 55 - 301
chapters/ssm/lds.ipynb

@@ -2,125 +2,35 @@
  "cells": [
  "cells": [
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 1,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "Requirement already satisfied: distrax in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (0.0.1)\n",
-      "Collecting distrax\n",
-      "  Downloading distrax-0.1.2-py3-none-any.whl (272 kB)\n",
-      "\u001b[K     |████████████████████████████████| 272 kB 6.9 MB/s eta 0:00:01\n",
-      "\u001b[?25hRequirement already satisfied: jax>=0.1.55 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from distrax) (0.2.11)\n",
-      "Requirement already satisfied: absl-py>=0.9.0 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from distrax) (0.12.0)\n",
-      "Requirement already satisfied: chex>=0.0.7 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from distrax) (0.0.8)\n",
-      "Requirement already satisfied: jaxlib>=0.1.67 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from distrax) (0.1.70)\n",
-      "Requirement already satisfied: numpy>=1.18.0 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from distrax) (1.19.5)\n",
-      "Collecting tensorflow-probability>=0.15.0\n",
-      "  Using cached tensorflow_probability-0.16.0-py2.py3-none-any.whl (6.3 MB)\n",
-      "Requirement already satisfied: six in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from absl-py>=0.9.0->distrax) (1.15.0)\n",
-      "Requirement already satisfied: dm-tree>=0.1.5 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from chex>=0.0.7->distrax) (0.1.6)\n",
-      "Requirement already satisfied: toolz>=0.9.0 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from chex>=0.0.7->distrax) (0.11.1)\n",
-      "Requirement already satisfied: opt-einsum in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from jax>=0.1.55->distrax) (3.3.0)\n",
-      "Requirement already satisfied: flatbuffers<3.0,>=1.12 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from jaxlib>=0.1.67->distrax) (1.12)\n",
-      "Requirement already satisfied: scipy in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from jaxlib>=0.1.67->distrax) (1.6.3)\n",
-      "Requirement already satisfied: cloudpickle>=1.3 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from tensorflow-probability>=0.15.0->distrax) (1.6.0)\n",
-      "Requirement already satisfied: decorator in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from tensorflow-probability>=0.15.0->distrax) (4.4.2)\n",
-      "Requirement already satisfied: gast>=0.3.2 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from tensorflow-probability>=0.15.0->distrax) (0.4.0)\n",
-      "Installing collected packages: tensorflow-probability, distrax\n",
-      "  Attempting uninstall: tensorflow-probability\n",
-      "    Found existing installation: tensorflow-probability 0.13.0\n",
-      "    Uninstalling tensorflow-probability-0.13.0:\n",
-      "      Successfully uninstalled tensorflow-probability-0.13.0\n",
-      "  Attempting uninstall: distrax\n",
-      "    Found existing installation: distrax 0.0.1\n",
-      "    Uninstalling distrax-0.0.1:\n",
-      "      Successfully uninstalled distrax-0.0.1\n",
-      "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
-      "jsl 0.0.0 requires dataclasses, which is not installed.\u001b[0m\n",
-      "Successfully installed distrax-0.1.2 tensorflow-probability-0.16.0\n",
-      "\u001b[33mWARNING: You are using pip version 21.2.4; however, version 22.0.4 is available.\n",
-      "You should consider upgrading via the '/opt/anaconda3/envs/spyder-dev/bin/python -m pip install --upgrade pip' command.\u001b[0m\n",
-      "Note: you may need to restart the kernel to use updated packages.\n"
-     ]
-    }
-   ],
-   "source": [
-    "# meta-data does not work yet in VScode\n",
-    "# https://github.com/microsoft/vscode-jupyter/issues/1121\n",
-    "\n",
-    "{\n",
-    "    \"tags\": [\n",
-    "        \"hide-cell\"\n",
-    "    ]\n",
-    "}\n",
-    "\n",
-    "\n",
-    "### Install necessary libraries\n",
-    "\n",
-    "try:\n",
-    "    import jax\n",
-    "except:\n",
-    "    # For cuda version, see https://github.com/google/jax#installation\n",
-    "    %pip install --upgrade \"jax[cpu]\" \n",
-    "    import jax\n",
-    "\n",
-    "try:\n",
-    "    import distrax\n",
-    "except:\n",
-    "    %pip install --upgrade  distrax\n",
-    "    import distrax\n",
-    "\n",
-    "try:\n",
-    "    import jsl\n",
-    "except:\n",
-    "    %pip install git+https://github.com/probml/jsl\n",
-    "    import jsl\n",
-    "\n",
-    "try:\n",
-    "    import rich\n",
-    "except:\n",
-    "    %pip install rich\n",
-    "    import rich\n",
-    "\n",
-    "\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 2,
+   "execution_count": null,
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
    "source": [
    "source": [
-    "{\n",
-    "    \"tags\": [\n",
-    "        \"hide-cell\"\n",
-    "    ]\n",
-    "}\n",
-    "\n",
-    "\n",
     "### Import standard libraries\n",
     "### Import standard libraries\n",
     "\n",
     "\n",
     "import abc\n",
     "import abc\n",
     "from dataclasses import dataclass\n",
     "from dataclasses import dataclass\n",
     "import functools\n",
     "import functools\n",
+    "from functools import partial\n",
     "import itertools\n",
     "import itertools\n",
-    "\n",
-    "from typing import Any, Callable, NamedTuple, Optional, Union, Tuple\n",
-    "\n",
     "import matplotlib.pyplot as plt\n",
     "import matplotlib.pyplot as plt\n",
     "import numpy as np\n",
     "import numpy as np\n",
-    "\n",
+    "from typing import Any, Callable, NamedTuple, Optional, Union, Tuple\n",
     "\n",
     "\n",
     "import jax\n",
     "import jax\n",
     "import jax.numpy as jnp\n",
     "import jax.numpy as jnp\n",
     "from jax import lax, vmap, jit, grad\n",
     "from jax import lax, vmap, jit, grad\n",
-    "from jax.scipy.special import logit\n",
-    "from jax.nn import softmax\n",
-    "from functools import partial\n",
-    "from jax.random import PRNGKey, split\n",
+    "#from jax.scipy.special import logit\n",
+    "#from jax.nn import softmax\n",
+    "import jax.random as jr\n",
+    "\n",
+    "\n",
+    "\n",
+    "import distrax\n",
+    "import optax\n",
+    "\n",
+    "import jsl\n",
+    "import ssm_jax\n",
     "\n",
     "\n",
     "import inspect\n",
     "import inspect\n",
     "import inspect as py_inspect\n",
     "import inspect as py_inspect\n",
@@ -136,170 +46,6 @@
    "cell_type": "markdown",
    "cell_type": "markdown",
    "metadata": {},
    "metadata": {},
    "source": [
    "source": [
-    "```{math}\n",
-    "\n",
-    "\\newcommand\\floor[1]{\\lfloor#1\\rfloor}\n",
-    "\n",
-    "\\newcommand{\\real}{\\mathbb{R}}\n",
-    "\n",
-    "% Numbers\n",
-    "\\newcommand{\\vzero}{\\boldsymbol{0}}\n",
-    "\\newcommand{\\vone}{\\boldsymbol{1}}\n",
-    "\n",
-    "% Greek https://www.latex-tutorial.com/symbols/greek-alphabet/\n",
-    "\\newcommand{\\valpha}{\\boldsymbol{\\alpha}}\n",
-    "\\newcommand{\\vbeta}{\\boldsymbol{\\beta}}\n",
-    "\\newcommand{\\vchi}{\\boldsymbol{\\chi}}\n",
-    "\\newcommand{\\vdelta}{\\boldsymbol{\\delta}}\n",
-    "\\newcommand{\\vDelta}{\\boldsymbol{\\Delta}}\n",
-    "\\newcommand{\\vepsilon}{\\boldsymbol{\\epsilon}}\n",
-    "\\newcommand{\\vzeta}{\\boldsymbol{\\zeta}}\n",
-    "\\newcommand{\\vXi}{\\boldsymbol{\\Xi}}\n",
-    "\\newcommand{\\vell}{\\boldsymbol{\\ell}}\n",
-    "\\newcommand{\\veta}{\\boldsymbol{\\eta}}\n",
-    "%\\newcommand{\\vEta}{\\boldsymbol{\\Eta}}\n",
-    "\\newcommand{\\vgamma}{\\boldsymbol{\\gamma}}\n",
-    "\\newcommand{\\vGamma}{\\boldsymbol{\\Gamma}}\n",
-    "\\newcommand{\\vmu}{\\boldsymbol{\\mu}}\n",
-    "\\newcommand{\\vmut}{\\boldsymbol{\\tilde{\\mu}}}\n",
-    "\\newcommand{\\vnu}{\\boldsymbol{\\nu}}\n",
-    "\\newcommand{\\vkappa}{\\boldsymbol{\\kappa}}\n",
-    "\\newcommand{\\vlambda}{\\boldsymbol{\\lambda}}\n",
-    "\\newcommand{\\vLambda}{\\boldsymbol{\\Lambda}}\n",
-    "\\newcommand{\\vLambdaBar}{\\overline{\\vLambda}}\n",
-    "%\\newcommand{\\vnu}{\\boldsymbol{\\nu}}\n",
-    "\\newcommand{\\vomega}{\\boldsymbol{\\omega}}\n",
-    "\\newcommand{\\vOmega}{\\boldsymbol{\\Omega}}\n",
-    "\\newcommand{\\vphi}{\\boldsymbol{\\phi}}\n",
-    "\\newcommand{\\vvarphi}{\\boldsymbol{\\varphi}}\n",
-    "\\newcommand{\\vPhi}{\\boldsymbol{\\Phi}}\n",
-    "\\newcommand{\\vpi}{\\boldsymbol{\\pi}}\n",
-    "\\newcommand{\\vPi}{\\boldsymbol{\\Pi}}\n",
-    "\\newcommand{\\vpsi}{\\boldsymbol{\\psi}}\n",
-    "\\newcommand{\\vPsi}{\\boldsymbol{\\Psi}}\n",
-    "\\newcommand{\\vrho}{\\boldsymbol{\\rho}}\n",
-    "\\newcommand{\\vtheta}{\\boldsymbol{\\theta}}\n",
-    "\\newcommand{\\vthetat}{\\boldsymbol{\\tilde{\\theta}}}\n",
-    "\\newcommand{\\vTheta}{\\boldsymbol{\\Theta}}\n",
-    "\\newcommand{\\vsigma}{\\boldsymbol{\\sigma}}\n",
-    "\\newcommand{\\vSigma}{\\boldsymbol{\\Sigma}}\n",
-    "\\newcommand{\\vSigmat}{\\boldsymbol{\\tilde{\\Sigma}}}\n",
-    "\\newcommand{\\vsigmoid}{\\vsigma}\n",
-    "\\newcommand{\\vtau}{\\boldsymbol{\\tau}}\n",
-    "\\newcommand{\\vxi}{\\boldsymbol{\\xi}}\n",
-    "\n",
-    "\n",
-    "% Lower Roman (Vectors)\n",
-    "\\newcommand{\\va}{\\mathbf{a}}\n",
-    "\\newcommand{\\vb}{\\mathbf{b}}\n",
-    "\\newcommand{\\vBt}{\\mathbf{\\tilde{B}}}\n",
-    "\\newcommand{\\vc}{\\mathbf{c}}\n",
-    "\\newcommand{\\vct}{\\mathbf{\\tilde{c}}}\n",
-    "\\newcommand{\\vd}{\\mathbf{d}}\n",
-    "\\newcommand{\\ve}{\\mathbf{e}}\n",
-    "\\newcommand{\\vf}{\\mathbf{f}}\n",
-    "\\newcommand{\\vg}{\\mathbf{g}}\n",
-    "\\newcommand{\\vh}{\\mathbf{h}}\n",
-    "%\\newcommand{\\myvh}{\\mathbf{h}}\n",
-    "\\newcommand{\\vi}{\\mathbf{i}}\n",
-    "\\newcommand{\\vj}{\\mathbf{j}}\n",
-    "\\newcommand{\\vk}{\\mathbf{k}}\n",
-    "\\newcommand{\\vl}{\\mathbf{l}}\n",
-    "\\newcommand{\\vm}{\\mathbf{m}}\n",
-    "\\newcommand{\\vn}{\\mathbf{n}}\n",
-    "\\newcommand{\\vo}{\\mathbf{o}}\n",
-    "\\newcommand{\\vp}{\\mathbf{p}}\n",
-    "\\newcommand{\\vq}{\\mathbf{q}}\n",
-    "\\newcommand{\\vr}{\\mathbf{r}}\n",
-    "\\newcommand{\\vs}{\\mathbf{s}}\n",
-    "\\newcommand{\\vt}{\\mathbf{t}}\n",
-    "\\newcommand{\\vu}{\\mathbf{u}}\n",
-    "\\newcommand{\\vv}{\\mathbf{v}}\n",
-    "\\newcommand{\\vw}{\\mathbf{w}}\n",
-    "\\newcommand{\\vws}{\\vw_s}\n",
-    "\\newcommand{\\vwt}{\\mathbf{\\tilde{w}}}\n",
-    "\\newcommand{\\vWt}{\\mathbf{\\tilde{W}}}\n",
-    "\\newcommand{\\vwh}{\\hat{\\vw}}\n",
-    "\\newcommand{\\vx}{\\mathbf{x}}\n",
-    "%\\newcommand{\\vx}{\\mathbf{x}}\n",
-    "\\newcommand{\\vxt}{\\mathbf{\\tilde{x}}}\n",
-    "\\newcommand{\\vy}{\\mathbf{y}}\n",
-    "\\newcommand{\\vyt}{\\mathbf{\\tilde{y}}}\n",
-    "\\newcommand{\\vz}{\\mathbf{z}}\n",
-    "%\\newcommand{\\vzt}{\\mathbf{\\tilde{z}}}\n",
-    "\n",
-    "\n",
-    "% Upper Roman (Matrices)\n",
-    "\\newcommand{\\vA}{\\mathbf{A}}\n",
-    "\\newcommand{\\vB}{\\mathbf{B}}\n",
-    "\\newcommand{\\vC}{\\mathbf{C}}\n",
-    "\\newcommand{\\vD}{\\mathbf{D}}\n",
-    "\\newcommand{\\vE}{\\mathbf{E}}\n",
-    "\\newcommand{\\vF}{\\mathbf{F}}\n",
-    "\\newcommand{\\vG}{\\mathbf{G}}\n",
-    "\\newcommand{\\vH}{\\mathbf{H}}\n",
-    "\\newcommand{\\vI}{\\mathbf{I}}\n",
-    "\\newcommand{\\vJ}{\\mathbf{J}}\n",
-    "\\newcommand{\\vK}{\\mathbf{K}}\n",
-    "\\newcommand{\\vL}{\\mathbf{L}}\n",
-    "\\newcommand{\\vM}{\\mathbf{M}}\n",
-    "\\newcommand{\\vMt}{\\mathbf{\\tilde{M}}}\n",
-    "\\newcommand{\\vN}{\\mathbf{N}}\n",
-    "\\newcommand{\\vO}{\\mathbf{O}}\n",
-    "\\newcommand{\\vP}{\\mathbf{P}}\n",
-    "\\newcommand{\\vQ}{\\mathbf{Q}}\n",
-    "\\newcommand{\\vR}{\\mathbf{R}}\n",
-    "\\newcommand{\\vS}{\\mathbf{S}}\n",
-    "\\newcommand{\\vT}{\\mathbf{T}}\n",
-    "\\newcommand{\\vU}{\\mathbf{U}}\n",
-    "\\newcommand{\\vV}{\\mathbf{V}}\n",
-    "\\newcommand{\\vW}{\\mathbf{W}}\n",
-    "\\newcommand{\\vX}{\\mathbf{X}}\n",
-    "%\\newcommand{\\vXs}{\\vX_{\\vs}}\n",
-    "\\newcommand{\\vXs}{\\vX_{s}}\n",
-    "\\newcommand{\\vXt}{\\mathbf{\\tilde{X}}}\n",
-    "\\newcommand{\\vY}{\\mathbf{Y}}\n",
-    "\\newcommand{\\vZ}{\\mathbf{Z}}\n",
-    "\\newcommand{\\vZt}{\\mathbf{\\tilde{Z}}}\n",
-    "\\newcommand{\\vzt}{\\mathbf{\\tilde{z}}}\n",
-    "\n",
-    "\n",
-    "%%%%\n",
-    "\\newcommand{\\hidden}{\\vz}\n",
-    "\\newcommand{\\hid}{\\hidden}\n",
-    "\\newcommand{\\observed}{\\vy}\n",
-    "\\newcommand{\\obs}{\\observed}\n",
-    "\\newcommand{\\inputs}{\\vu}\n",
-    "\\newcommand{\\input}{\\inputs}\n",
-    "\n",
-    "\\newcommand{\\hmmTrans}{\\vA}\n",
-    "\\newcommand{\\hmmObs}{\\vB}\n",
-    "\\newcommand{\\hmmInit}{\\vpi}\n",
-    "\\newcommand{\\hmmhid}{\\hidden}\n",
-    "\\newcommand{\\hmmobs}{\\obs}\n",
-    "\n",
-    "\\newcommand{\\ldsDyn}{\\vA}\n",
-    "\\newcommand{\\ldsObs}{\\vC}\n",
-    "\\newcommand{\\ldsDynIn}{\\vB}\n",
-    "\\newcommand{\\ldsObsIn}{\\vD}\n",
-    "\\newcommand{\\ldsDynNoise}{\\vQ}\n",
-    "\\newcommand{\\ldsObsNoise}{\\vR}\n",
-    "\n",
-    "\\newcommand{\\ssmDynFn}{f}\n",
-    "\\newcommand{\\ssmObsFn}{h}\n",
-    "\n",
-    "\n",
-    "%%%\n",
-    "\\newcommand{\\gauss}{\\mathcal{N}}\n",
-    "\n",
-    "\\newcommand{\\diag}{\\mathrm{diag}}\n",
-    "```\n"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
     "(sec:lds-intro)=\n",
     "(sec:lds-intro)=\n",
     "# Linear Gaussian SSMs\n",
     "# Linear Gaussian SSMs\n",
     "\n",
     "\n",
@@ -310,50 +56,50 @@
     "hidden states and inputs (i.e. there are no auto-regressive dependencies\n",
     "hidden states and inputs (i.e. there are no auto-regressive dependencies\n",
     "between the observables).\n",
     "between the observables).\n",
     "We can rewrite this model as \n",
     "We can rewrite this model as \n",
-    "a stochastic nonlinear dynamical system (NLDS)\n",
+    "a stochastic $\\keyword{nonlinear dynamical system}$ or $\\keyword{NLDS}$\n",
     "by defining the distribution of the next hidden state \n",
     "by defining the distribution of the next hidden state \n",
     "as a deterministic function of the past state\n",
     "as a deterministic function of the past state\n",
-    "plus random process noise $\\vepsilon_t$ \n",
+    "plus random $\\keyword{process noise}$ $\\transNoise_t$ \n",
     "\\begin{align}\n",
     "\\begin{align}\n",
-    "\\hmmhid_t &= \\ssmDynFn(\\hmmhid_{t-1}, \\inputs_t, \\vepsilon_t)  \n",
+    "\\hidden_t &= \\dynamicsFn(\\hidden_{t-1}, \\inputs_t, \\transNoise_t)  \n",
     "\\end{align}\n",
     "\\end{align}\n",
-    "where $\\vepsilon_t$ is drawn from the distribution such\n",
+    "where $\\transNoise_t$ is drawn from the distribution such\n",
     "that the induced distribution\n",
     "that the induced distribution\n",
-    "on $\\hmmhid_t$ matches $p(\\hmmhid_t|\\hmmhid_{t-1}, \\inputs_t)$.\n",
+    "on $\\hidden_t$ matches $p(\\hidden_t|\\hidden_{t-1}, \\inputs_t)$.\n",
     "Similarly we can rewrite the observation distributions\n",
     "Similarly we can rewrite the observation distributions\n",
     "as a deterministic function of the hidden state\n",
     "as a deterministic function of the hidden state\n",
-    "plus observation noise $\\veta_t$:\n",
+    "plus $\\keyword{observation noise}$ $\\obsNoise_t$:\n",
     "\\begin{align}\n",
     "\\begin{align}\n",
-    "\\hmmobs_t &= \\ssmObsFn(\\hmmhid_{t}, \\inputs_t, \\veta_t)\n",
+    "\\obs_t &= \\measurementFn(\\hidden_{t}, \\inputs_t, \\obsNoise_t)\n",
     "\\end{align}\n",
     "\\end{align}\n",
     "\n",
     "\n",
     "\n",
     "\n",
     "If we assume additive Gaussian noise,\n",
     "If we assume additive Gaussian noise,\n",
     "the model becomes\n",
     "the model becomes\n",
     "\\begin{align}\n",
     "\\begin{align}\n",
-    "\\hmmhid_t &= \\ssmDynFn(\\hmmhid_{t-1}, \\inputs_t) +  \\vepsilon_t  \\\\\n",
-    "\\hmmobs_t &= \\ssmObsFn(\\hmmhid_{t}, \\inputs_t) + \\veta_t\n",
+    "\\hidden_t &= \\dynamicsFn(\\hidden_{t-1}, \\inputs_t) +  \\transNoise_t  \\\\\n",
+    "\\obs_t &= \\measurementFn(\\hidden_{t}, \\inputs_t) + \\obsNoise_t\n",
     "\\end{align}\n",
     "\\end{align}\n",
-    "where $\\vepsilon_t \\sim \\gauss(\\vzero,\\vQ_t)$\n",
-    "and $\\veta_t \\sim \\gauss(\\vzero,\\vR_t)$.\n",
-    "We will call these Gaussian SSMs.\n",
+    "where $\\transNoise_t \\sim \\gauss(\\vzero,\\transCov_t)$\n",
+    "and $\\obsNoise_t \\sim \\gauss(\\vzero,\\obsCov_t)$.\n",
+    "We will call these $\\keyword{Gaussian SSMs}$.\n",
     "\n",
     "\n",
     "If we additionally assume\n",
     "If we additionally assume\n",
-    "the transition function $\\ssmDynFn$\n",
-    "and the observation function $\\ssmObsFn$ are both linear,\n",
+    "the transition function $\\dynamicsFn$\n",
+    "and the observation function $\\measurementFn$ are both linear,\n",
     "then we can rewrite the model as follows:\n",
     "then we can rewrite the model as follows:\n",
     "\\begin{align}\n",
     "\\begin{align}\n",
-    "p(\\hmmhid_t|\\hmmhid_{t-1},\\inputs_t) &= \\gauss(\\hmmhid_t|\\ldsDyn_t \\hmmhid_{t-1}\n",
-    "+ \\ldsDynIn_t \\inputs_t, \\vQ_t)\n",
+    "p(\\hidden_t|\\hidden_{t-1},\\inputs_t) &= \\gauss(\\hidden_t|\\ldsDyn_t \\hidden_{t-1}\n",
+    "+ \\ldsDynIn_t \\inputs_t, \\transCov_t)\n",
     "\\\\\n",
     "\\\\\n",
-    "p(\\hmmobs_t|\\hmmhid_t,\\inputs_t) &= \\gauss(\\hmmobs_t|\\ldsObs_t \\hmmhid_{t}\n",
-    "+ \\ldsObsIn_t \\inputs_t, \\vR_t)\n",
+    "p(\\obs_t|\\hidden_t,\\inputs_t) &= \\gauss(\\obs_t|\\ldsObs_t \\hidden_{t}\n",
+    "+ \\ldsObsIn_t \\inputs_t, \\obsCov_t)\n",
     "\\end{align}\n",
     "\\end{align}\n",
     "This is called a \n",
     "This is called a \n",
-    "linear-Gaussian state space model\n",
-    "(LG-SSM),\n",
-    "or a\n",
-    "linear dynamical system (LDS).\n",
+    "$\\keyword{linear-Gaussian state space model}\n",
+    "or $\\keyword{LG-SSM}$;\n",
+    "it is also called \n",
+    "a $\\keyword{linear dynamical system}$ or $\\keyword{LDS}$.\n",
     "We usually assume the parameters are independent of time, in which case\n",
     "We usually assume the parameters are independent of time, in which case\n",
     "the model is said to be time-invariant or homogeneous.\n"
     "the model is said to be time-invariant or homogeneous.\n"
    ]
    ]
@@ -372,7 +118,7 @@
     "Consider an object moving in $\\real^2$.\n",
     "Consider an object moving in $\\real^2$.\n",
     "Let the state be\n",
     "Let the state be\n",
     "the position and velocity of the object,\n",
     "the position and velocity of the object,\n",
-    "$\\vz_t =\\begin{pmatrix} u_t & \\dot{u}_t & v_t & \\dot{v}_t \\end{pmatrix}$.\n",
+    "$\\hidden_t =\\begin{pmatrix} u_t & \\dot{u}_t & v_t & \\dot{v}_t \\end{pmatrix}$.\n",
     "(We use $u$ and $v$ for the two coordinates,\n",
     "(We use $u$ and $v$ for the two coordinates,\n",
     "to avoid confusion with the state and observation variables.)\n",
     "to avoid confusion with the state and observation variables.)\n",
     "If we use Euler discretization,\n",
     "If we use Euler discretization,\n",
@@ -390,9 +136,9 @@
     "}_{\\ldsDyn}\n",
     "}_{\\ldsDyn}\n",
     "\\\n",
     "\\\n",
     "\\underbrace{\\begin{pmatrix} u_{t-1} \\\\ \\dot{u}_{t-1} \\\\ v_{t-1} \\\\ \\dot{v}_{t-1} \\end{pmatrix}}_{\\vz_{t-1}}\n",
     "\\underbrace{\\begin{pmatrix} u_{t-1} \\\\ \\dot{u}_{t-1} \\\\ v_{t-1} \\\\ \\dot{v}_{t-1} \\end{pmatrix}}_{\\vz_{t-1}}\n",
-    "+ \\vepsilon_t\n",
+    "+ \\transNoise_t\n",
     "\\end{align}\n",
     "\\end{align}\n",
-    "where $\\vepsilon_t \\sim \\gauss(\\vzero,\\vQ)$ is\n",
+    "where $\\transNoise_t \\sim \\gauss(\\vzero,\\transCov)$ is\n",
     "the process noise.\n",
     "the process noise.\n",
     "\n",
     "\n",
     "Let us assume\n",
     "Let us assume\n",
@@ -401,7 +147,7 @@
     "of the state, but not to the location.\n",
     "of the state, but not to the location.\n",
     "(This is known as a random accelerations model.)\n",
     "(This is known as a random accelerations model.)\n",
     "We can approximate the resulting process in discrete time by assuming\n",
     "We can approximate the resulting process in discrete time by assuming\n",
-    "$\\vQ = \\diag(0, q, 0, q)$.\n",
+    "$\\transCov = \\diag(0, q, 0, q)$.\n",
     "(See  {cite}`Sarkka13` p60 for a more accurate way\n",
     "(See  {cite}`Sarkka13` p60 for a more accurate way\n",
     "to convert the continuous time process to discrete time.)\n",
     "to convert the continuous time process to discrete time.)\n",
     "\n",
     "\n",
@@ -411,7 +157,7 @@
     "corrupted by  Gaussian noise.\n",
     "corrupted by  Gaussian noise.\n",
     "Thus the observation model becomes\n",
     "Thus the observation model becomes\n",
     "\\begin{align}\n",
     "\\begin{align}\n",
-    "\\underbrace{\\begin{pmatrix}  y_{1,t} \\\\  y_{2,t} \\end{pmatrix}}_{\\vy_t}\n",
+    "\\underbrace{\\begin{pmatrix}  \\obs_{1,t} \\\\  \\obs_{2,t} \\end{pmatrix}}_{\\obs_t}\n",
     "  &=\n",
     "  &=\n",
     "    \\underbrace{\n",
     "    \\underbrace{\n",
     "    \\begin{pmatrix}\n",
     "    \\begin{pmatrix}\n",
@@ -420,18 +166,18 @@
     "    \\end{pmatrix}\n",
     "    \\end{pmatrix}\n",
     "    }_{\\ldsObs}\n",
     "    }_{\\ldsObs}\n",
     "    \\\n",
     "    \\\n",
-    "\\underbrace{\\begin{pmatrix} u_t\\\\ \\dot{u}_t \\\\ v_t \\\\ \\dot{v}_t \\end{pmatrix}}_{\\vz_t}    \n",
-    " + \\veta_t\n",
+    "\\underbrace{\\begin{pmatrix} u_t\\\\ \\dot{u}_t \\\\ v_t \\\\ \\dot{v}_t \\end{pmatrix}}_{\\hidden_t}    \n",
+    " + \\obsNoise_t\n",
     "\\end{align}\n",
     "\\end{align}\n",
-    "where $\\veta_t \\sim \\gauss(\\vzero,\\vR)$ is the \\keywordDef{observation noise}.\n",
+    "where $\\obsNoise_t \\sim \\gauss(\\vzero,\\obsCov)$ is the \\keywordDef{observation noise}.\n",
     "We see that the observation matrix $\\ldsObs$ simply ``extracts'' the\n",
     "We see that the observation matrix $\\ldsObs$ simply ``extracts'' the\n",
     "relevant parts  of the state vector.\n",
     "relevant parts  of the state vector.\n",
     "\n",
     "\n",
     "Suppose we sample a trajectory and corresponding set\n",
     "Suppose we sample a trajectory and corresponding set\n",
     "of noisy observations from this model,\n",
     "of noisy observations from this model,\n",
-    "$(\\vz_{1:T}, \\vy_{1:T}) \\sim p(\\vz,\\vy|\\vtheta)$.\n",
+    "$(\\hidden_{1:T}, \\obs_{1:T}) \\sim p(\\hidden,\\obs|\\params)$.\n",
     "(We use diagonal observation noise,\n",
     "(We use diagonal observation noise,\n",
-    "$\\vR = \\diag(\\sigma_1^2, \\sigma_2^2)$.)\n",
+    "$\\obsCov = \\diag(\\sigma_1^2, \\sigma_2^2)$.)\n",
     "The results are shown below. \n"
     "The results are shown below. \n"
    ]
    ]
   },
   },
@@ -561,7 +307,15 @@
    "metadata": {},
    "metadata": {},
    "source": [
    "source": [
     "The main task is to infer the hidden states given the noisy\n",
     "The main task is to infer the hidden states given the noisy\n",
-    "observations, i.e., $p(\\vz|\\vy,\\vtheta)$. We discuss the topic of inference in {ref}`sec:inference`."
+    "observations, i.e., $p(\\hidden_t|\\obs_{1:t},\\params)$\n",
+    "or $p(\\hidden_t|\\obs_{1:T}, \\params)$ in the offline case.\n",
+    "We discuss the topic of inference in {ref}`sec:inference`.\n",
+    "We will usually represent this belief state by a Gaussian distribution,\n",
+    "$p(\\hidden_t|\\obs_{1:t},\\params) = \\gauss(\\hidden_t|\\mean_{t|t}, \\covMat_{t|t})$\n",
+    "and\n",
+    "$p(\\hidden_t|\\obs_{1:T},\\params) = \\gauss(\\hidden_t|\\mean_{t|T}, \\covMat_{t|T})$.\n",
+    "Sometimes we use information form, \n",
+    "$p(\\hidden_t|\\obs_{1:t},\\params) = \\gaussInfo(\\hidden_t|\\mean_{t|t}, \\covMat_{t|t})$"
    ]
    ]
   }
   }
  ],
  ],

+ 116 - 0
chapters/ssm/learning.ipynb

@@ -0,0 +1,116 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "### Import standard libraries\n",
+    "\n",
+    "import abc\n",
+    "from dataclasses import dataclass\n",
+    "import functools\n",
+    "from functools import partial\n",
+    "import itertools\n",
+    "import matplotlib.pyplot as plt\n",
+    "import numpy as np\n",
+    "from typing import Any, Callable, NamedTuple, Optional, Union, Tuple\n",
+    "\n",
+    "import jax\n",
+    "import jax.numpy as jnp\n",
+    "from jax import lax, vmap, jit, grad\n",
+    "#from jax.scipy.special import logit\n",
+    "#from jax.nn import softmax\n",
+    "import jax.random as jr\n",
+    "\n",
+    "\n",
+    "\n",
+    "import distrax\n",
+    "import optax\n",
+    "\n",
+    "import jsl\n",
+    "import ssm_jax\n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "(sec:learning)=\n",
+    "# Parameter estimation (learning)\n",
+    "\n",
+    "\n",
+    "So far, we have assumed that the parameters $\\params$ of the SSM are known.\n",
+    "For example, in the case of an HMM with categorical observations\n",
+    "we have $\\params = (\\hmmInit, \\hmmTrans, \\hmmObs)$,\n",
+    "and in the case of an LDS, we have $\\params = (\\ldsTrans, \\ldsObs, \\ldsTransNoise, \\ldsObsNoise)$.\n",
+    "(We also need to specify the initial state distribution, $(\\ldsInitMean, \\ldsInitCov)$.)\n",
+    "If we adopt a Bayesian perspective, we can view these parameters as random variables that are\n",
+    "shared across all time steps, and across all sequences.\n",
+    "This is shown in {numref}`fig:hmm-plates`, where we adopt $\\keyword{plate notation}$\n",
+    "to represent repetitive structure.\n",
+    "\n",
+    "```{figure} /figures/hmmDgmPlatesY.png\n",
+    ":scale: 100%\n",
+    ":name: fig:hmm-plates\n",
+    "\n",
+    "Illustration of an HMM using plate notation, where we show the parameter\n",
+    "nodes which are shared across all the sequences.\n",
+    "```\n",
+    "\n",
+    "!!!\n",
+    "Suppose we observe $N$ sequences $\\data = \\{\\obs_{n,1:T_n}: n=1:N\\}$.\n",
+    "Then the goal of $\\keyword{parameter estimation}$, also called $\\keyword{model learning}$\n",
+    "or $\\keyword{model fitting}$, is to approximate the posterior\n",
+    "\\begin{align}\n",
+    "p(\\params|\\data) \\propto p(\\params) \\prod_{n=1}^N p(\\obs_{n,1:T_n} | \\params)\n",
+    "\\end{align}\n",
+    "where $p(\\obs_{n,1:T_n} | \\params)$ is the marginal likelihood of sequence $n$:\n",
+    "\\begin{align}\n",
+    "p(\\obs_{1:T} | \\params) = \\int  p(\\hidden_{1:T}, \\obs_{1:T} | \\params) d\\hidden_{1:T}\n",
+    "\\end{align}\n",
+    "\n",
+    "Since computing the full posterior is computationally difficult, we often settle for computing\n",
+    "a point estimate such as the MAP (maximum a posterior) estimate\n",
+    "\\begin{align}\n",
+    "\\params_{\\map} = \\arg \\max_{\\params} \\log p(\\params) + \\sum_{n=1}^N \\log p(\\obs_{n,1:T_n} | \\params)\n",
+    "\\end{align}\n",
+    "If we ignore the prior term, we get the maximum likelihood estimate or MLE:\n",
+    "\\begin{align}\n",
+    "\\params_{\\mle} = \\arg \\max_{\\params}  \\sum_{n=1}^N \\log p(\\obs_{n,1:T_n} | \\params)\n",
+    "\\end{align}\n",
+    "In practice, the MAP estimate often works better than the MLE, since the prior can regularize\n",
+    "the estimate to ensure the model is numerically stable and does not overfit the training set.\n",
+    "\n",
+    "We will discuss a variety of algorithms for parameter estimation in later chapters.\n",
+    "\n"
+   ]
+  }
+ ],
+ "metadata": {
+  "interpreter": {
+   "hash": "6407c60499271029b671b4ff687c4ed4626355c45fd34c44476827f4be42c4d7"
+  },
+  "kernelspec": {
+   "display_name": "Python 3.9.2 ('spyder-dev')",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.9.2"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}

+ 0 - 164
chapters/ssm/nlds.ipynb

@@ -136,170 +136,6 @@
    "cell_type": "markdown",
    "cell_type": "markdown",
    "metadata": {},
    "metadata": {},
    "source": [
    "source": [
-    "```{math}\n",
-    "\n",
-    "\\newcommand\\floor[1]{\\lfloor#1\\rfloor}\n",
-    "\n",
-    "\\newcommand{\\real}{\\mathbb{R}}\n",
-    "\n",
-    "% Numbers\n",
-    "\\newcommand{\\vzero}{\\boldsymbol{0}}\n",
-    "\\newcommand{\\vone}{\\boldsymbol{1}}\n",
-    "\n",
-    "% Greek https://www.latex-tutorial.com/symbols/greek-alphabet/\n",
-    "\\newcommand{\\valpha}{\\boldsymbol{\\alpha}}\n",
-    "\\newcommand{\\vbeta}{\\boldsymbol{\\beta}}\n",
-    "\\newcommand{\\vchi}{\\boldsymbol{\\chi}}\n",
-    "\\newcommand{\\vdelta}{\\boldsymbol{\\delta}}\n",
-    "\\newcommand{\\vDelta}{\\boldsymbol{\\Delta}}\n",
-    "\\newcommand{\\vepsilon}{\\boldsymbol{\\epsilon}}\n",
-    "\\newcommand{\\vzeta}{\\boldsymbol{\\zeta}}\n",
-    "\\newcommand{\\vXi}{\\boldsymbol{\\Xi}}\n",
-    "\\newcommand{\\vell}{\\boldsymbol{\\ell}}\n",
-    "\\newcommand{\\veta}{\\boldsymbol{\\eta}}\n",
-    "%\\newcommand{\\vEta}{\\boldsymbol{\\Eta}}\n",
-    "\\newcommand{\\vgamma}{\\boldsymbol{\\gamma}}\n",
-    "\\newcommand{\\vGamma}{\\boldsymbol{\\Gamma}}\n",
-    "\\newcommand{\\vmu}{\\boldsymbol{\\mu}}\n",
-    "\\newcommand{\\vmut}{\\boldsymbol{\\tilde{\\mu}}}\n",
-    "\\newcommand{\\vnu}{\\boldsymbol{\\nu}}\n",
-    "\\newcommand{\\vkappa}{\\boldsymbol{\\kappa}}\n",
-    "\\newcommand{\\vlambda}{\\boldsymbol{\\lambda}}\n",
-    "\\newcommand{\\vLambda}{\\boldsymbol{\\Lambda}}\n",
-    "\\newcommand{\\vLambdaBar}{\\overline{\\vLambda}}\n",
-    "%\\newcommand{\\vnu}{\\boldsymbol{\\nu}}\n",
-    "\\newcommand{\\vomega}{\\boldsymbol{\\omega}}\n",
-    "\\newcommand{\\vOmega}{\\boldsymbol{\\Omega}}\n",
-    "\\newcommand{\\vphi}{\\boldsymbol{\\phi}}\n",
-    "\\newcommand{\\vvarphi}{\\boldsymbol{\\varphi}}\n",
-    "\\newcommand{\\vPhi}{\\boldsymbol{\\Phi}}\n",
-    "\\newcommand{\\vpi}{\\boldsymbol{\\pi}}\n",
-    "\\newcommand{\\vPi}{\\boldsymbol{\\Pi}}\n",
-    "\\newcommand{\\vpsi}{\\boldsymbol{\\psi}}\n",
-    "\\newcommand{\\vPsi}{\\boldsymbol{\\Psi}}\n",
-    "\\newcommand{\\vrho}{\\boldsymbol{\\rho}}\n",
-    "\\newcommand{\\vtheta}{\\boldsymbol{\\theta}}\n",
-    "\\newcommand{\\vthetat}{\\boldsymbol{\\tilde{\\theta}}}\n",
-    "\\newcommand{\\vTheta}{\\boldsymbol{\\Theta}}\n",
-    "\\newcommand{\\vsigma}{\\boldsymbol{\\sigma}}\n",
-    "\\newcommand{\\vSigma}{\\boldsymbol{\\Sigma}}\n",
-    "\\newcommand{\\vSigmat}{\\boldsymbol{\\tilde{\\Sigma}}}\n",
-    "\\newcommand{\\vsigmoid}{\\vsigma}\n",
-    "\\newcommand{\\vtau}{\\boldsymbol{\\tau}}\n",
-    "\\newcommand{\\vxi}{\\boldsymbol{\\xi}}\n",
-    "\n",
-    "\n",
-    "% Lower Roman (Vectors)\n",
-    "\\newcommand{\\va}{\\mathbf{a}}\n",
-    "\\newcommand{\\vb}{\\mathbf{b}}\n",
-    "\\newcommand{\\vBt}{\\mathbf{\\tilde{B}}}\n",
-    "\\newcommand{\\vc}{\\mathbf{c}}\n",
-    "\\newcommand{\\vct}{\\mathbf{\\tilde{c}}}\n",
-    "\\newcommand{\\vd}{\\mathbf{d}}\n",
-    "\\newcommand{\\ve}{\\mathbf{e}}\n",
-    "\\newcommand{\\vf}{\\mathbf{f}}\n",
-    "\\newcommand{\\vg}{\\mathbf{g}}\n",
-    "\\newcommand{\\vh}{\\mathbf{h}}\n",
-    "%\\newcommand{\\myvh}{\\mathbf{h}}\n",
-    "\\newcommand{\\vi}{\\mathbf{i}}\n",
-    "\\newcommand{\\vj}{\\mathbf{j}}\n",
-    "\\newcommand{\\vk}{\\mathbf{k}}\n",
-    "\\newcommand{\\vl}{\\mathbf{l}}\n",
-    "\\newcommand{\\vm}{\\mathbf{m}}\n",
-    "\\newcommand{\\vn}{\\mathbf{n}}\n",
-    "\\newcommand{\\vo}{\\mathbf{o}}\n",
-    "\\newcommand{\\vp}{\\mathbf{p}}\n",
-    "\\newcommand{\\vq}{\\mathbf{q}}\n",
-    "\\newcommand{\\vr}{\\mathbf{r}}\n",
-    "\\newcommand{\\vs}{\\mathbf{s}}\n",
-    "\\newcommand{\\vt}{\\mathbf{t}}\n",
-    "\\newcommand{\\vu}{\\mathbf{u}}\n",
-    "\\newcommand{\\vv}{\\mathbf{v}}\n",
-    "\\newcommand{\\vw}{\\mathbf{w}}\n",
-    "\\newcommand{\\vws}{\\vw_s}\n",
-    "\\newcommand{\\vwt}{\\mathbf{\\tilde{w}}}\n",
-    "\\newcommand{\\vWt}{\\mathbf{\\tilde{W}}}\n",
-    "\\newcommand{\\vwh}{\\hat{\\vw}}\n",
-    "\\newcommand{\\vx}{\\mathbf{x}}\n",
-    "%\\newcommand{\\vx}{\\mathbf{x}}\n",
-    "\\newcommand{\\vxt}{\\mathbf{\\tilde{x}}}\n",
-    "\\newcommand{\\vy}{\\mathbf{y}}\n",
-    "\\newcommand{\\vyt}{\\mathbf{\\tilde{y}}}\n",
-    "\\newcommand{\\vz}{\\mathbf{z}}\n",
-    "%\\newcommand{\\vzt}{\\mathbf{\\tilde{z}}}\n",
-    "\n",
-    "\n",
-    "% Upper Roman (Matrices)\n",
-    "\\newcommand{\\vA}{\\mathbf{A}}\n",
-    "\\newcommand{\\vB}{\\mathbf{B}}\n",
-    "\\newcommand{\\vC}{\\mathbf{C}}\n",
-    "\\newcommand{\\vD}{\\mathbf{D}}\n",
-    "\\newcommand{\\vE}{\\mathbf{E}}\n",
-    "\\newcommand{\\vF}{\\mathbf{F}}\n",
-    "\\newcommand{\\vG}{\\mathbf{G}}\n",
-    "\\newcommand{\\vH}{\\mathbf{H}}\n",
-    "\\newcommand{\\vI}{\\mathbf{I}}\n",
-    "\\newcommand{\\vJ}{\\mathbf{J}}\n",
-    "\\newcommand{\\vK}{\\mathbf{K}}\n",
-    "\\newcommand{\\vL}{\\mathbf{L}}\n",
-    "\\newcommand{\\vM}{\\mathbf{M}}\n",
-    "\\newcommand{\\vMt}{\\mathbf{\\tilde{M}}}\n",
-    "\\newcommand{\\vN}{\\mathbf{N}}\n",
-    "\\newcommand{\\vO}{\\mathbf{O}}\n",
-    "\\newcommand{\\vP}{\\mathbf{P}}\n",
-    "\\newcommand{\\vQ}{\\mathbf{Q}}\n",
-    "\\newcommand{\\vR}{\\mathbf{R}}\n",
-    "\\newcommand{\\vS}{\\mathbf{S}}\n",
-    "\\newcommand{\\vT}{\\mathbf{T}}\n",
-    "\\newcommand{\\vU}{\\mathbf{U}}\n",
-    "\\newcommand{\\vV}{\\mathbf{V}}\n",
-    "\\newcommand{\\vW}{\\mathbf{W}}\n",
-    "\\newcommand{\\vX}{\\mathbf{X}}\n",
-    "%\\newcommand{\\vXs}{\\vX_{\\vs}}\n",
-    "\\newcommand{\\vXs}{\\vX_{s}}\n",
-    "\\newcommand{\\vXt}{\\mathbf{\\tilde{X}}}\n",
-    "\\newcommand{\\vY}{\\mathbf{Y}}\n",
-    "\\newcommand{\\vZ}{\\mathbf{Z}}\n",
-    "\\newcommand{\\vZt}{\\mathbf{\\tilde{Z}}}\n",
-    "\\newcommand{\\vzt}{\\mathbf{\\tilde{z}}}\n",
-    "\n",
-    "\n",
-    "%%%%\n",
-    "\\newcommand{\\hidden}{\\vz}\n",
-    "\\newcommand{\\hid}{\\hidden}\n",
-    "\\newcommand{\\observed}{\\vy}\n",
-    "\\newcommand{\\obs}{\\observed}\n",
-    "\\newcommand{\\inputs}{\\vu}\n",
-    "\\newcommand{\\input}{\\inputs}\n",
-    "\n",
-    "\\newcommand{\\hmmTrans}{\\vA}\n",
-    "\\newcommand{\\hmmObs}{\\vB}\n",
-    "\\newcommand{\\hmmInit}{\\vpi}\n",
-    "\\newcommand{\\hmmhid}{\\hidden}\n",
-    "\\newcommand{\\hmmobs}{\\obs}\n",
-    "\n",
-    "\\newcommand{\\ldsDyn}{\\vA}\n",
-    "\\newcommand{\\ldsObs}{\\vC}\n",
-    "\\newcommand{\\ldsDynIn}{\\vB}\n",
-    "\\newcommand{\\ldsObsIn}{\\vD}\n",
-    "\\newcommand{\\ldsDynNoise}{\\vQ}\n",
-    "\\newcommand{\\ldsObsNoise}{\\vR}\n",
-    "\n",
-    "\\newcommand{\\ssmDynFn}{f}\n",
-    "\\newcommand{\\ssmObsFn}{h}\n",
-    "\n",
-    "\n",
-    "%%%\n",
-    "\\newcommand{\\gauss}{\\mathcal{N}}\n",
-    "\n",
-    "\\newcommand{\\diag}{\\mathrm{diag}}\n",
-    "```\n"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
     "(sec:nlds-intro)=\n",
     "(sec:nlds-intro)=\n",
     "# Nonlinear Gaussian SSMs\n",
     "# Nonlinear Gaussian SSMs\n",
     "\n",
     "\n",

+ 21 - 186
chapters/ssm/ssm_intro.ipynb

@@ -4,175 +4,6 @@
    "cell_type": "markdown",
    "cell_type": "markdown",
    "metadata": {},
    "metadata": {},
    "source": [
    "source": [
-    "```{math}\n",
-    "\n",
-    "\\newcommand\\floor[1]{\\lfloor#1\\rfloor}\n",
-    "\n",
-    "\\newcommand{\\real}{\\mathbb{R}}\n",
-    "\n",
-    "% Numbers\n",
-    "\\newcommand{\\vzero}{\\boldsymbol{0}}\n",
-    "\\newcommand{\\vone}{\\boldsymbol{1}}\n",
-    "\n",
-    "% Greek https://www.latex-tutorial.com/symbols/greek-alphabet/\n",
-    "\\newcommand{\\valpha}{\\boldsymbol{\\alpha}}\n",
-    "\\newcommand{\\vbeta}{\\boldsymbol{\\beta}}\n",
-    "\\newcommand{\\vchi}{\\boldsymbol{\\chi}}\n",
-    "\\newcommand{\\vdelta}{\\boldsymbol{\\delta}}\n",
-    "\\newcommand{\\vDelta}{\\boldsymbol{\\Delta}}\n",
-    "\\newcommand{\\vepsilon}{\\boldsymbol{\\epsilon}}\n",
-    "\\newcommand{\\vzeta}{\\boldsymbol{\\zeta}}\n",
-    "\\newcommand{\\vXi}{\\boldsymbol{\\Xi}}\n",
-    "\\newcommand{\\vell}{\\boldsymbol{\\ell}}\n",
-    "\\newcommand{\\veta}{\\boldsymbol{\\eta}}\n",
-    "%\\newcommand{\\vEta}{\\boldsymbol{\\Eta}}\n",
-    "\\newcommand{\\vgamma}{\\boldsymbol{\\gamma}}\n",
-    "\\newcommand{\\vGamma}{\\boldsymbol{\\Gamma}}\n",
-    "\\newcommand{\\vmu}{\\boldsymbol{\\mu}}\n",
-    "\\newcommand{\\vmut}{\\boldsymbol{\\tilde{\\mu}}}\n",
-    "\\newcommand{\\vnu}{\\boldsymbol{\\nu}}\n",
-    "\\newcommand{\\vkappa}{\\boldsymbol{\\kappa}}\n",
-    "\\newcommand{\\vlambda}{\\boldsymbol{\\lambda}}\n",
-    "\\newcommand{\\vLambda}{\\boldsymbol{\\Lambda}}\n",
-    "\\newcommand{\\vLambdaBar}{\\overline{\\vLambda}}\n",
-    "%\\newcommand{\\vnu}{\\boldsymbol{\\nu}}\n",
-    "\\newcommand{\\vomega}{\\boldsymbol{\\omega}}\n",
-    "\\newcommand{\\vOmega}{\\boldsymbol{\\Omega}}\n",
-    "\\newcommand{\\vphi}{\\boldsymbol{\\phi}}\n",
-    "\\newcommand{\\vvarphi}{\\boldsymbol{\\varphi}}\n",
-    "\\newcommand{\\vPhi}{\\boldsymbol{\\Phi}}\n",
-    "\\newcommand{\\vpi}{\\boldsymbol{\\pi}}\n",
-    "\\newcommand{\\vPi}{\\boldsymbol{\\Pi}}\n",
-    "\\newcommand{\\vpsi}{\\boldsymbol{\\psi}}\n",
-    "\\newcommand{\\vPsi}{\\boldsymbol{\\Psi}}\n",
-    "\\newcommand{\\vrho}{\\boldsymbol{\\rho}}\n",
-    "\\newcommand{\\vtheta}{\\boldsymbol{\\theta}}\n",
-    "\\newcommand{\\vthetat}{\\boldsymbol{\\tilde{\\theta}}}\n",
-    "\\newcommand{\\vTheta}{\\boldsymbol{\\Theta}}\n",
-    "\\newcommand{\\vsigma}{\\boldsymbol{\\sigma}}\n",
-    "\\newcommand{\\vSigma}{\\boldsymbol{\\Sigma}}\n",
-    "\\newcommand{\\vSigmat}{\\boldsymbol{\\tilde{\\Sigma}}}\n",
-    "\\newcommand{\\vsigmoid}{\\vsigma}\n",
-    "\\newcommand{\\vtau}{\\boldsymbol{\\tau}}\n",
-    "\\newcommand{\\vxi}{\\boldsymbol{\\xi}}\n",
-    "\n",
-    "\n",
-    "% Lower Roman (Vectors)\n",
-    "\\newcommand{\\va}{\\mathbf{a}}\n",
-    "\\newcommand{\\vb}{\\mathbf{b}}\n",
-    "\\newcommand{\\vBt}{\\mathbf{\\tilde{B}}}\n",
-    "\\newcommand{\\vc}{\\mathbf{c}}\n",
-    "\\newcommand{\\vct}{\\mathbf{\\tilde{c}}}\n",
-    "\\newcommand{\\vd}{\\mathbf{d}}\n",
-    "\\newcommand{\\ve}{\\mathbf{e}}\n",
-    "\\newcommand{\\vf}{\\mathbf{f}}\n",
-    "\\newcommand{\\vg}{\\mathbf{g}}\n",
-    "\\newcommand{\\vh}{\\mathbf{h}}\n",
-    "%\\newcommand{\\myvh}{\\mathbf{h}}\n",
-    "\\newcommand{\\vi}{\\mathbf{i}}\n",
-    "\\newcommand{\\vj}{\\mathbf{j}}\n",
-    "\\newcommand{\\vk}{\\mathbf{k}}\n",
-    "\\newcommand{\\vl}{\\mathbf{l}}\n",
-    "\\newcommand{\\vm}{\\mathbf{m}}\n",
-    "\\newcommand{\\vn}{\\mathbf{n}}\n",
-    "\\newcommand{\\vo}{\\mathbf{o}}\n",
-    "\\newcommand{\\vp}{\\mathbf{p}}\n",
-    "\\newcommand{\\vq}{\\mathbf{q}}\n",
-    "\\newcommand{\\vr}{\\mathbf{r}}\n",
-    "\\newcommand{\\vs}{\\mathbf{s}}\n",
-    "\\newcommand{\\vt}{\\mathbf{t}}\n",
-    "\\newcommand{\\vu}{\\mathbf{u}}\n",
-    "\\newcommand{\\vv}{\\mathbf{v}}\n",
-    "\\newcommand{\\vw}{\\mathbf{w}}\n",
-    "\\newcommand{\\vws}{\\vw_s}\n",
-    "\\newcommand{\\vwt}{\\mathbf{\\tilde{w}}}\n",
-    "\\newcommand{\\vWt}{\\mathbf{\\tilde{W}}}\n",
-    "\\newcommand{\\vwh}{\\hat{\\vw}}\n",
-    "\\newcommand{\\vx}{\\mathbf{x}}\n",
-    "%\\newcommand{\\vx}{\\mathbf{x}}\n",
-    "\\newcommand{\\vxt}{\\mathbf{\\tilde{x}}}\n",
-    "\\newcommand{\\vy}{\\mathbf{y}}\n",
-    "\\newcommand{\\vyt}{\\mathbf{\\tilde{y}}}\n",
-    "\\newcommand{\\vz}{\\mathbf{z}}\n",
-    "%\\newcommand{\\vzt}{\\mathbf{\\tilde{z}}}\n",
-    "\n",
-    "\n",
-    "% Upper Roman (Matrices)\n",
-    "\\newcommand{\\vA}{\\mathbf{A}}\n",
-    "\\newcommand{\\vB}{\\mathbf{B}}\n",
-    "\\newcommand{\\vC}{\\mathbf{C}}\n",
-    "\\newcommand{\\vD}{\\mathbf{D}}\n",
-    "\\newcommand{\\vE}{\\mathbf{E}}\n",
-    "\\newcommand{\\vF}{\\mathbf{F}}\n",
-    "\\newcommand{\\vG}{\\mathbf{G}}\n",
-    "\\newcommand{\\vH}{\\mathbf{H}}\n",
-    "\\newcommand{\\vI}{\\mathbf{I}}\n",
-    "\\newcommand{\\vJ}{\\mathbf{J}}\n",
-    "\\newcommand{\\vK}{\\mathbf{K}}\n",
-    "\\newcommand{\\vL}{\\mathbf{L}}\n",
-    "\\newcommand{\\vM}{\\mathbf{M}}\n",
-    "\\newcommand{\\vMt}{\\mathbf{\\tilde{M}}}\n",
-    "\\newcommand{\\vN}{\\mathbf{N}}\n",
-    "\\newcommand{\\vO}{\\mathbf{O}}\n",
-    "\\newcommand{\\vP}{\\mathbf{P}}\n",
-    "\\newcommand{\\vQ}{\\mathbf{Q}}\n",
-    "\\newcommand{\\vR}{\\mathbf{R}}\n",
-    "\\newcommand{\\vS}{\\mathbf{S}}\n",
-    "\\newcommand{\\vT}{\\mathbf{T}}\n",
-    "\\newcommand{\\vU}{\\mathbf{U}}\n",
-    "\\newcommand{\\vV}{\\mathbf{V}}\n",
-    "\\newcommand{\\vW}{\\mathbf{W}}\n",
-    "\\newcommand{\\vX}{\\mathbf{X}}\n",
-    "%\\newcommand{\\vXs}{\\vX_{\\vs}}\n",
-    "\\newcommand{\\vXs}{\\vX_{s}}\n",
-    "\\newcommand{\\vXt}{\\mathbf{\\tilde{X}}}\n",
-    "\\newcommand{\\vY}{\\mathbf{Y}}\n",
-    "\\newcommand{\\vZ}{\\mathbf{Z}}\n",
-    "\\newcommand{\\vZt}{\\mathbf{\\tilde{Z}}}\n",
-    "\\newcommand{\\vzt}{\\mathbf{\\tilde{z}}}\n",
-    "\n",
-    "\n",
-    "%%%%\n",
-    "\\newcommand{\\hidden}{\\vz}\n",
-    "\\newcommand{\\hid}{\\hidden}\n",
-    "\\newcommand{\\observed}{\\vy}\n",
-    "\\newcommand{\\obs}{\\observed}\n",
-    "\\newcommand{\\inputs}{\\vu}\n",
-    "\\newcommand{\\input}{\\inputs}\n",
-    "\n",
-    "\\newcommand{\\hmmTrans}{\\vA}\n",
-    "\\newcommand{\\hmmObs}{\\vB}\n",
-    "\\newcommand{\\hmmInit}{\\vpi}\n",
-    "\\newcommand{\\hmmhid}{\\hidden}\n",
-    "\\newcommand{\\hmmobs}{\\obs}\n",
-    "\n",
-    "\\newcommand{\\ldsDyn}{\\vA}\n",
-    "\\newcommand{\\ldsObs}{\\vC}\n",
-    "\\newcommand{\\ldsDynIn}{\\vB}\n",
-    "\\newcommand{\\ldsObsIn}{\\vD}\n",
-    "\\newcommand{\\ldsDynNoise}{\\vQ}\n",
-    "\\newcommand{\\ldsObsNoise}{\\vR}\n",
-    "\n",
-    "\\newcommand{\\ssmDynFn}{f}\n",
-    "\\newcommand{\\ssmObsFn}{h}\n",
-    "\n",
-    "\n",
-    "%%%\n",
-    "\\newcommand{\\gauss}{\\mathcal{N}}\n",
-    "\n",
-    "\\newcommand{\\diag}{\\mathrm{diag}}\n",
-    "```\n"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": []
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
     "(sec:ssm-intro)=\n",
     "(sec:ssm-intro)=\n",
     "# What are State Space Models?\n",
     "# What are State Space Models?\n",
     "\n",
     "\n",
@@ -197,7 +28,7 @@
     "unlike a standard Markov model.\n",
     "unlike a standard Markov model.\n",
     "\n",
     "\n",
     "```{figure} /figures/SSM-AR-inputs.png\n",
     "```{figure} /figures/SSM-AR-inputs.png\n",
-    ":height: 300px\n",
+    ":height: 150px\n",
     ":name: fig:ssm-ar\n",
     ":name: fig:ssm-ar\n",
     "\n",
     "\n",
     "Illustration of an SSM as a graphical model.\n",
     "Illustration of an SSM as a graphical model.\n",
@@ -208,14 +39,14 @@
     "as the following joint distribution:\n",
     "as the following joint distribution:\n",
     "```{math}\n",
     "```{math}\n",
     ":label: eq:SSM-ar\n",
     ":label: eq:SSM-ar\n",
-    "p(\\hmmobs_{1:T},\\hmmhid_{1:T}|\\inputs_{1:T})\n",
-    " = \\left[ p(\\hmmhid_1|\\inputs_1) \\prod_{t=2}^{T}\n",
-    " p(\\hmmhid_t|\\hmmhid_{t-1},\\inputs_t) \\right]\n",
-    " \\left[ \\prod_{t=1}^T p(\\hmmobs_t|\\hmmhid_t, \\inputs_t, \\hmmobs_{t-1}) \\right]\n",
+    "p(\\obs_{1:T},\\hidden_{1:T}|\\inputs_{1:T})\n",
+    " = \\left[ p(\\hidden_1|\\inputs_1) \\prod_{t=2}^{T}\n",
+    " p(\\hidden_t|\\hidden_{t-1},\\inputs_t) \\right]\n",
+    " \\left[ \\prod_{t=1}^T p(\\obs_t|\\hidden_t, \\inputs_t, \\obs_{t-1}) \\right]\n",
     "```\n",
     "```\n",
-    "where $p(\\hmmhid_t|\\hmmhid_{t-1},\\inputs_t)$ is the\n",
+    "where $p(\\hidden_t|\\hidden_{t-1},\\inputs_t)$ is the\n",
     "transition model,\n",
     "transition model,\n",
-    "$p(\\hmmobs_t|\\hmmhid_t, \\inputs_t, \\hmmobs_{t-1})$ is the\n",
+    "$p(\\obs_t|\\hidden_t, \\inputs_t, \\obs_{t-1})$ is the\n",
     "observation model,\n",
     "observation model,\n",
     "and $\\inputs_{t}$ is an optional input or action.\n",
     "and $\\inputs_{t}$ is an optional input or action.\n",
     "See {numref}`fig:ssm-ar` \n",
     "See {numref}`fig:ssm-ar` \n",
@@ -228,31 +59,35 @@
     "In this case the joint simplifies to \n",
     "In this case the joint simplifies to \n",
     "```{math}\n",
     "```{math}\n",
     ":label: eq:SSM-input\n",
     ":label: eq:SSM-input\n",
-    "p(\\hmmobs_{1:T},\\hmmhid_{1:T}|\\inputs_{1:T})\n",
-    " = \\left[ p(\\hmmhid_1|\\inputs_1) \\prod_{t=2}^{T}\n",
-    " p(\\hmmhid_t|\\hmmhid_{t-1},\\inputs_t) \\right]\n",
-    " \\left[ \\prod_{t=1}^T p(\\hmmobs_t|\\hmmhid_t, \\inputs_t) \\right]\n",
+    "p(\\obs_{1:T},\\hidden_{1:T}|\\inputs_{1:T})\n",
+    " = \\left[ p(\\hidden_1|\\inputs_1) \\prod_{t=2}^{T}\n",
+    " p(\\hidden_t|\\hidden_{t-1},\\inputs_t) \\right]\n",
+    " \\left[ \\prod_{t=1}^T p(\\obs_t|\\hidden_t, \\inputs_t) \\right]\n",
     "```\n",
     "```\n",
     "Sometimes there are no external inputs, so the model further\n",
     "Sometimes there are no external inputs, so the model further\n",
     "simplifies to the following unconditional generative model: \n",
     "simplifies to the following unconditional generative model: \n",
     "```{math}\n",
     "```{math}\n",
     ":label: eq:SSM-no-input\n",
     ":label: eq:SSM-no-input\n",
-    "p(\\hmmobs_{1:T},\\hmmhid_{1:T})\n",
-    " = \\left[ p(\\hmmhid_1) \\prod_{t=2}^{T}\n",
-    " p(\\hmmhid_t|\\hmmhid_{t-1}) \\right]\n",
-    " \\left[ \\prod_{t=1}^T p(\\hmmobs_t|\\hmmhid_t) \\right]\n",
+    "p(\\obs_{1:T},\\hidden_{1:T})\n",
+    " = \\left[ p(\\hidden_1) \\prod_{t=2}^{T}\n",
+    " p(\\hidden_t|\\hidden_{t-1}) \\right]\n",
+    " \\left[ \\prod_{t=1}^T p(\\obs_t|\\hidden_t) \\right]\n",
     "```\n",
     "```\n",
     "See {numref}`ssm-simplified` \n",
     "See {numref}`ssm-simplified` \n",
     "for an illustration of the corresponding graphical model.\n",
     "for an illustration of the corresponding graphical model.\n",
     "\n",
     "\n",
     "\n",
     "\n",
     "```{figure} /figures/SSM-simplified.png\n",
     "```{figure} /figures/SSM-simplified.png\n",
-    ":scale: 100%\n",
+    ":height: 150px\n",
     ":name: ssm-simplified\n",
     ":name: ssm-simplified\n",
     "\n",
     "\n",
     "Illustration of a simplified SSM.\n",
     "Illustration of a simplified SSM.\n",
     "```\n",
     "```\n",
-    "\n"
+    "\n",
+    "SSMs are widely used in many areas of science, engineering, finance, economics, etc.\n",
+    "The main applications are state estimation (i.e., inferring the underlying hidden state of the system given the observation),\n",
+    "forecasting (i.e., predicting future states and observations), and control (i.e., inferring the sequence of inputs that will\n",
+    "give rise to a desired target state). We will discuss these applications in alter chapters."
    ]
    ]
   },
   },
   {
   {

+ 12 - 2
references.bib

@@ -1,6 +1,4 @@
 
 
----
----
 
 
 %@string{aij = "Artificial Intelligence J."}
 %@string{aij = "Artificial Intelligence J."}
 @string{aij = "AIJ"}
 @string{aij = "AIJ"}
@@ -554,3 +552,15 @@ publisher = "Cambridge University Press",
  year = 2005,
  year = 2005,
  publisher = "Springer"
  publisher = "Springer"
 }
 }
+
+
+@article{Devijver85,
+  title={Baum's forward-backward algorithm revisited},
+  author={Devijver, Pierre A},
+  journal={Pattern Recognition Letters},
+  volume={3},
+  number={6},
+  pages={369--373},
+  year={1985},
+  publisher={Elsevier}
+}