{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# meta-data does not work yet in VScode\n", "# https://github.com/microsoft/vscode-jupyter/issues/1121\n", "\n", "{\n", " \"tags\": [\n", " \"hide-cell\"\n", " ]\n", "}\n", "\n", "\n", "### Install necessary libraries\n", "\n", "try:\n", " import jax\n", "except:\n", " # For cuda version, see https://github.com/google/jax#installation\n", " %pip install --upgrade \"jax[cpu]\" \n", " import jax\n", "\n", "try:\n", " import distrax\n", "except:\n", " %pip install --upgrade distrax\n", " import distrax\n", "\n", "try:\n", " import jsl\n", "except:\n", " %pip install git+https://github.com/probml/jsl\n", " import jsl\n", "\n", "try:\n", " import rich\n", "except:\n", " %pip install rich\n", " import rich\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "{\n", " \"tags\": [\n", " \"hide-cell\"\n", " ]\n", "}\n", "\n", "\n", "### Import standard libraries\n", "\n", "import abc\n", "from dataclasses import dataclass\n", "import functools\n", "import itertools\n", "\n", "from typing import Any, Callable, NamedTuple, Optional, Union, Tuple\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "\n", "import jax\n", "import jax.numpy as jnp\n", "from jax import lax, vmap, jit, grad\n", "from jax.scipy.special import logit\n", "from jax.nn import softmax\n", "from functools import partial\n", "from jax.random import PRNGKey, split\n", "\n", "import inspect\n", "import inspect as py_inspect\n", "import rich\n", "from rich import inspect as r_inspect\n", "from rich import print as r_print\n", "\n", "def print_source(fname):\n", " r_print(py_inspect.getsource(fname))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```{math}\n", "\n", "\\newcommand\\floor[1]{\\lfloor#1\\rfloor}\n", "\n", "\\newcommand{\\real}{\\mathbb{R}}\n", "\n", "% Numbers\n", "\\newcommand{\\vzero}{\\boldsymbol{0}}\n", "\\newcommand{\\vone}{\\boldsymbol{1}}\n", "\n", "% Greek https://www.latex-tutorial.com/symbols/greek-alphabet/\n", "\\newcommand{\\valpha}{\\boldsymbol{\\alpha}}\n", "\\newcommand{\\vbeta}{\\boldsymbol{\\beta}}\n", "\\newcommand{\\vchi}{\\boldsymbol{\\chi}}\n", "\\newcommand{\\vdelta}{\\boldsymbol{\\delta}}\n", "\\newcommand{\\vDelta}{\\boldsymbol{\\Delta}}\n", "\\newcommand{\\vepsilon}{\\boldsymbol{\\epsilon}}\n", "\\newcommand{\\vzeta}{\\boldsymbol{\\zeta}}\n", "\\newcommand{\\vXi}{\\boldsymbol{\\Xi}}\n", "\\newcommand{\\vell}{\\boldsymbol{\\ell}}\n", "\\newcommand{\\veta}{\\boldsymbol{\\eta}}\n", "%\\newcommand{\\vEta}{\\boldsymbol{\\Eta}}\n", "\\newcommand{\\vgamma}{\\boldsymbol{\\gamma}}\n", "\\newcommand{\\vGamma}{\\boldsymbol{\\Gamma}}\n", "\\newcommand{\\vmu}{\\boldsymbol{\\mu}}\n", "\\newcommand{\\vmut}{\\boldsymbol{\\tilde{\\mu}}}\n", "\\newcommand{\\vnu}{\\boldsymbol{\\nu}}\n", "\\newcommand{\\vkappa}{\\boldsymbol{\\kappa}}\n", "\\newcommand{\\vlambda}{\\boldsymbol{\\lambda}}\n", "\\newcommand{\\vLambda}{\\boldsymbol{\\Lambda}}\n", "\\newcommand{\\vLambdaBar}{\\overline{\\vLambda}}\n", "%\\newcommand{\\vnu}{\\boldsymbol{\\nu}}\n", "\\newcommand{\\vomega}{\\boldsymbol{\\omega}}\n", "\\newcommand{\\vOmega}{\\boldsymbol{\\Omega}}\n", "\\newcommand{\\vphi}{\\boldsymbol{\\phi}}\n", "\\newcommand{\\vvarphi}{\\boldsymbol{\\varphi}}\n", "\\newcommand{\\vPhi}{\\boldsymbol{\\Phi}}\n", "\\newcommand{\\vpi}{\\boldsymbol{\\pi}}\n", "\\newcommand{\\vPi}{\\boldsymbol{\\Pi}}\n", "\\newcommand{\\vpsi}{\\boldsymbol{\\psi}}\n", "\\newcommand{\\vPsi}{\\boldsymbol{\\Psi}}\n", "\\newcommand{\\vrho}{\\boldsymbol{\\rho}}\n", "\\newcommand{\\vtheta}{\\boldsymbol{\\theta}}\n", "\\newcommand{\\vthetat}{\\boldsymbol{\\tilde{\\theta}}}\n", "\\newcommand{\\vTheta}{\\boldsymbol{\\Theta}}\n", "\\newcommand{\\vsigma}{\\boldsymbol{\\sigma}}\n", "\\newcommand{\\vSigma}{\\boldsymbol{\\Sigma}}\n", "\\newcommand{\\vSigmat}{\\boldsymbol{\\tilde{\\Sigma}}}\n", "\\newcommand{\\vsigmoid}{\\vsigma}\n", "\\newcommand{\\vtau}{\\boldsymbol{\\tau}}\n", "\\newcommand{\\vxi}{\\boldsymbol{\\xi}}\n", "\n", "\n", "% Lower Roman (Vectors)\n", "\\newcommand{\\va}{\\mathbf{a}}\n", "\\newcommand{\\vb}{\\mathbf{b}}\n", "\\newcommand{\\vBt}{\\mathbf{\\tilde{B}}}\n", "\\newcommand{\\vc}{\\mathbf{c}}\n", "\\newcommand{\\vct}{\\mathbf{\\tilde{c}}}\n", "\\newcommand{\\vd}{\\mathbf{d}}\n", "\\newcommand{\\ve}{\\mathbf{e}}\n", "\\newcommand{\\vf}{\\mathbf{f}}\n", "\\newcommand{\\vg}{\\mathbf{g}}\n", "\\newcommand{\\vh}{\\mathbf{h}}\n", "%\\newcommand{\\myvh}{\\mathbf{h}}\n", "\\newcommand{\\vi}{\\mathbf{i}}\n", "\\newcommand{\\vj}{\\mathbf{j}}\n", "\\newcommand{\\vk}{\\mathbf{k}}\n", "\\newcommand{\\vl}{\\mathbf{l}}\n", "\\newcommand{\\vm}{\\mathbf{m}}\n", "\\newcommand{\\vn}{\\mathbf{n}}\n", "\\newcommand{\\vo}{\\mathbf{o}}\n", "\\newcommand{\\vp}{\\mathbf{p}}\n", "\\newcommand{\\vq}{\\mathbf{q}}\n", "\\newcommand{\\vr}{\\mathbf{r}}\n", "\\newcommand{\\vs}{\\mathbf{s}}\n", "\\newcommand{\\vt}{\\mathbf{t}}\n", "\\newcommand{\\vu}{\\mathbf{u}}\n", "\\newcommand{\\vv}{\\mathbf{v}}\n", "\\newcommand{\\vw}{\\mathbf{w}}\n", "\\newcommand{\\vws}{\\vw_s}\n", "\\newcommand{\\vwt}{\\mathbf{\\tilde{w}}}\n", "\\newcommand{\\vWt}{\\mathbf{\\tilde{W}}}\n", "\\newcommand{\\vwh}{\\hat{\\vw}}\n", "\\newcommand{\\vx}{\\mathbf{x}}\n", "%\\newcommand{\\vx}{\\mathbf{x}}\n", "\\newcommand{\\vxt}{\\mathbf{\\tilde{x}}}\n", "\\newcommand{\\vy}{\\mathbf{y}}\n", "\\newcommand{\\vyt}{\\mathbf{\\tilde{y}}}\n", "\\newcommand{\\vz}{\\mathbf{z}}\n", "%\\newcommand{\\vzt}{\\mathbf{\\tilde{z}}}\n", "\n", "\n", "% Upper Roman (Matrices)\n", "\\newcommand{\\vA}{\\mathbf{A}}\n", "\\newcommand{\\vB}{\\mathbf{B}}\n", "\\newcommand{\\vC}{\\mathbf{C}}\n", "\\newcommand{\\vD}{\\mathbf{D}}\n", "\\newcommand{\\vE}{\\mathbf{E}}\n", "\\newcommand{\\vF}{\\mathbf{F}}\n", "\\newcommand{\\vG}{\\mathbf{G}}\n", "\\newcommand{\\vH}{\\mathbf{H}}\n", "\\newcommand{\\vI}{\\mathbf{I}}\n", "\\newcommand{\\vJ}{\\mathbf{J}}\n", "\\newcommand{\\vK}{\\mathbf{K}}\n", "\\newcommand{\\vL}{\\mathbf{L}}\n", "\\newcommand{\\vM}{\\mathbf{M}}\n", "\\newcommand{\\vMt}{\\mathbf{\\tilde{M}}}\n", "\\newcommand{\\vN}{\\mathbf{N}}\n", "\\newcommand{\\vO}{\\mathbf{O}}\n", "\\newcommand{\\vP}{\\mathbf{P}}\n", "\\newcommand{\\vQ}{\\mathbf{Q}}\n", "\\newcommand{\\vR}{\\mathbf{R}}\n", "\\newcommand{\\vS}{\\mathbf{S}}\n", "\\newcommand{\\vT}{\\mathbf{T}}\n", "\\newcommand{\\vU}{\\mathbf{U}}\n", "\\newcommand{\\vV}{\\mathbf{V}}\n", "\\newcommand{\\vW}{\\mathbf{W}}\n", "\\newcommand{\\vX}{\\mathbf{X}}\n", "%\\newcommand{\\vXs}{\\vX_{\\vs}}\n", "\\newcommand{\\vXs}{\\vX_{s}}\n", "\\newcommand{\\vXt}{\\mathbf{\\tilde{X}}}\n", "\\newcommand{\\vY}{\\mathbf{Y}}\n", "\\newcommand{\\vZ}{\\mathbf{Z}}\n", "\\newcommand{\\vZt}{\\mathbf{\\tilde{Z}}}\n", "\\newcommand{\\vzt}{\\mathbf{\\tilde{z}}}\n", "\n", "\n", "%%%%\n", "\\newcommand{\\hidden}{\\vz}\n", "\\newcommand{\\hid}{\\hidden}\n", "\\newcommand{\\observed}{\\vy}\n", "\\newcommand{\\obs}{\\observed}\n", "\\newcommand{\\inputs}{\\vu}\n", "\\newcommand{\\input}{\\inputs}\n", "\n", "\\newcommand{\\hmmTrans}{\\vA}\n", "\\newcommand{\\hmmObs}{\\vB}\n", "\\newcommand{\\hmmInit}{\\vpi}\n", "\\newcommand{\\hmmhid}{\\hidden}\n", "\\newcommand{\\hmmobs}{\\obs}\n", "\n", "\\newcommand{\\ldsDyn}{\\vA}\n", "\\newcommand{\\ldsObs}{\\vC}\n", "\\newcommand{\\ldsDynIn}{\\vB}\n", "\\newcommand{\\ldsObsIn}{\\vD}\n", "\\newcommand{\\ldsDynNoise}{\\vQ}\n", "\\newcommand{\\ldsObsNoise}{\\vR}\n", "\n", "\\newcommand{\\ssmDynFn}{f}\n", "\\newcommand{\\ssmObsFn}{h}\n", "\n", "\n", "%%%\n", "\\newcommand{\\gauss}{\\mathcal{N}}\n", "\n", "\\newcommand{\\diag}{\\mathrm{diag}}\n", "```\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "(sec:hmm-intro)=\n", "# Hidden Markov Models\n", "\n", "In this section, we discuss the\n", "hidden Markov model or HMM,\n", "which is a state space model in which the hidden states\n", "are discrete, so $\\hmmhid_t \\in \\{1,\\ldots, K\\}$.\n", "The observations may be discrete,\n", "$\\hmmobs_t \\in \\{1,\\ldots, C\\}$,\n", "or continuous,\n", "$\\hmmobs_t \\in \\real^D$,\n", "or some combination,\n", "as we illustrate below.\n", "More details can be found in e.g., \n", "{cite}`Rabiner89,Fraser08,Cappe05`.\n", "For an interactive introduction,\n", "see https://nipunbatra.github.io/hmm/." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "(sec:casino)=\n", "### Example: Casino HMM\n", "\n", "To illustrate HMMs with categorical observation model,\n", "we consider the \"Ocassionally dishonest casino\" model from {cite}`Durbin98`.\n", "There are 2 hidden states, representing whether the dice being used in the casino is fair or loaded.\n", "Each state defines a distribution over the 6 possible observations.\n", "\n", "The transition model is denoted by\n", "```{math}\n", "p(z_t=j|z_{t-1}=i) = \\hmmTrans_{ij}\n", "```\n", "Here the $i$'th row of $\\vA$ corresponds to the outgoing distribution from state $i$.\n", "This is a row stochastic matrix,\n", "meaning each row sums to one.\n", "We can visualize\n", "the non-zero entries in the transition matrix by creating a state transition diagram,\n", "as shown in \n", "{numref}`fig:casino`.\n", "\n", "```{figure} /figures/casino.png\n", ":scale: 50%\n", ":name: fig:casino\n", "\n", "Illustration of the casino HMM.\n", "```\n", "\n", "The observation model\n", "$p(\\obs_t|\\hidden_t=j)$ has the form\n", "```{math}\n", "p(\\obs_t=k|\\hidden_t=j) = \\hmmObs_{jk} \n", "```\n", "This is represented by the histograms associated with each\n", "state in {numref}`casino-fig`.\n", "\n", "Finally,\n", "the initial state distribution is denoted by\n", "```{math}\n", "p(z_1=j) = \\hmmInit_j\n", "```\n", "\n", "Collectively we denote all the parameters by $\\vtheta=(\\hmmTrans, \\hmmObs, \\hmmInit)$.\n", "\n", "Now let us implement this model in code." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# state transition matrix\n", "A = np.array([\n", " [0.95, 0.05],\n", " [0.10, 0.90]\n", "])\n", "\n", "# observation matrix\n", "B = np.array([\n", " [1/6, 1/6, 1/6, 1/6, 1/6, 1/6], # fair die\n", " [1/10, 1/10, 1/10, 1/10, 1/10, 5/10] # loaded die\n", "])\n", "\n", "pi = np.array([0.5, 0.5])\n", "\n", "(nstates, nobs) = np.shape(B)\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "import distrax\n", "from distrax import HMM\n", "\n", "\n", "hmm = HMM(trans_dist=distrax.Categorical(probs=A),\n", " init_dist=distrax.Categorical(probs=pi),\n", " obs_dist=distrax.Categorical(probs=B))\n", "\n", "print(hmm)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "Let's sample from the model. We will generate a sequence of latent states, $\\hid_{1:T}$,\n", "which we then convert to a sequence of observations, $\\obs_{1:T}$." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Printing sample observed/latent...\n", "x: 633665342652353616444236412331351246651613325161656366246242\n", "z: 222222211111111111111111111111111111111222111111112222211111\n" ] } ], "source": [ "\n", "\n", "\n", "seed = 314\n", "n_samples = 300\n", "z_hist, x_hist = hmm.sample(seed=PRNGKey(seed), seq_len=n_samples)\n", "\n", "z_hist_str = \"\".join((np.array(z_hist) + 1).astype(str))[:60]\n", "x_hist_str = \"\".join((np.array(x_hist) + 1).astype(str))[:60]\n", "\n", "print(\"Printing sample observed/latent...\")\n", "print(f\"x: {x_hist_str}\")\n", "print(f\"z: {z_hist_str}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Below is the source code for the sampling algorithm.\n", "\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
  def sample(self,\n",
       "             *,\n",
       "             seed: chex.PRNGKey,\n",
       "             seq_len: chex.Array) -> Tuple:\n",
       "    \"\"\"Sample from this HMM.\n",
       "\n",
       "    Samples an observation of given length according to this\n",
       "    Hidden Markov Model and gives the sequence of the hidden states\n",
       "    as well as the observation.\n",
       "\n",
       "    Args:\n",
       "      seed: Random key of shape (2,) and dtype uint32.\n",
       "      seq_len: The length of the observation sequence.\n",
       "\n",
       "    Returns:\n",
       "      Tuple of hidden state sequence, and observation sequence.\n",
       "    \"\"\"\n",
       "    rng_key, rng_init = jax.random.split(seed)\n",
       "    initial_state = self._init_dist.sample(seed=rng_init)\n",
       "\n",
       "    def draw_state(prev_state, key):\n",
       "      state = self._trans_dist.sample(seed=key)\n",
       "      return state, state\n",
       "\n",
       "    rng_state, rng_obs = jax.random.split(rng_key)\n",
       "    keys = jax.random.split(rng_state, seq_len - 1)\n",
       "    _, states = jax.lax.scan(draw_state, initial_state, keys)\n",
       "    states = jnp.append(initial_state, states)\n",
       "\n",
       "    def draw_obs(state, key):\n",
       "      return self._obs_dist.sample(seed=key)\n",
       "\n",
       "    keys = jax.random.split(rng_obs, seq_len)\n",
       "    obs_seq = jax.vmap(draw_obs, in_axes=(0, 0))(states, keys)\n",
       "\n",
       "    return states, obs_seq\n",
       "\n",
       "
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "print_source(hmm.sample)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let us check correctness by computing empirical pairwise statistics\n", "\n", "We will compute the number of i->j latent state transitions, and check that it is close to the true \n", "A[i,j] transition probabilites." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[225. 92.]\n", " [ 92. 90.]]\n", "[[0.7097792 0.29022083]\n", " [0.50549453 0.4945055 ]]\n" ] } ], "source": [ "\n", "\n", "import collections\n", "def compute_counts(state_seq, nstates):\n", " wseq = np.array(state_seq)\n", " word_pairs = [pair for pair in zip(wseq[:-1], wseq[1:])]\n", " counter_pairs = collections.Counter(word_pairs)\n", " counts = np.zeros((nstates, nstates))\n", " for (k,v) in counter_pairs.items():\n", " counts[k[0], k[1]] = v\n", " return counts\n", "\n", "\n", "def normalize(u, axis=0, eps=1e-15):\n", " u = jnp.where(u == 0, 0, jnp.where(u < eps, eps, u))\n", " c = u.sum(axis=axis)\n", " c = jnp.where(c == 0, 1, c)\n", " return u / c, c\n", "\n", "def normalize_counts(counts):\n", " ncounts = vmap(lambda v: normalize(v)[0], in_axes=0)(counts)\n", " return ncounts\n", "\n", "init_dist = jnp.array([1.0, 0.0])\n", "trans_mat = jnp.array([[0.7, 0.3], [0.5, 0.5]])\n", "obs_mat = jnp.eye(2)\n", "\n", "hmm = HMM(trans_dist=distrax.Categorical(probs=trans_mat),\n", " init_dist=distrax.Categorical(probs=init_dist),\n", " obs_dist=distrax.Categorical(probs=obs_mat))\n", "\n", "rng_key = jax.random.PRNGKey(0)\n", "seq_len = 500\n", "state_seq, _ = hmm.sample(seed=PRNGKey(seed), seq_len=seq_len)\n", "\n", "counts = compute_counts(state_seq, nstates=2)\n", "print(counts)\n", "\n", "trans_mat_empirical = normalize_counts(counts)\n", "print(trans_mat_empirical)\n", "\n", "assert jnp.allclose(trans_mat, trans_mat_empirical, atol=1e-1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Our primary goal will be to infer the latent state from the observations,\n", "so we can detect if the casino is being dishonest or not. This will\n", "affect how we choose to gamble our money.\n", "We discuss various ways to perform this inference below." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "(sec:lillypad)=\n", "## Example: Lillypad HMM\n", "\n", "\n", "If $\\obs_t$ is continuous, it is common to use a Gaussian\n", "observation model:\n", "```{math}\n", "p(\\obs_t|\\hidden_t=j) = \\gauss(\\obs_t|\\vmu_j,\\vSigma_j)\n", "```\n", "This is sometimes called a Gaussian HMM.\n", "\n", "As a simple example, suppose we have an HMM with 3 hidden states,\n", "each of which generates a 2d Gaussian.\n", "We can represent these Gaussian distributions are 2d ellipses,\n", "as we show below.\n", "We call these ``lilly pads'', because of their shape.\n", "We can imagine a frog hopping from one lilly pad to another.\n", "(This analogy is due to the late Sam Roweis.)\n", "The frog will stay on a pad for a while (corresponding to remaining in the same\n", "discrete state $\\hidden_t$), and then jump to a new pad\n", "(corresponding to a transition to a new state).\n", "The data we see are just the 2d points (e.g., water droplets)\n", "coming from near the pad that the frog is currently on.\n", "Thus this model is like a Gaussian mixture model,\n", "in that it generates clusters of observations,\n", "except now there is temporal correlation between the data points.\n", "\n", "Let us now illustrate this model in code.\n", "\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "# Let us create the model\n", "\n", "initial_probs = jnp.array([0.3, 0.2, 0.5])\n", "\n", "# transition matrix\n", "A = jnp.array([\n", "[0.3, 0.4, 0.3],\n", "[0.1, 0.6, 0.3],\n", "[0.2, 0.3, 0.5]\n", "])\n", "\n", "# Observation model\n", "mu_collection = jnp.array([\n", "[0.3, 0.3],\n", "[0.8, 0.5],\n", "[0.3, 0.8]\n", "])\n", "\n", "S1 = jnp.array([[1.1, 0], [0, 0.3]])\n", "S2 = jnp.array([[0.3, -0.5], [-0.5, 1.3]])\n", "S3 = jnp.array([[0.8, 0.4], [0.4, 0.5]])\n", "cov_collection = jnp.array([S1, S2, S3]) / 60\n", "\n", "\n", "import tensorflow_probability as tfp\n", "\n", "if False:\n", " hmm = HMM(trans_dist=distrax.Categorical(probs=A),\n", " init_dist=distrax.Categorical(probs=initial_probs),\n", " obs_dist=distrax.MultivariateNormalFullCovariance(\n", " loc=mu_collection, covariance_matrix=cov_collection))\n", "else:\n", " hmm = HMM(trans_dist=distrax.Categorical(probs=A),\n", " init_dist=distrax.Categorical(probs=initial_probs),\n", " obs_dist=distrax.as_distribution(\n", " tfp.substrates.jax.distributions.MultivariateNormalFullCovariance(loc=mu_collection,\n", " covariance_matrix=cov_collection)))\n", "\n", "print(hmm)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(50,)\n", "(50, 2)\n" ] } ], "source": [ "\n", "n_samples, seed = 50, 10\n", "samples_state, samples_obs = hmm.sample(seed=PRNGKey(seed), seq_len=n_samples)\n", "\n", "print(samples_state.shape)\n", "print(samples_obs.shape)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "filenames": { "image/png": "/Users/kpmurphy/github/ssm-book/_build/jupyter_execute/chapters/ssm/hmm_17_0.png" }, "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "\n", "# Let's plot the observed data in 2d\n", "xmin, xmax = 0, 1\n", "ymin, ymax = 0, 1.2\n", "colors = [\"tab:green\", \"tab:blue\", \"tab:red\"]\n", "\n", "def plot_2dhmm(hmm, samples_obs, samples_state, colors, ax, xmin, xmax, ymin, ymax, step=1e-2):\n", " obs_dist = hmm.obs_dist\n", " color_sample = [colors[i] for i in samples_state]\n", "\n", " xs = jnp.arange(xmin, xmax, step)\n", " ys = jnp.arange(ymin, ymax, step)\n", "\n", " v_prob = vmap(lambda x, y: obs_dist.prob(jnp.array([x, y])), in_axes=(None, 0))\n", " z = vmap(v_prob, in_axes=(0, None))(xs, ys)\n", "\n", " grid = np.mgrid[xmin:xmax:step, ymin:ymax:step]\n", "\n", " for k, color in enumerate(colors):\n", " ax.contour(*grid, z[:, :, k], levels=[1], colors=color, linewidths=3)\n", " ax.text(*(obs_dist.mean()[k] + 0.13), f\"$k$={k + 1}\", fontsize=13, horizontalalignment=\"right\")\n", "\n", " ax.plot(*samples_obs.T, c=\"black\", alpha=0.3, zorder=1)\n", " ax.scatter(*samples_obs.T, c=color_sample, s=30, zorder=2, alpha=0.8)\n", "\n", " return ax, color_sample\n", "\n", "\n", "fig, ax = plt.subplots()\n", "_, color_sample = plot_2dhmm(hmm, samples_obs, samples_state, colors, ax, xmin, xmax, ymin, ymax)\n" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "filenames": { "image/png": "/Users/kpmurphy/github/ssm-book/_build/jupyter_execute/chapters/ssm/hmm_18_1.png" }, "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Let's plot the hidden state sequence\n", "\n", "fig, ax = plt.subplots()\n", "ax.step(range(n_samples), samples_state, where=\"post\", c=\"black\", linewidth=1, alpha=0.3)\n", "ax.scatter(range(n_samples), samples_state, c=color_sample, zorder=3)\n" ] } ], "metadata": { "interpreter": { "hash": "6407c60499271029b671b4ff687c4ed4626355c45fd34c44476827f4be42c4d7" }, "kernelspec": { "display_name": "Python 3.9.2 ('spyder-dev')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }