{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Implementing discrete HMMs in Numpy \n", "\n", "We start with a simple numpy implementation." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import jax\n", "import numpy as np\n", "import jax.numpy as jnp\n", "\n", "from dataclasses import dataclass" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import jsl" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def normalize(u, axis=0, eps=1e-15):\n", " u = np.where(u == 0, 0, np.where(u < eps, eps, u))\n", " c = u.sum(axis=axis)\n", " c = np.where(c == 0, 1, c)\n", " return u / c, c" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We first create the \"Ocassionally dishonest casino\" model from {cite}`Durbin98`." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "\n", "# state transition matrix\n", "A = np.array([\n", " [0.95, 0.05],\n", " [0.10, 0.90]\n", "])\n", "\n", "# observation matrix\n", "B = np.array([\n", " [1/6, 1/6, 1/6, 1/6, 1/6, 1/6], # fair die\n", " [1/10, 1/10, 1/10, 1/10, 1/10, 5/10] # loaded die\n", "])\n", "\n", "pi = np.array([1, 1]) / 2\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's bundle the parameters into a structure." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "@dataclass\n", "class HMMNumpy:\n", " trans_mat: np.array # A : (n_states, n_states)\n", " obs_mat: np.array # B : (n_states, n_obs)\n", " init_dist: np.array # pi : (n_states)\n", " \n", "\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "HMMNumpy(trans_mat=array([[0.95, 0.05],\n", " [0.1 , 0.9 ]]), obs_mat=array([[0.16666667, 0.16666667, 0.16666667, 0.16666667, 0.16666667,\n", " 0.16666667],\n", " [0.1 , 0.1 , 0.1 , 0.1 , 0.1 ,\n", " 0.5 ]]), init_dist=array([0.5, 0.5]))\n" ] } ], "source": [ "(nstates, nobs) = jnp.shape(B)\n", "for i in range(nstates):\n", " A[i,:] = normalize(A[i,:])[0]\n", " B[i,:] = normalize(B[i,:])[0]\n", " \n", "params_numpy = HMMNumpy(A, B, pi)\n", "print(params_numpy)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Function to sample a single sequence of hidden states and discrete observations." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "def hmm_sample_numpy(params, seq_len, random_state=0):\n", "\n", " def sample_one_step_(hist, a, p):\n", " x_t = np.random.choice(a=a, p=p)\n", " return np.append(hist, [x_t]), x_t\n", "\n", " np.random.seed(random_state)\n", "\n", " trans_mat, obs_mat, init_dist = params.trans_mat, params.obs_mat, params.init_dist\n", " n_states, n_obs = obs_mat.shape\n", "\n", " state_seq = np.array([], dtype=int)\n", " obs_seq = np.array([], dtype=int)\n", "\n", " latent_states = np.arange(n_states)\n", " obs_states = np.arange(n_obs)\n", "\n", " state_seq, zt = sample_one_step_(state_seq, latent_states, init_dist)\n", " obs_seq, xt = sample_one_step_(obs_seq, obs_states, obs_mat[zt])\n", "\n", " for t in range(1, seq_len):\n", " #print(t, zt, trans_mat[zt])\n", " state_seq, zt = sample_one_step_(state_seq, latent_states, trans_mat[zt])\n", " obs_seq, xt = sample_one_step_(obs_seq, obs_states, obs_mat[zt])\n", "\n", " return state_seq, obs_seq" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[1 1 1 1 1 1 1 0 0 0 1 1 1 1 1 1 1 0 0 0]\n", "[5 5 5 5 3 5 5 0 4 5 5 5 5 5 4 5 5 3 3 4]\n" ] } ], "source": [ "seq_len = 20\n", "state_seq, obs_seq = hmm_sample_numpy(params_numpy, seq_len, random_state=0)\n", "print(state_seq)\n", "print(obs_seq)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[1 1 1 1 1 1 1 0 0 0 1 1 1 1 1 1 1 0 0 0]\n", "[5 5 5 5 3 5 5 0 4 5 5 5 5 5 4 5 5 3 3 4]\n" ] } ], "source": [ "#from jsl.hmm.hmm_numpy_lib import HMMNumpy, hmm_forwards_backwards_numpy, hmm_loglikelihood_numpy\n", "import jsl.hmm.hmm_numpy_lib as hmm_np\n", "\n", "state_seq, obs_seq = hmm_np.hmm_sample_numpy(params_numpy, seq_len, random_state=0)\n", "print(state_seq)\n", "print(obs_seq)\n", "\n", " " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }