{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "(sec:hmm-ex)=\n", "# Hidden Markov Models\n", "\n", "In this section, we introduce Hidden Markov Models (HMMs)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Boilerplate" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Collecting jax[cpu]\n", " Downloading jax-0.3.5.tar.gz (946 kB)\n", "\u001b[K |████████████████████████████████| 946 kB 2.7 MB/s eta 0:00:01\n", "\u001b[?25hCollecting absl-py\n", " Downloading absl_py-1.0.0-py3-none-any.whl (126 kB)\n", "\u001b[K |████████████████████████████████| 126 kB 47.7 MB/s eta 0:00:01\n", "\u001b[?25hCollecting numpy>=1.19\n", " Downloading numpy-1.22.3-cp38-cp38-macosx_10_14_x86_64.whl (17.6 MB)\n", "\u001b[K |████████████████████████████████| 17.6 MB 47.5 MB/s eta 0:00:01\n", "\u001b[?25hCollecting opt_einsum\n", " Using cached opt_einsum-3.3.0-py3-none-any.whl (65 kB)\n", "Collecting scipy>=1.2.1\n", " Downloading scipy-1.8.0-cp38-cp38-macosx_12_0_universal2.macosx_10_9_x86_64.whl (55.3 MB)\n", "\u001b[K |████████████████████████████████| 55.3 MB 73.1 MB/s eta 0:00:01\n", "\u001b[?25hCollecting typing_extensions\n", " Using cached typing_extensions-4.1.1-py3-none-any.whl (26 kB)\n", "Collecting jaxlib==0.3.5\n", " Downloading jaxlib-0.3.5-cp38-none-macosx_10_9_x86_64.whl (70.5 MB)\n", "\u001b[K |████████████████████████████████| 70.5 MB 723 kB/s eta 0:00:01\n", "\u001b[?25hCollecting flatbuffers<3.0,>=1.12\n", " Using cached flatbuffers-2.0-py2.py3-none-any.whl (26 kB)\n", "Requirement already satisfied: six in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from absl-py->jax[cpu]) (1.16.0)\n", "Building wheels for collected packages: jax\n", " Building wheel for jax (setup.py) ... \u001b[?25ldone\n", "\u001b[?25h Created wheel for jax: filename=jax-0.3.5-py3-none-any.whl size=1095861 sha256=6886baa70817bbac3b5797b3720dcb81e49097f61cee7d2e1255823ea32ccad8\n", " Stored in directory: /Users/kpmurphy/Library/Caches/pip/wheels/05/30/aa/908988293721511b4b29e0aadf9b5d133d0f14f6c0a188e764\n", "Successfully built jax\n", "Installing collected packages: numpy, typing-extensions, scipy, opt-einsum, flatbuffers, absl-py, jaxlib, jax\n", "Successfully installed absl-py-1.0.0 flatbuffers-2.0 jax-0.3.5 jaxlib-0.3.5 numpy-1.22.3 opt-einsum-3.3.0 scipy-1.8.0 typing-extensions-4.1.1\n", "Note: you may need to restart the kernel to use updated packages.\n", "Collecting git+https://github.com/probml/jsl\n", " Cloning https://github.com/probml/jsl to /private/var/folders/mn/vt7cgfsx6zs9vblhvbbk7pf8003xtr/T/pip-req-build-i8seqdiw\n", " Running command git clone -q https://github.com/probml/jsl /private/var/folders/mn/vt7cgfsx6zs9vblhvbbk7pf8003xtr/T/pip-req-build-i8seqdiw\n", "Collecting chex\n", " Downloading chex-0.1.2-py3-none-any.whl (72 kB)\n", "\u001b[K |████████████████████████████████| 72 kB 1.3 MB/s eta 0:00:011\n", "\u001b[?25hCollecting dataclasses\n", " Using cached dataclasses-0.6-py3-none-any.whl (14 kB)\n", "Requirement already satisfied: jaxlib in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from jsl==0.0.0) (0.3.5)\n", "Requirement already satisfied: jax in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from jsl==0.0.0) (0.3.5)\n", "Collecting matplotlib\n", " Downloading matplotlib-3.5.1-cp38-cp38-macosx_10_9_x86_64.whl (7.3 MB)\n", "\u001b[K |████████████████████████████████| 7.3 MB 3.9 MB/s eta 0:00:01\n", "\u001b[?25hCollecting tensorflow_probability\n", " Using cached tensorflow_probability-0.16.0-py2.py3-none-any.whl (6.3 MB)\n", "Collecting dm-tree>=0.1.5\n", " Using cached dm_tree-0.1.6-cp38-cp38-macosx_10_14_x86_64.whl (95 kB)\n", "Requirement already satisfied: absl-py>=0.9.0 in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from chex->jsl==0.0.0) (1.0.0)\n", "Requirement already satisfied: numpy>=1.18.0 in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from chex->jsl==0.0.0) (1.22.3)\n", "Collecting toolz>=0.9.0\n", " Downloading toolz-0.11.2-py3-none-any.whl (55 kB)\n", "\u001b[K |████████████████████████████████| 55 kB 11.3 MB/s eta 0:00:01\n", "\u001b[?25hRequirement already satisfied: six in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from absl-py>=0.9.0->chex->jsl==0.0.0) (1.16.0)\n", "Requirement already satisfied: typing-extensions in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from jax->jsl==0.0.0) (4.1.1)\n", "Requirement already satisfied: scipy>=1.2.1 in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from jax->jsl==0.0.0) (1.8.0)\n", "Requirement already satisfied: opt-einsum in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from jax->jsl==0.0.0) (3.3.0)\n", "Requirement already satisfied: flatbuffers<3.0,>=1.12 in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from jaxlib->jsl==0.0.0) (2.0)\n", "Collecting cycler>=0.10\n", " Downloading cycler-0.11.0-py3-none-any.whl (6.4 kB)\n", "Collecting kiwisolver>=1.0.1\n", " Downloading kiwisolver-1.4.2-cp38-cp38-macosx_10_9_x86_64.whl (65 kB)\n", "\u001b[K |████████████████████████████████| 65 kB 8.5 MB/s eta 0:00:01\n", "\u001b[?25hCollecting fonttools>=4.22.0\n", " Downloading fonttools-4.32.0-py3-none-any.whl (900 kB)\n", "\u001b[K |████████████████████████████████| 900 kB 35.8 MB/s eta 0:00:01\n", "\u001b[?25hRequirement already satisfied: pyparsing>=2.2.1 in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from matplotlib->jsl==0.0.0) (3.0.7)\n", "Requirement already satisfied: packaging>=20.0 in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from matplotlib->jsl==0.0.0) (21.3)\n", "Requirement already satisfied: python-dateutil>=2.7 in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from matplotlib->jsl==0.0.0) (2.8.2)\n", "Collecting pillow>=6.2.0\n", " Downloading Pillow-9.1.0-cp38-cp38-macosx_10_9_x86_64.whl (3.1 MB)\n", "\u001b[K |████████████████████████████████| 3.1 MB 76.6 MB/s eta 0:00:01\n", "\u001b[?25hCollecting gast>=0.3.2\n", " Downloading gast-0.5.3-py3-none-any.whl (19 kB)\n", "Requirement already satisfied: decorator in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from tensorflow_probability->jsl==0.0.0) (5.1.1)\n", "Collecting cloudpickle>=1.3\n", " Downloading cloudpickle-2.0.0-py3-none-any.whl (25 kB)\n", "Building wheels for collected packages: jsl\n", " Building wheel for jsl (setup.py) ... \u001b[?25ldone\n", "\u001b[?25h Created wheel for jsl: filename=jsl-0.0.0-py3-none-any.whl size=77852 sha256=e7365293dc97e2b3e72bf42cc19db7d7e355abec312fc4d87961fa2044fa06f0\n", " Stored in directory: /private/var/folders/mn/vt7cgfsx6zs9vblhvbbk7pf8003xtr/T/pip-ephem-wheel-cache-63vxzlng/wheels/ed/8b/bf/0105dc839fecf1fc8db14f7267a6ce5ee876324b58565b359f\n", "Successfully built jsl\n", "Installing collected packages: toolz, pillow, kiwisolver, gast, fonttools, dm-tree, cycler, cloudpickle, tensorflow-probability, matplotlib, dataclasses, chex, jsl\n", "Successfully installed chex-0.1.2 cloudpickle-2.0.0 cycler-0.11.0 dataclasses-0.6 dm-tree-0.1.6 fonttools-4.32.0 gast-0.5.3 jsl-0.0.0 kiwisolver-1.4.2 matplotlib-3.5.1 pillow-9.1.0 tensorflow-probability-0.16.0 toolz-0.11.2\n", "Note: you may need to restart the kernel to use updated packages.\n", "Collecting rich\n", " Downloading rich-12.2.0-py3-none-any.whl (229 kB)\n", "\u001b[K |████████████████████████████████| 229 kB 2.9 MB/s eta 0:00:01\n", "\u001b[?25hRequirement already satisfied: typing-extensions<5.0,>=4.0.0 in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from rich) (4.1.1)\n", "Requirement already satisfied: pygments<3.0.0,>=2.6.0 in /opt/anaconda3/envs/scripts/lib/python3.8/site-packages (from rich) (2.11.2)\n", "Collecting commonmark<0.10.0,>=0.9.0\n", " Using cached commonmark-0.9.1-py2.py3-none-any.whl (51 kB)\n", "Installing collected packages: commonmark, rich\n", "Successfully installed commonmark-0.9.1 rich-12.2.0\n", "Note: you may need to restart the kernel to use updated packages.\n" ] } ], "source": [ "# Install necessary libraries\n", "\n", "try:\n", " import jax\n", "except:\n", " # For cuda version, see https://github.com/google/jax#installation\n", " %pip install --upgrade \"jax[cpu]\" \n", " import jax\n", "\n", "try:\n", " import jsl\n", "except:\n", " %pip install git+https://github.com/probml/jsl\n", " import jsl\n", "\n", "try:\n", " import rich\n", "except:\n", " %pip install rich\n", " import rich\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Import standard libraries\n", "\n", "import abc\n", "from dataclasses import dataclass\n", "import functools\n", "import itertools\n", "\n", "from typing import Any, Callable, NamedTuple, Optional, Union, Tuple\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "\n", "import jax\n", "import jax.numpy as jnp\n", "from jax import lax, vmap, jit, grad\n", "from jax.scipy.special import logit\n", "from jax.nn import softmax\n", "from functools import partial\n", "from jax.random import PRNGKey, split\n", "\n", "import inspect\n", "import inspect as py_inspect\n", "from rich import inspect as r_inspect\n", "from rich import print as r_print\n", "\n", "def print_source(fname):\n", " r_print(py_inspect.getsource(fname))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Utility code" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "\n", "\n", "def normalize(u, axis=0, eps=1e-15):\n", " '''\n", " Normalizes the values within the axis in a way that they sum up to 1.\n", " Parameters\n", " ----------\n", " u : array\n", " axis : int\n", " eps : float\n", " Threshold for the alpha values\n", " Returns\n", " -------\n", " * array\n", " Normalized version of the given matrix\n", " * array(seq_len, n_hidden) :\n", " The values of the normalizer\n", " '''\n", " u = jnp.where(u == 0, 0, jnp.where(u < eps, eps, u))\n", " c = u.sum(axis=axis)\n", " c = jnp.where(c == 0, 1, c)\n", " return u / c, c" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "(sec:casino-ex)=\n", "## Example: Casino HMM\n", "\n", "We first create the \"Ocassionally dishonest casino\" model from {cite}`Durbin98`.\n", "\n", "```{figure} /figures/casino.png\n", ":scale: 50%\n", ":name: casino-fig\n", "\n", "Illustration of the casino HMM.\n", "```\n", "\n", "There are 2 hidden states, each of which emit 6 possible observations." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" ] } ], "source": [ "# state transition matrix\n", "A = np.array([\n", " [0.95, 0.05],\n", " [0.10, 0.90]\n", "])\n", "\n", "# observation matrix\n", "B = np.array([\n", " [1/6, 1/6, 1/6, 1/6, 1/6, 1/6], # fair die\n", " [1/10, 1/10, 1/10, 1/10, 1/10, 5/10] # loaded die\n", "])\n", "\n", "pi, _ = normalize(np.array([1, 1]))\n", "pi = np.array(pi)\n", "\n", "\n", "(nstates, nobs) = np.shape(B)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's make a little data structure to store all the parameters.\n", "We use NamedTuple rather than dataclass, since we assume these are immutable.\n", "(Also, standard python dataclass does not work well with JAX, which requires parameters to be\n", "pytrees, as discussed in https://github.com/google/jax/issues/2371)." ] }, { "cell_type": "code", "execution_count": 74, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "HMM(trans_mat=array([[0.95, 0.05],\n", " [0.1 , 0.9 ]]), obs_mat=array([[0.16666667, 0.16666667, 0.16666667, 0.16666667, 0.16666667,\n", " 0.16666667],\n", " [0.1 , 0.1 , 0.1 , 0.1 , 0.1 ,\n", " 0.5 ]]), init_dist=array([0.5, 0.5], dtype=float32))\n", "\n", "HMM(trans_mat=DeviceArray([[0.95, 0.05],\n", " [0.1 , 0.9 ]], dtype=float32), obs_mat=DeviceArray([[0.16666667, 0.16666667, 0.16666667, 0.16666667, 0.16666667,\n", " 0.16666667],\n", " [0.1 , 0.1 , 0.1 , 0.1 , 0.1 ,\n", " 0.5 ]], dtype=float32), init_dist=DeviceArray([0.5, 0.5], dtype=float32))\n", "\n" ] } ], "source": [ "Array = Union[np.array, jnp.array]\n", "\n", "class HMM(NamedTuple):\n", " trans_mat: Array # A : (n_states, n_states)\n", " obs_mat: Array # B : (n_states, n_obs)\n", " init_dist: Array # pi : (n_states)\n", "\n", "params_np = HMM(A, B, pi)\n", "print(params_np)\n", "print(type(params_np.trans_mat))\n", "\n", "\n", "params = jax.tree_map(lambda x: jnp.array(x), params_np)\n", "print(params)\n", "print(type(params.trans_mat))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Sampling from the joint\n", "\n", "Let's write code to sample from this model. \n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Numpy version\n", "\n", "First we code it in numpy using a for loop." ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "def hmm_sample_np(params, seq_len, random_state=0):\n", " np.random.seed(random_state)\n", " trans_mat, obs_mat, init_dist = params.trans_mat, params.obs_mat, params.init_dist\n", " n_states, n_obs = obs_mat.shape\n", " state_seq = np.zeros(seq_len, dtype=int)\n", " obs_seq = np.zeros(seq_len, dtype=int)\n", " for t in range(seq_len):\n", " if t==0:\n", " zt = np.random.choice(n_states, p=init_dist)\n", " else:\n", " zt = np.random.choice(n_states, p=trans_mat[zt])\n", " yt = np.random.choice(n_obs, p=obs_mat[zt])\n", " state_seq[t] = zt\n", " obs_seq[t] = yt\n", "\n", " return state_seq, obs_seq" ] }, { "cell_type": "code", "execution_count": 75, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0\n", " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", " 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n", "[4 1 0 2 3 4 5 4 3 1 5 4 5 0 5 2 5 3 5 4 5 5 4 2 1 4 1 0 0 4 2 2 3 3 3 0 4\n", " 0 2 4 3 2 5 5 3 5 3 1 3 3 3 2 3 5 5 0 4 4 5 0 0 1 3 5 1 5 0 1 2 4 0 0 0 4\n", " 0 5 1 4 3 5 4 5 0 2 3 5 2 4 1 2 1 0 4 3 5 0 4 5 1 5]\n" ] } ], "source": [ "seq_len = 100\n", "state_seq, obs_seq = hmm_sample_np(params_np, seq_len, random_state=1)\n", "print(state_seq)\n", "print(obs_seq)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### JAX version\n", "\n", "Now let's write a JAX version using jax.lax.scan (for the inter-dependent states) and vmap (for the observations).\n", "This is harder to read than the numpy version, but faster." ] }, { "cell_type": "code", "execution_count": 91, "metadata": {}, "outputs": [], "source": [ "#@partial(jit, static_argnums=(1,))\n", "def markov_chain_sample(rng_key, init_dist, trans_mat, seq_len):\n", " n_states = len(init_dist)\n", "\n", " def draw_state(prev_state, key):\n", " state = jax.random.choice(key, n_states, p=trans_mat[prev_state])\n", " return state, state\n", "\n", " rng_key, rng_state = jax.random.split(rng_key, 2)\n", " keys = jax.random.split(rng_state, seq_len - 1)\n", " initial_state = jax.random.choice(rng_key, n_states, p=init_dist)\n", " final_state, states = jax.lax.scan(draw_state, initial_state, keys)\n", " state_seq = jnp.append(jnp.array([initial_state]), states)\n", "\n", " return state_seq" ] }, { "cell_type": "code", "execution_count": 90, "metadata": {}, "outputs": [], "source": [ "#@partial(jit, static_argnums=(1,))\n", "def hmm_sample(rng_key, params, seq_len):\n", "\n", " trans_mat, obs_mat, init_dist = params.trans_mat, params.obs_mat, params.init_dist\n", " n_states, n_obs = obs_mat.shape\n", " rng_key, rng_obs = jax.random.split(rng_key, 2)\n", " state_seq = markov_chain_sample(rng_key, init_dist, trans_mat, seq_len)\n", "\n", " def draw_obs(z, key):\n", " obs = jax.random.choice(key, n_obs, p=obs_mat[z])\n", " return obs\n", "\n", " keys = jax.random.split(rng_obs, seq_len)\n", " obs_seq = jax.vmap(draw_obs, in_axes=(0, 0))(state_seq, keys)\n", " \n", " return state_seq, obs_seq" ] }, { "cell_type": "code", "execution_count": 70, "metadata": {}, "outputs": [], "source": [ "#@partial(jit, static_argnums=(1,))\n", "def hmm_sample2(rng_key, params, seq_len):\n", "\n", " trans_mat, obs_mat, init_dist = params.trans_mat, params.obs_mat, params.init_dist\n", " n_states, n_obs = obs_mat.shape\n", "\n", " def draw_state(prev_state, key):\n", " state = jax.random.choice(key, n_states, p=trans_mat[prev_state])\n", " return state, state\n", "\n", " rng_key, rng_state, rng_obs = jax.random.split(rng_key, 3)\n", " keys = jax.random.split(rng_state, seq_len - 1)\n", " initial_state = jax.random.choice(rng_key, n_states, p=init_dist)\n", " final_state, states = jax.lax.scan(draw_state, initial_state, keys)\n", " state_seq = jnp.append(jnp.array([initial_state]), states)\n", "\n", " def draw_obs(z, key):\n", " obs = jax.random.choice(key, n_obs, p=obs_mat[z])\n", " return obs\n", "\n", " keys = jax.random.split(rng_obs, seq_len)\n", " obs_seq = jax.vmap(draw_obs, in_axes=(0, 0))(state_seq, keys)\n", "\n", " return state_seq, obs_seq" ] }, { "cell_type": "code", "execution_count": 93, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1\n", " 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n", "[5 5 2 2 0 0 0 1 3 3 2 2 5 1 5 1 0 2 2 4 2 5 1 5 5 0 0 4 2 4 3 2 3 4 1 0 5\n", " 2 2 2 1 4 3 2 2 2 4 1 0 3 5 2 5 1 4 2 5 2 5 0 5 4 4 4 2 2 0 4 5 2 2 0 1 5\n", " 1 3 4 5 1 5 0 5 1 5 1 2 4 5 3 4 5 4 0 4 0 2 4 5 3 3]\n" ] } ], "source": [ "\n", "key = PRNGKey(2)\n", "seq_len = 100\n", "\n", "state_seq, obs_seq = hmm_sample(key, params, seq_len)\n", "print(state_seq)\n", "print(obs_seq)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Check correctness by computing empirical pairwise statistics\n", "\n", "We will compute the number of i->j transitions, and check that it is close to the true \n", "A[i,j] transition probabilites." ] }, { "cell_type": "code", "execution_count": 107, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0 0 1 1 1 1 0 0 1 1 1 0 1 0 0 1 1 1 1 0 1 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 1\n", " 1 0 0 0 0 1 0 1 0 0 0 0 1 0 0 1 1 0 1 1 0 1 1 0 1 1 1 0 0 1 1 0 1 0 0 1 0\n", " 1 0 0 0 0 0 0 0 0 0 0 1 1 1 1 0 0 0 1 1 1 0 1 0 0 0 0 1 0 0 0 0 1 1 0 0 0\n", " 0 0 1 1 1 1 1 1 0 0 0 1 1 0 0 0 0 0 1 0 0 0 1 0 1 1 0 1 1 0 0 0 0 0 0 1 0\n", " 0 0 0 1 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 1 1 1 0 1 1 0 0 0 0 0 1 0 0 0 0\n", " 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 1 0 0 0 1 1 1 0 0 0 1 1 0 0 0\n", " 0 0 0 1 1 1 0 0 0 0 1 0 0 1 1 1 0 1 1 1 1 1 0 1 1 0 0 0 1 1 0 1 0 0 1 0 0\n", " 0 0 0 1 0 0 0 1 0 1 0 0 0 0 1 0 0 1 0 0 0 1 1 0 0 0 0 0 0 0 1 0 0 1 1 1 1\n", " 1 1 0 0 0 0 0 1 1 0 0 0 0 0 0 1 0 0 1 0 0 0 0 0 0 1 0 0 0 0 1 0 1 0 0 0 1\n", " 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 1 0 0 1 1 0 1 0 0 0\n", " 0 0 0 0 0 0 0 1 0 0 1 1 1 1 0 0 1 1 0 0 0 0 1 1 0 1 1 0 0 0 0 0 0 0 0 1 0\n", " 1 0 1 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0\n", " 0 0 0 0 1 0 0 1 1 0 1 1 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 0 1 1 0 0 1 1 0 0 1\n", " 1 0 0 0 0 0 0 0 1 0 0 1 1 0 0 0 0 1 1]\n", "[[244. 93.]\n", " [ 92. 70.]]\n", "[[0.7240356 0.27596438]\n", " [0.56790125 0.43209878]]\n" ] } ], "source": [ "import collections\n", "def compute_counts(state_seq, nstates):\n", " wseq = np.array(state_seq)\n", " word_pairs = [pair for pair in zip(wseq[:-1], wseq[1:])]\n", " counter_pairs = collections.Counter(word_pairs)\n", " counts = np.zeros((nstates, nstates))\n", " for (k,v) in counter_pairs.items():\n", " counts[k[0], k[1]] = v\n", " return counts\n", "\n", "def normalize_counts(counts):\n", " ncounts = vmap(lambda v: normalize(v)[0], in_axes=0)(counts)\n", " return ncounts\n", "\n", "init_dist = jnp.array([1.0, 0.0])\n", "trans_mat = jnp.array([[0.7, 0.3], [0.5, 0.5]])\n", "rng_key = jax.random.PRNGKey(0)\n", "seq_len = 500\n", "state_seq = markov_chain_sample(rng_key, init_dist, trans_mat, seq_len)\n", "print(state_seq)\n", "\n", "counts = compute_counts(state_seq, nstates=2)\n", "print(counts)\n", "\n", "trans_mat_empirical = normalize_counts(counts)\n", "print(trans_mat_empirical)\n", "\n", "assert jnp.allclose(trans_mat, trans_mat_empirical, atol=1e-1)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.8" } }, "nbformat": 4, "nbformat_minor": 4 }