|
@@ -0,0 +1,612 @@
|
|
|
+{
|
|
|
+ "cells": [
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "(sec:hmm-ex)=\n",
|
|
|
+ "# Hidden Markov Models\n",
|
|
|
+ "\n",
|
|
|
+ "In this section, we introduce Hidden Markov Models (HMMs)."
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "## Boilerplate"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 1,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "Collecting jax[cpu]\n",
|
|
|
+ " Downloading jax-0.3.5.tar.gz (946 kB)\n",
|
|
|
+ "\u001b[K |████████████████████████████████| 946 kB 2.7 MB/s eta 0:00:01\n",
|
|
|
+ "\u001b[?25hCollecting absl-py\n",
|
|
|
+ " Downloading absl_py-1.0.0-py3-none-any.whl (126 kB)\n",
|
|
|
+ "\u001b[K |████████████████████████████████| 126 kB 47.7 MB/s eta 0:00:01\n",
|
|
|
+ "\u001b[?25hCollecting numpy>=1.19\n",
|
|
|
+ " Downloading numpy-1.22.3-cp38-cp38-macosx_10_14_x86_64.whl (17.6 MB)\n",
|
|
|
+ "\u001b[K |████████████████████████████████| 17.6 MB 47.5 MB/s eta 0:00:01\n",
|
|
|
+ "\u001b[?25hCollecting opt_einsum\n",
|
|
|
+ " Using cached opt_einsum-3.3.0-py3-none-any.whl (65 kB)\n",
|
|
|
+ "Collecting scipy>=1.2.1\n",
|
|
|
+ " Downloading scipy-1.8.0-cp38-cp38-macosx_12_0_universal2.macosx_10_9_x86_64.whl (55.3 MB)\n",
|
|
|
+ "\u001b[K |████████████████████████████████| 55.3 MB 73.1 MB/s eta 0:00:01\n",
|
|
|
+ "\u001b[?25hCollecting typing_extensions\n",
|
|
|
+ " Using cached typing_extensions-4.1.1-py3-none-any.whl (26 kB)\n",
|
|
|
+ "Collecting jaxlib==0.3.5\n",
|
|
|
+ " Downloading jaxlib-0.3.5-cp38-none-macosx_10_9_x86_64.whl (70.5 MB)\n",
|
|
|
+ "\u001b[K |████████████████████████████████| 70.5 MB 723 kB/s eta 0:00:01\n",
|
|
|
+ "\u001b[?25hCollecting flatbuffers<3.0,>=1.12\n",
|
|
|
+ " Using cached flatbuffers-2.0-py2.py3-none-any.whl (26 kB)\n",
|
|
|
+ "Requirement already satisfied: six in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from absl-py->jax[cpu]) (1.16.0)\n",
|
|
|
+ "Building wheels for collected packages: jax\n",
|
|
|
+ " Building wheel for jax (setup.py) ... \u001b[?25ldone\n",
|
|
|
+ "\u001b[?25h Created wheel for jax: filename=jax-0.3.5-py3-none-any.whl size=1095861 sha256=6886baa70817bbac3b5797b3720dcb81e49097f61cee7d2e1255823ea32ccad8\n",
|
|
|
+ " Stored in directory: /Users/kpmurphy/Library/Caches/pip/wheels/05/30/aa/908988293721511b4b29e0aadf9b5d133d0f14f6c0a188e764\n",
|
|
|
+ "Successfully built jax\n",
|
|
|
+ "Installing collected packages: numpy, typing-extensions, scipy, opt-einsum, flatbuffers, absl-py, jaxlib, jax\n",
|
|
|
+ "Successfully installed absl-py-1.0.0 flatbuffers-2.0 jax-0.3.5 jaxlib-0.3.5 numpy-1.22.3 opt-einsum-3.3.0 scipy-1.8.0 typing-extensions-4.1.1\n",
|
|
|
+ "Note: you may need to restart the kernel to use updated packages.\n",
|
|
|
+ "Collecting git+https://github.com/probml/jsl\n",
|
|
|
+ " Cloning https://github.com/probml/jsl to /private/var/folders/mn/vt7cgfsx6zs9vblhvbbk7pf8003xtr/T/pip-req-build-i8seqdiw\n",
|
|
|
+ " Running command git clone -q https://github.com/probml/jsl /private/var/folders/mn/vt7cgfsx6zs9vblhvbbk7pf8003xtr/T/pip-req-build-i8seqdiw\n",
|
|
|
+ "Collecting chex\n",
|
|
|
+ " Downloading chex-0.1.2-py3-none-any.whl (72 kB)\n",
|
|
|
+ "\u001b[K |████████████████████████████████| 72 kB 1.3 MB/s eta 0:00:011\n",
|
|
|
+ "\u001b[?25hCollecting dataclasses\n",
|
|
|
+ " Using cached dataclasses-0.6-py3-none-any.whl (14 kB)\n",
|
|
|
+ "Requirement already satisfied: jaxlib in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from jsl==0.0.0) (0.3.5)\n",
|
|
|
+ "Requirement already satisfied: jax in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from jsl==0.0.0) (0.3.5)\n",
|
|
|
+ "Collecting matplotlib\n",
|
|
|
+ " Downloading matplotlib-3.5.1-cp38-cp38-macosx_10_9_x86_64.whl (7.3 MB)\n",
|
|
|
+ "\u001b[K |████████████████████████████████| 7.3 MB 3.9 MB/s eta 0:00:01\n",
|
|
|
+ "\u001b[?25hCollecting tensorflow_probability\n",
|
|
|
+ " Using cached tensorflow_probability-0.16.0-py2.py3-none-any.whl (6.3 MB)\n",
|
|
|
+ "Collecting dm-tree>=0.1.5\n",
|
|
|
+ " Using cached dm_tree-0.1.6-cp38-cp38-macosx_10_14_x86_64.whl (95 kB)\n",
|
|
|
+ "Requirement already satisfied: absl-py>=0.9.0 in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from chex->jsl==0.0.0) (1.0.0)\n",
|
|
|
+ "Requirement already satisfied: numpy>=1.18.0 in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from chex->jsl==0.0.0) (1.22.3)\n",
|
|
|
+ "Collecting toolz>=0.9.0\n",
|
|
|
+ " Downloading toolz-0.11.2-py3-none-any.whl (55 kB)\n",
|
|
|
+ "\u001b[K |████████████████████████████████| 55 kB 11.3 MB/s eta 0:00:01\n",
|
|
|
+ "\u001b[?25hRequirement already satisfied: six in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from absl-py>=0.9.0->chex->jsl==0.0.0) (1.16.0)\n",
|
|
|
+ "Requirement already satisfied: typing-extensions in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from jax->jsl==0.0.0) (4.1.1)\n",
|
|
|
+ "Requirement already satisfied: scipy>=1.2.1 in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from jax->jsl==0.0.0) (1.8.0)\n",
|
|
|
+ "Requirement already satisfied: opt-einsum in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from jax->jsl==0.0.0) (3.3.0)\n",
|
|
|
+ "Requirement already satisfied: flatbuffers<3.0,>=1.12 in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from jaxlib->jsl==0.0.0) (2.0)\n",
|
|
|
+ "Collecting cycler>=0.10\n",
|
|
|
+ " Downloading cycler-0.11.0-py3-none-any.whl (6.4 kB)\n",
|
|
|
+ "Collecting kiwisolver>=1.0.1\n",
|
|
|
+ " Downloading kiwisolver-1.4.2-cp38-cp38-macosx_10_9_x86_64.whl (65 kB)\n",
|
|
|
+ "\u001b[K |████████████████████████████████| 65 kB 8.5 MB/s eta 0:00:01\n",
|
|
|
+ "\u001b[?25hCollecting fonttools>=4.22.0\n",
|
|
|
+ " Downloading fonttools-4.32.0-py3-none-any.whl (900 kB)\n",
|
|
|
+ "\u001b[K |████████████████████████████████| 900 kB 35.8 MB/s eta 0:00:01\n",
|
|
|
+ "\u001b[?25hRequirement already satisfied: pyparsing>=2.2.1 in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from matplotlib->jsl==0.0.0) (3.0.7)\n",
|
|
|
+ "Requirement already satisfied: packaging>=20.0 in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from matplotlib->jsl==0.0.0) (21.3)\n",
|
|
|
+ "Requirement already satisfied: python-dateutil>=2.7 in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from matplotlib->jsl==0.0.0) (2.8.2)\n",
|
|
|
+ "Collecting pillow>=6.2.0\n",
|
|
|
+ " Downloading Pillow-9.1.0-cp38-cp38-macosx_10_9_x86_64.whl (3.1 MB)\n",
|
|
|
+ "\u001b[K |████████████████████████████████| 3.1 MB 76.6 MB/s eta 0:00:01\n",
|
|
|
+ "\u001b[?25hCollecting gast>=0.3.2\n",
|
|
|
+ " Downloading gast-0.5.3-py3-none-any.whl (19 kB)\n",
|
|
|
+ "Requirement already satisfied: decorator in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from tensorflow_probability->jsl==0.0.0) (5.1.1)\n",
|
|
|
+ "Collecting cloudpickle>=1.3\n",
|
|
|
+ " Downloading cloudpickle-2.0.0-py3-none-any.whl (25 kB)\n",
|
|
|
+ "Building wheels for collected packages: jsl\n",
|
|
|
+ " Building wheel for jsl (setup.py) ... \u001b[?25ldone\n",
|
|
|
+ "\u001b[?25h Created wheel for jsl: filename=jsl-0.0.0-py3-none-any.whl size=77852 sha256=e7365293dc97e2b3e72bf42cc19db7d7e355abec312fc4d87961fa2044fa06f0\n",
|
|
|
+ " Stored in directory: /private/var/folders/mn/vt7cgfsx6zs9vblhvbbk7pf8003xtr/T/pip-ephem-wheel-cache-63vxzlng/wheels/ed/8b/bf/0105dc839fecf1fc8db14f7267a6ce5ee876324b58565b359f\n",
|
|
|
+ "Successfully built jsl\n",
|
|
|
+ "Installing collected packages: toolz, pillow, kiwisolver, gast, fonttools, dm-tree, cycler, cloudpickle, tensorflow-probability, matplotlib, dataclasses, chex, jsl\n",
|
|
|
+ "Successfully installed chex-0.1.2 cloudpickle-2.0.0 cycler-0.11.0 dataclasses-0.6 dm-tree-0.1.6 fonttools-4.32.0 gast-0.5.3 jsl-0.0.0 kiwisolver-1.4.2 matplotlib-3.5.1 pillow-9.1.0 tensorflow-probability-0.16.0 toolz-0.11.2\n",
|
|
|
+ "Note: you may need to restart the kernel to use updated packages.\n",
|
|
|
+ "Collecting rich\n",
|
|
|
+ " Downloading rich-12.2.0-py3-none-any.whl (229 kB)\n",
|
|
|
+ "\u001b[K |████████████████████████████████| 229 kB 2.9 MB/s eta 0:00:01\n",
|
|
|
+ "\u001b[?25hRequirement already satisfied: typing-extensions<5.0,>=4.0.0 in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from rich) (4.1.1)\n",
|
|
|
+ "Requirement already satisfied: pygments<3.0.0,>=2.6.0 in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from rich) (2.11.2)\n",
|
|
|
+ "Collecting commonmark<0.10.0,>=0.9.0\n",
|
|
|
+ " Using cached commonmark-0.9.1-py2.py3-none-any.whl (51 kB)\n",
|
|
|
+ "Installing collected packages: commonmark, rich\n",
|
|
|
+ "Successfully installed commonmark-0.9.1 rich-12.2.0\n",
|
|
|
+ "Note: you may need to restart the kernel to use updated packages.\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "# Install necessary libraries\n",
|
|
|
+ "\n",
|
|
|
+ "try:\n",
|
|
|
+ " import jax\n",
|
|
|
+ "except:\n",
|
|
|
+ " # For cuda version, see https://github.com/google/jax#installation\n",
|
|
|
+ " %pip install --upgrade \"jax[cpu]\" \n",
|
|
|
+ " import jax\n",
|
|
|
+ "\n",
|
|
|
+ "try:\n",
|
|
|
+ " import jsl\n",
|
|
|
+ "except:\n",
|
|
|
+ " %pip install git+https://github.com/probml/jsl\n",
|
|
|
+ " import jsl\n",
|
|
|
+ "\n",
|
|
|
+ "try:\n",
|
|
|
+ " import rich\n",
|
|
|
+ "except:\n",
|
|
|
+ " %pip install rich\n",
|
|
|
+ " import rich\n",
|
|
|
+ "\n",
|
|
|
+ "\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 2,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "# Import standard libraries\n",
|
|
|
+ "\n",
|
|
|
+ "import abc\n",
|
|
|
+ "from dataclasses import dataclass\n",
|
|
|
+ "import functools\n",
|
|
|
+ "import itertools\n",
|
|
|
+ "\n",
|
|
|
+ "from typing import Any, Callable, NamedTuple, Optional, Union, Tuple\n",
|
|
|
+ "\n",
|
|
|
+ "import matplotlib.pyplot as plt\n",
|
|
|
+ "import numpy as np\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "import jax\n",
|
|
|
+ "import jax.numpy as jnp\n",
|
|
|
+ "from jax import lax, vmap, jit, grad\n",
|
|
|
+ "from jax.scipy.special import logit\n",
|
|
|
+ "from jax.nn import softmax\n",
|
|
|
+ "from functools import partial\n",
|
|
|
+ "from jax.random import PRNGKey, split\n",
|
|
|
+ "\n",
|
|
|
+ "import inspect\n",
|
|
|
+ "import inspect as py_inspect\n",
|
|
|
+ "from rich import inspect as r_inspect\n",
|
|
|
+ "from rich import print as r_print\n",
|
|
|
+ "\n",
|
|
|
+ "def print_source(fname):\n",
|
|
|
+ " r_print(py_inspect.getsource(fname))"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "## Utility code"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 3,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "def normalize(u, axis=0, eps=1e-15):\n",
|
|
|
+ " '''\n",
|
|
|
+ " Normalizes the values within the axis in a way that they sum up to 1.\n",
|
|
|
+ " Parameters\n",
|
|
|
+ " ----------\n",
|
|
|
+ " u : array\n",
|
|
|
+ " axis : int\n",
|
|
|
+ " eps : float\n",
|
|
|
+ " Threshold for the alpha values\n",
|
|
|
+ " Returns\n",
|
|
|
+ " -------\n",
|
|
|
+ " * array\n",
|
|
|
+ " Normalized version of the given matrix\n",
|
|
|
+ " * array(seq_len, n_hidden) :\n",
|
|
|
+ " The values of the normalizer\n",
|
|
|
+ " '''\n",
|
|
|
+ " u = jnp.where(u == 0, 0, jnp.where(u < eps, eps, u))\n",
|
|
|
+ " c = u.sum(axis=axis)\n",
|
|
|
+ " c = jnp.where(c == 0, 1, c)\n",
|
|
|
+ " return u / c, c"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "(sec:casino-ex)=\n",
|
|
|
+ "## Example: Casino HMM\n",
|
|
|
+ "\n",
|
|
|
+ "We first create the \"Ocassionally dishonest casino\" model from {cite}`Durbin98`.\n",
|
|
|
+ "\n",
|
|
|
+ "```{figure} /figures/casino.png\n",
|
|
|
+ ":scale: 50%\n",
|
|
|
+ ":name: casino-fig\n",
|
|
|
+ "\n",
|
|
|
+ "Illustration of the casino HMM.\n",
|
|
|
+ "```\n",
|
|
|
+ "\n",
|
|
|
+ "There are 2 hidden states, each of which emit 6 possible observations."
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 5,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stderr",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "# state transition matrix\n",
|
|
|
+ "A = np.array([\n",
|
|
|
+ " [0.95, 0.05],\n",
|
|
|
+ " [0.10, 0.90]\n",
|
|
|
+ "])\n",
|
|
|
+ "\n",
|
|
|
+ "# observation matrix\n",
|
|
|
+ "B = np.array([\n",
|
|
|
+ " [1/6, 1/6, 1/6, 1/6, 1/6, 1/6], # fair die\n",
|
|
|
+ " [1/10, 1/10, 1/10, 1/10, 1/10, 5/10] # loaded die\n",
|
|
|
+ "])\n",
|
|
|
+ "\n",
|
|
|
+ "pi, _ = normalize(np.array([1, 1]))\n",
|
|
|
+ "pi = np.array(pi)\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "(nstates, nobs) = np.shape(B)\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "Let's make a little data structure to store all the parameters.\n",
|
|
|
+ "We use NamedTuple rather than dataclass, since we assume these are immutable.\n",
|
|
|
+ "(Also, standard python dataclass does not work well with JAX, which requires parameters to be\n",
|
|
|
+ "pytrees, as discussed in https://github.com/google/jax/issues/2371)."
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 74,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "HMM(trans_mat=array([[0.95, 0.05],\n",
|
|
|
+ " [0.1 , 0.9 ]]), obs_mat=array([[0.16666667, 0.16666667, 0.16666667, 0.16666667, 0.16666667,\n",
|
|
|
+ " 0.16666667],\n",
|
|
|
+ " [0.1 , 0.1 , 0.1 , 0.1 , 0.1 ,\n",
|
|
|
+ " 0.5 ]]), init_dist=array([0.5, 0.5], dtype=float32))\n",
|
|
|
+ "<class 'numpy.ndarray'>\n",
|
|
|
+ "HMM(trans_mat=DeviceArray([[0.95, 0.05],\n",
|
|
|
+ " [0.1 , 0.9 ]], dtype=float32), obs_mat=DeviceArray([[0.16666667, 0.16666667, 0.16666667, 0.16666667, 0.16666667,\n",
|
|
|
+ " 0.16666667],\n",
|
|
|
+ " [0.1 , 0.1 , 0.1 , 0.1 , 0.1 ,\n",
|
|
|
+ " 0.5 ]], dtype=float32), init_dist=DeviceArray([0.5, 0.5], dtype=float32))\n",
|
|
|
+ "<class 'jaxlib.xla_extension.DeviceArray'>\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "Array = Union[np.array, jnp.array]\n",
|
|
|
+ "\n",
|
|
|
+ "class HMM(NamedTuple):\n",
|
|
|
+ " trans_mat: Array # A : (n_states, n_states)\n",
|
|
|
+ " obs_mat: Array # B : (n_states, n_obs)\n",
|
|
|
+ " init_dist: Array # pi : (n_states)\n",
|
|
|
+ "\n",
|
|
|
+ "params_np = HMM(A, B, pi)\n",
|
|
|
+ "print(params_np)\n",
|
|
|
+ "print(type(params_np.trans_mat))\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "params = jax.tree_map(lambda x: jnp.array(x), params_np)\n",
|
|
|
+ "print(params)\n",
|
|
|
+ "print(type(params.trans_mat))"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "## Sampling from the joint\n",
|
|
|
+ "\n",
|
|
|
+ "Let's write code to sample from this model. \n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "### Numpy version\n",
|
|
|
+ "\n",
|
|
|
+ "First we code it in numpy using a for loop."
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 30,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "def hmm_sample_np(params, seq_len, random_state=0):\n",
|
|
|
+ " np.random.seed(random_state)\n",
|
|
|
+ " trans_mat, obs_mat, init_dist = params.trans_mat, params.obs_mat, params.init_dist\n",
|
|
|
+ " n_states, n_obs = obs_mat.shape\n",
|
|
|
+ " state_seq = np.zeros(seq_len, dtype=int)\n",
|
|
|
+ " obs_seq = np.zeros(seq_len, dtype=int)\n",
|
|
|
+ " for t in range(seq_len):\n",
|
|
|
+ " if t==0:\n",
|
|
|
+ " zt = np.random.choice(n_states, p=init_dist)\n",
|
|
|
+ " else:\n",
|
|
|
+ " zt = np.random.choice(n_states, p=trans_mat[zt])\n",
|
|
|
+ " yt = np.random.choice(n_obs, p=obs_mat[zt])\n",
|
|
|
+ " state_seq[t] = zt\n",
|
|
|
+ " obs_seq[t] = yt\n",
|
|
|
+ "\n",
|
|
|
+ " return state_seq, obs_seq"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 75,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0\n",
|
|
|
+ " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
|
|
|
+ " 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
|
|
|
+ "[4 1 0 2 3 4 5 4 3 1 5 4 5 0 5 2 5 3 5 4 5 5 4 2 1 4 1 0 0 4 2 2 3 3 3 0 4\n",
|
|
|
+ " 0 2 4 3 2 5 5 3 5 3 1 3 3 3 2 3 5 5 0 4 4 5 0 0 1 3 5 1 5 0 1 2 4 0 0 0 4\n",
|
|
|
+ " 0 5 1 4 3 5 4 5 0 2 3 5 2 4 1 2 1 0 4 3 5 0 4 5 1 5]\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "seq_len = 100\n",
|
|
|
+ "state_seq, obs_seq = hmm_sample_np(params_np, seq_len, random_state=1)\n",
|
|
|
+ "print(state_seq)\n",
|
|
|
+ "print(obs_seq)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "### JAX version\n",
|
|
|
+ "\n",
|
|
|
+ "Now let's write a JAX version using jax.lax.scan (for the inter-dependent states) and vmap (for the observations).\n",
|
|
|
+ "This is harder to read than the numpy version, but faster."
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 91,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "#@partial(jit, static_argnums=(1,))\n",
|
|
|
+ "def markov_chain_sample(rng_key, init_dist, trans_mat, seq_len):\n",
|
|
|
+ " n_states = len(init_dist)\n",
|
|
|
+ "\n",
|
|
|
+ " def draw_state(prev_state, key):\n",
|
|
|
+ " state = jax.random.choice(key, n_states, p=trans_mat[prev_state])\n",
|
|
|
+ " return state, state\n",
|
|
|
+ "\n",
|
|
|
+ " rng_key, rng_state = jax.random.split(rng_key, 2)\n",
|
|
|
+ " keys = jax.random.split(rng_state, seq_len - 1)\n",
|
|
|
+ " initial_state = jax.random.choice(rng_key, n_states, p=init_dist)\n",
|
|
|
+ " final_state, states = jax.lax.scan(draw_state, initial_state, keys)\n",
|
|
|
+ " state_seq = jnp.append(jnp.array([initial_state]), states)\n",
|
|
|
+ "\n",
|
|
|
+ " return state_seq"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 90,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "#@partial(jit, static_argnums=(1,))\n",
|
|
|
+ "def hmm_sample(rng_key, params, seq_len):\n",
|
|
|
+ "\n",
|
|
|
+ " trans_mat, obs_mat, init_dist = params.trans_mat, params.obs_mat, params.init_dist\n",
|
|
|
+ " n_states, n_obs = obs_mat.shape\n",
|
|
|
+ " rng_key, rng_obs = jax.random.split(rng_key, 2)\n",
|
|
|
+ " state_seq = markov_chain_sample(rng_key, init_dist, trans_mat, seq_len)\n",
|
|
|
+ "\n",
|
|
|
+ " def draw_obs(z, key):\n",
|
|
|
+ " obs = jax.random.choice(key, n_obs, p=obs_mat[z])\n",
|
|
|
+ " return obs\n",
|
|
|
+ "\n",
|
|
|
+ " keys = jax.random.split(rng_obs, seq_len)\n",
|
|
|
+ " obs_seq = jax.vmap(draw_obs, in_axes=(0, 0))(state_seq, keys)\n",
|
|
|
+ " \n",
|
|
|
+ " return state_seq, obs_seq"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 70,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "#@partial(jit, static_argnums=(1,))\n",
|
|
|
+ "def hmm_sample2(rng_key, params, seq_len):\n",
|
|
|
+ "\n",
|
|
|
+ " trans_mat, obs_mat, init_dist = params.trans_mat, params.obs_mat, params.init_dist\n",
|
|
|
+ " n_states, n_obs = obs_mat.shape\n",
|
|
|
+ "\n",
|
|
|
+ " def draw_state(prev_state, key):\n",
|
|
|
+ " state = jax.random.choice(key, n_states, p=trans_mat[prev_state])\n",
|
|
|
+ " return state, state\n",
|
|
|
+ "\n",
|
|
|
+ " rng_key, rng_state, rng_obs = jax.random.split(rng_key, 3)\n",
|
|
|
+ " keys = jax.random.split(rng_state, seq_len - 1)\n",
|
|
|
+ " initial_state = jax.random.choice(rng_key, n_states, p=init_dist)\n",
|
|
|
+ " final_state, states = jax.lax.scan(draw_state, initial_state, keys)\n",
|
|
|
+ " state_seq = jnp.append(jnp.array([initial_state]), states)\n",
|
|
|
+ "\n",
|
|
|
+ " def draw_obs(z, key):\n",
|
|
|
+ " obs = jax.random.choice(key, n_obs, p=obs_mat[z])\n",
|
|
|
+ " return obs\n",
|
|
|
+ "\n",
|
|
|
+ " keys = jax.random.split(rng_obs, seq_len)\n",
|
|
|
+ " obs_seq = jax.vmap(draw_obs, in_axes=(0, 0))(state_seq, keys)\n",
|
|
|
+ "\n",
|
|
|
+ " return state_seq, obs_seq"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 93,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "[1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
|
|
|
+ " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1\n",
|
|
|
+ " 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
|
|
|
+ "[5 5 2 2 0 0 0 1 3 3 2 2 5 1 5 1 0 2 2 4 2 5 1 5 5 0 0 4 2 4 3 2 3 4 1 0 5\n",
|
|
|
+ " 2 2 2 1 4 3 2 2 2 4 1 0 3 5 2 5 1 4 2 5 2 5 0 5 4 4 4 2 2 0 4 5 2 2 0 1 5\n",
|
|
|
+ " 1 3 4 5 1 5 0 5 1 5 1 2 4 5 3 4 5 4 0 4 0 2 4 5 3 3]\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "\n",
|
|
|
+ "key = PRNGKey(2)\n",
|
|
|
+ "seq_len = 100\n",
|
|
|
+ "\n",
|
|
|
+ "state_seq, obs_seq = hmm_sample(key, params, seq_len)\n",
|
|
|
+ "print(state_seq)\n",
|
|
|
+ "print(obs_seq)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "### Check correctness by computing empirical pairwise statistics\n",
|
|
|
+ "\n",
|
|
|
+ "We will compute the number of i->j transitions, and check that it is close to the true \n",
|
|
|
+ "A[i,j] transition probabilites."
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 107,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "[0 0 1 1 1 1 0 0 1 1 1 0 1 0 0 1 1 1 1 0 1 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 1\n",
|
|
|
+ " 1 0 0 0 0 1 0 1 0 0 0 0 1 0 0 1 1 0 1 1 0 1 1 0 1 1 1 0 0 1 1 0 1 0 0 1 0\n",
|
|
|
+ " 1 0 0 0 0 0 0 0 0 0 0 1 1 1 1 0 0 0 1 1 1 0 1 0 0 0 0 1 0 0 0 0 1 1 0 0 0\n",
|
|
|
+ " 0 0 1 1 1 1 1 1 0 0 0 1 1 0 0 0 0 0 1 0 0 0 1 0 1 1 0 1 1 0 0 0 0 0 0 1 0\n",
|
|
|
+ " 0 0 0 1 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 1 1 1 0 1 1 0 0 0 0 0 1 0 0 0 0\n",
|
|
|
+ " 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 1 0 0 0 1 1 1 0 0 0 1 1 0 0 0\n",
|
|
|
+ " 0 0 0 1 1 1 0 0 0 0 1 0 0 1 1 1 0 1 1 1 1 1 0 1 1 0 0 0 1 1 0 1 0 0 1 0 0\n",
|
|
|
+ " 0 0 0 1 0 0 0 1 0 1 0 0 0 0 1 0 0 1 0 0 0 1 1 0 0 0 0 0 0 0 1 0 0 1 1 1 1\n",
|
|
|
+ " 1 1 0 0 0 0 0 1 1 0 0 0 0 0 0 1 0 0 1 0 0 0 0 0 0 1 0 0 0 0 1 0 1 0 0 0 1\n",
|
|
|
+ " 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 1 0 0 1 1 0 1 0 0 0\n",
|
|
|
+ " 0 0 0 0 0 0 0 1 0 0 1 1 1 1 0 0 1 1 0 0 0 0 1 1 0 1 1 0 0 0 0 0 0 0 0 1 0\n",
|
|
|
+ " 1 0 1 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0\n",
|
|
|
+ " 0 0 0 0 1 0 0 1 1 0 1 1 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 0 1 1 0 0 1 1 0 0 1\n",
|
|
|
+ " 1 0 0 0 0 0 0 0 1 0 0 1 1 0 0 0 0 1 1]\n",
|
|
|
+ "[[244. 93.]\n",
|
|
|
+ " [ 92. 70.]]\n",
|
|
|
+ "[[0.7240356 0.27596438]\n",
|
|
|
+ " [0.56790125 0.43209878]]\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "import collections\n",
|
|
|
+ "def compute_counts(state_seq, nstates):\n",
|
|
|
+ " wseq = np.array(state_seq)\n",
|
|
|
+ " word_pairs = [pair for pair in zip(wseq[:-1], wseq[1:])]\n",
|
|
|
+ " counter_pairs = collections.Counter(word_pairs)\n",
|
|
|
+ " counts = np.zeros((nstates, nstates))\n",
|
|
|
+ " for (k,v) in counter_pairs.items():\n",
|
|
|
+ " counts[k[0], k[1]] = v\n",
|
|
|
+ " return counts\n",
|
|
|
+ "\n",
|
|
|
+ "def normalize_counts(counts):\n",
|
|
|
+ " ncounts = vmap(lambda v: normalize(v)[0], in_axes=0)(counts)\n",
|
|
|
+ " return ncounts\n",
|
|
|
+ "\n",
|
|
|
+ "init_dist = jnp.array([1.0, 0.0])\n",
|
|
|
+ "trans_mat = jnp.array([[0.7, 0.3], [0.5, 0.5]])\n",
|
|
|
+ "rng_key = jax.random.PRNGKey(0)\n",
|
|
|
+ "seq_len = 500\n",
|
|
|
+ "state_seq = markov_chain_sample(rng_key, init_dist, trans_mat, seq_len)\n",
|
|
|
+ "print(state_seq)\n",
|
|
|
+ "\n",
|
|
|
+ "counts = compute_counts(state_seq, nstates=2)\n",
|
|
|
+ "print(counts)\n",
|
|
|
+ "\n",
|
|
|
+ "trans_mat_empirical = normalize_counts(counts)\n",
|
|
|
+ "print(trans_mat_empirical)\n",
|
|
|
+ "\n",
|
|
|
+ "assert jnp.allclose(trans_mat, trans_mat_empirical, atol=1e-1)\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": []
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "metadata": {
|
|
|
+ "kernelspec": {
|
|
|
+ "display_name": "Python 3",
|
|
|
+ "language": "python",
|
|
|
+ "name": "python3"
|
|
|
+ },
|
|
|
+ "language_info": {
|
|
|
+ "codemirror_mode": {
|
|
|
+ "name": "ipython",
|
|
|
+ "version": 3
|
|
|
+ },
|
|
|
+ "file_extension": ".py",
|
|
|
+ "mimetype": "text/x-python",
|
|
|
+ "name": "python",
|
|
|
+ "nbconvert_exporter": "python",
|
|
|
+ "pygments_lexer": "ipython3",
|
|
|
+ "version": "3.8.8"
|
|
|
+ }
|
|
|
+ },
|
|
|
+ "nbformat": 4,
|
|
|
+ "nbformat_minor": 4
|
|
|
+}
|