|
@@ -4,171 +4,8 @@
|
|
|
"cell_type": "markdown",
|
|
|
"metadata": {},
|
|
|
"source": [
|
|
|
- "```{math}\n",
|
|
|
- "\n",
|
|
|
- "\\newcommand{\\defeq}{\\triangleq}\n",
|
|
|
- "\\newcommand{\\trans}{{\\mkern-1.5mu\\mathsf{T}}}\n",
|
|
|
- "\\newcommand{\\transpose}[1]{{#1}^{\\trans}}\n",
|
|
|
- "\n",
|
|
|
- "\\newcommand{\\inv}[1]{{#1}^{-1}}\n",
|
|
|
- "\\DeclareMathOperator{\\dotstar}{\\odot}\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "\\newcommand\\floor[1]{\\lfloor#1\\rfloor}\n",
|
|
|
- "\n",
|
|
|
- "\\newcommand{\\real}{\\mathbb{R}}\n",
|
|
|
- "\n",
|
|
|
- "% Numbers\n",
|
|
|
- "\\newcommand{\\vzero}{\\boldsymbol{0}}\n",
|
|
|
- "\\newcommand{\\vone}{\\boldsymbol{1}}\n",
|
|
|
- "\n",
|
|
|
- "% Greek https://www.latex-tutorial.com/symbols/greek-alphabet/\n",
|
|
|
- "\\newcommand{\\valpha}{\\boldsymbol{\\alpha}}\n",
|
|
|
- "\\newcommand{\\vbeta}{\\boldsymbol{\\beta}}\n",
|
|
|
- "\\newcommand{\\vchi}{\\boldsymbol{\\chi}}\n",
|
|
|
- "\\newcommand{\\vdelta}{\\boldsymbol{\\delta}}\n",
|
|
|
- "\\newcommand{\\vDelta}{\\boldsymbol{\\Delta}}\n",
|
|
|
- "\\newcommand{\\vepsilon}{\\boldsymbol{\\epsilon}}\n",
|
|
|
- "\\newcommand{\\vzeta}{\\boldsymbol{\\zeta}}\n",
|
|
|
- "\\newcommand{\\vXi}{\\boldsymbol{\\Xi}}\n",
|
|
|
- "\\newcommand{\\vell}{\\boldsymbol{\\ell}}\n",
|
|
|
- "\\newcommand{\\veta}{\\boldsymbol{\\eta}}\n",
|
|
|
- "%\\newcommand{\\vEta}{\\boldsymbol{\\Eta}}\n",
|
|
|
- "\\newcommand{\\vgamma}{\\boldsymbol{\\gamma}}\n",
|
|
|
- "\\newcommand{\\vGamma}{\\boldsymbol{\\Gamma}}\n",
|
|
|
- "\\newcommand{\\vmu}{\\boldsymbol{\\mu}}\n",
|
|
|
- "\\newcommand{\\vmut}{\\boldsymbol{\\tilde{\\mu}}}\n",
|
|
|
- "\\newcommand{\\vnu}{\\boldsymbol{\\nu}}\n",
|
|
|
- "\\newcommand{\\vkappa}{\\boldsymbol{\\kappa}}\n",
|
|
|
- "\\newcommand{\\vlambda}{\\boldsymbol{\\lambda}}\n",
|
|
|
- "\\newcommand{\\vLambda}{\\boldsymbol{\\Lambda}}\n",
|
|
|
- "\\newcommand{\\vLambdaBar}{\\overline{\\vLambda}}\n",
|
|
|
- "%\\newcommand{\\vnu}{\\boldsymbol{\\nu}}\n",
|
|
|
- "\\newcommand{\\vomega}{\\boldsymbol{\\omega}}\n",
|
|
|
- "\\newcommand{\\vOmega}{\\boldsymbol{\\Omega}}\n",
|
|
|
- "\\newcommand{\\vphi}{\\boldsymbol{\\phi}}\n",
|
|
|
- "\\newcommand{\\vvarphi}{\\boldsymbol{\\varphi}}\n",
|
|
|
- "\\newcommand{\\vPhi}{\\boldsymbol{\\Phi}}\n",
|
|
|
- "\\newcommand{\\vpi}{\\boldsymbol{\\pi}}\n",
|
|
|
- "\\newcommand{\\vPi}{\\boldsymbol{\\Pi}}\n",
|
|
|
- "\\newcommand{\\vpsi}{\\boldsymbol{\\psi}}\n",
|
|
|
- "\\newcommand{\\vPsi}{\\boldsymbol{\\Psi}}\n",
|
|
|
- "\\newcommand{\\vrho}{\\boldsymbol{\\rho}}\n",
|
|
|
- "\\newcommand{\\vtheta}{\\boldsymbol{\\theta}}\n",
|
|
|
- "\\newcommand{\\vthetat}{\\boldsymbol{\\tilde{\\theta}}}\n",
|
|
|
- "\\newcommand{\\vTheta}{\\boldsymbol{\\Theta}}\n",
|
|
|
- "\\newcommand{\\vsigma}{\\boldsymbol{\\sigma}}\n",
|
|
|
- "\\newcommand{\\vSigma}{\\boldsymbol{\\Sigma}}\n",
|
|
|
- "\\newcommand{\\vSigmat}{\\boldsymbol{\\tilde{\\Sigma}}}\n",
|
|
|
- "\\newcommand{\\vsigmoid}{\\vsigma}\n",
|
|
|
- "\\newcommand{\\vtau}{\\boldsymbol{\\tau}}\n",
|
|
|
- "\\newcommand{\\vxi}{\\boldsymbol{\\xi}}\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "% Lower Roman (Vectors)\n",
|
|
|
- "\\newcommand{\\va}{\\mathbf{a}}\n",
|
|
|
- "\\newcommand{\\vb}{\\mathbf{b}}\n",
|
|
|
- "\\newcommand{\\vBt}{\\mathbf{\\tilde{B}}}\n",
|
|
|
- "\\newcommand{\\vc}{\\mathbf{c}}\n",
|
|
|
- "\\newcommand{\\vct}{\\mathbf{\\tilde{c}}}\n",
|
|
|
- "\\newcommand{\\vd}{\\mathbf{d}}\n",
|
|
|
- "\\newcommand{\\ve}{\\mathbf{e}}\n",
|
|
|
- "\\newcommand{\\vf}{\\mathbf{f}}\n",
|
|
|
- "\\newcommand{\\vg}{\\mathbf{g}}\n",
|
|
|
- "\\newcommand{\\vh}{\\mathbf{h}}\n",
|
|
|
- "%\\newcommand{\\myvh}{\\mathbf{h}}\n",
|
|
|
- "\\newcommand{\\vi}{\\mathbf{i}}\n",
|
|
|
- "\\newcommand{\\vj}{\\mathbf{j}}\n",
|
|
|
- "\\newcommand{\\vk}{\\mathbf{k}}\n",
|
|
|
- "\\newcommand{\\vl}{\\mathbf{l}}\n",
|
|
|
- "\\newcommand{\\vm}{\\mathbf{m}}\n",
|
|
|
- "\\newcommand{\\vn}{\\mathbf{n}}\n",
|
|
|
- "\\newcommand{\\vo}{\\mathbf{o}}\n",
|
|
|
- "\\newcommand{\\vp}{\\mathbf{p}}\n",
|
|
|
- "\\newcommand{\\vq}{\\mathbf{q}}\n",
|
|
|
- "\\newcommand{\\vr}{\\mathbf{r}}\n",
|
|
|
- "\\newcommand{\\vs}{\\mathbf{s}}\n",
|
|
|
- "\\newcommand{\\vt}{\\mathbf{t}}\n",
|
|
|
- "\\newcommand{\\vu}{\\mathbf{u}}\n",
|
|
|
- "\\newcommand{\\vv}{\\mathbf{v}}\n",
|
|
|
- "\\newcommand{\\vw}{\\mathbf{w}}\n",
|
|
|
- "\\newcommand{\\vws}{\\vw_s}\n",
|
|
|
- "\\newcommand{\\vwt}{\\mathbf{\\tilde{w}}}\n",
|
|
|
- "\\newcommand{\\vWt}{\\mathbf{\\tilde{W}}}\n",
|
|
|
- "\\newcommand{\\vwh}{\\hat{\\vw}}\n",
|
|
|
- "\\newcommand{\\vx}{\\mathbf{x}}\n",
|
|
|
- "%\\newcommand{\\vx}{\\mathbf{x}}\n",
|
|
|
- "\\newcommand{\\vxt}{\\mathbf{\\tilde{x}}}\n",
|
|
|
- "\\newcommand{\\vy}{\\mathbf{y}}\n",
|
|
|
- "\\newcommand{\\vyt}{\\mathbf{\\tilde{y}}}\n",
|
|
|
- "\\newcommand{\\vz}{\\mathbf{z}}\n",
|
|
|
- "%\\newcommand{\\vzt}{\\mathbf{\\tilde{z}}}\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "% Upper Roman (Matrices)\n",
|
|
|
- "\\newcommand{\\vA}{\\mathbf{A}}\n",
|
|
|
- "\\newcommand{\\vB}{\\mathbf{B}}\n",
|
|
|
- "\\newcommand{\\vC}{\\mathbf{C}}\n",
|
|
|
- "\\newcommand{\\vD}{\\mathbf{D}}\n",
|
|
|
- "\\newcommand{\\vE}{\\mathbf{E}}\n",
|
|
|
- "\\newcommand{\\vF}{\\mathbf{F}}\n",
|
|
|
- "\\newcommand{\\vG}{\\mathbf{G}}\n",
|
|
|
- "\\newcommand{\\vH}{\\mathbf{H}}\n",
|
|
|
- "\\newcommand{\\vI}{\\mathbf{I}}\n",
|
|
|
- "\\newcommand{\\vJ}{\\mathbf{J}}\n",
|
|
|
- "\\newcommand{\\vK}{\\mathbf{K}}\n",
|
|
|
- "\\newcommand{\\vL}{\\mathbf{L}}\n",
|
|
|
- "\\newcommand{\\vM}{\\mathbf{M}}\n",
|
|
|
- "\\newcommand{\\vMt}{\\mathbf{\\tilde{M}}}\n",
|
|
|
- "\\newcommand{\\vN}{\\mathbf{N}}\n",
|
|
|
- "\\newcommand{\\vO}{\\mathbf{O}}\n",
|
|
|
- "\\newcommand{\\vP}{\\mathbf{P}}\n",
|
|
|
- "\\newcommand{\\vQ}{\\mathbf{Q}}\n",
|
|
|
- "\\newcommand{\\vR}{\\mathbf{R}}\n",
|
|
|
- "\\newcommand{\\vS}{\\mathbf{S}}\n",
|
|
|
- "\\newcommand{\\vT}{\\mathbf{T}}\n",
|
|
|
- "\\newcommand{\\vU}{\\mathbf{U}}\n",
|
|
|
- "\\newcommand{\\vV}{\\mathbf{V}}\n",
|
|
|
- "\\newcommand{\\vW}{\\mathbf{W}}\n",
|
|
|
- "\\newcommand{\\vX}{\\mathbf{X}}\n",
|
|
|
- "%\\newcommand{\\vXs}{\\vX_{\\vs}}\n",
|
|
|
- "\\newcommand{\\vXs}{\\vX_{s}}\n",
|
|
|
- "\\newcommand{\\vXt}{\\mathbf{\\tilde{X}}}\n",
|
|
|
- "\\newcommand{\\vY}{\\mathbf{Y}}\n",
|
|
|
- "\\newcommand{\\vZ}{\\mathbf{Z}}\n",
|
|
|
- "\\newcommand{\\vZt}{\\mathbf{\\tilde{Z}}}\n",
|
|
|
- "\\newcommand{\\vzt}{\\mathbf{\\tilde{z}}}\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "%%%%\n",
|
|
|
- "\\newcommand{\\hidden}{\\vz}\n",
|
|
|
- "\\newcommand{\\hid}{\\hidden}\n",
|
|
|
- "\\newcommand{\\observed}{\\vy}\n",
|
|
|
- "\\newcommand{\\obs}{\\observed}\n",
|
|
|
- "\\newcommand{\\inputs}{\\vu}\n",
|
|
|
- "\\newcommand{\\input}{\\inputs}\n",
|
|
|
- "\n",
|
|
|
- "\\newcommand{\\hmmTrans}{\\vA}\n",
|
|
|
- "\\newcommand{\\hmmObs}{\\vB}\n",
|
|
|
- "\\newcommand{\\hmmInit}{\\vpi}\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "\\newcommand{\\ldsDyn}{\\vA}\n",
|
|
|
- "\\newcommand{\\ldsObs}{\\vC}\n",
|
|
|
- "\\newcommand{\\ldsDynIn}{\\vB}\n",
|
|
|
- "\\newcommand{\\ldsObsIn}{\\vD}\n",
|
|
|
- "\\newcommand{\\ldsDynNoise}{\\vQ}\n",
|
|
|
- "\\newcommand{\\ldsObsNoise}{\\vR}\n",
|
|
|
- "\n",
|
|
|
- "\\newcommand{\\ssmDynFn}{f}\n",
|
|
|
- "\\newcommand{\\ssmObsFn}{h}\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "%%%\n",
|
|
|
- "\\newcommand{\\gauss}{\\mathcal{N}}\n",
|
|
|
- "\n",
|
|
|
- "\\newcommand{\\diag}{\\mathrm{diag}}\n",
|
|
|
- "```\n"
|
|
|
+ "(sec:forwards)=\n",
|
|
|
+ "# HMM filtering (forwards algorithm)"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
@@ -177,141 +14,85 @@
|
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
|
- "# meta-data does not work yet in VScode\n",
|
|
|
- "# https://github.com/microsoft/vscode-jupyter/issues/1121\n",
|
|
|
- "\n",
|
|
|
- "{\n",
|
|
|
- " \"tags\": [\n",
|
|
|
- " \"hide-cell\"\n",
|
|
|
- " ]\n",
|
|
|
- "}\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "### Install necessary libraries\n",
|
|
|
- "\n",
|
|
|
- "try:\n",
|
|
|
- " import jax\n",
|
|
|
- "except:\n",
|
|
|
- " # For cuda version, see https://github.com/google/jax#installation\n",
|
|
|
- " %pip install --upgrade \"jax[cpu]\" \n",
|
|
|
- " import jax\n",
|
|
|
- "\n",
|
|
|
- "try:\n",
|
|
|
- " import distrax\n",
|
|
|
- "except:\n",
|
|
|
- " %pip install --upgrade distrax\n",
|
|
|
- " import distrax\n",
|
|
|
- "\n",
|
|
|
- "try:\n",
|
|
|
- " import jsl\n",
|
|
|
- "except:\n",
|
|
|
- " %pip install git+https://github.com/probml/jsl\n",
|
|
|
- " import jsl\n",
|
|
|
- "\n",
|
|
|
- "#try:\n",
|
|
|
- "# import ssm_jax\n",
|
|
|
- "##except:\n",
|
|
|
- "# %pip install git+https://github.com/probml/ssm-jax\n",
|
|
|
- "# import ssm_jax\n",
|
|
|
- "\n",
|
|
|
- "try:\n",
|
|
|
- " import rich\n",
|
|
|
- "except:\n",
|
|
|
- " %pip install rich\n",
|
|
|
- " import rich\n",
|
|
|
- "\n",
|
|
|
- "\n"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 4,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "{\n",
|
|
|
- " \"tags\": [\n",
|
|
|
- " \"hide-cell\"\n",
|
|
|
- " ]\n",
|
|
|
- "}\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
"### Import standard libraries\n",
|
|
|
"\n",
|
|
|
"import abc\n",
|
|
|
"from dataclasses import dataclass\n",
|
|
|
"import functools\n",
|
|
|
+ "from functools import partial\n",
|
|
|
"import itertools\n",
|
|
|
- "\n",
|
|
|
- "from typing import Any, Callable, NamedTuple, Optional, Union, Tuple\n",
|
|
|
- "\n",
|
|
|
"import matplotlib.pyplot as plt\n",
|
|
|
"import numpy as np\n",
|
|
|
- "\n",
|
|
|
+ "from typing import Any, Callable, NamedTuple, Optional, Union, Tuple\n",
|
|
|
"\n",
|
|
|
"import jax\n",
|
|
|
"import jax.numpy as jnp\n",
|
|
|
"from jax import lax, vmap, jit, grad\n",
|
|
|
- "from jax.scipy.special import logit\n",
|
|
|
- "from jax.nn import softmax\n",
|
|
|
- "from functools import partial\n",
|
|
|
- "from jax.random import PRNGKey, split\n",
|
|
|
+ "#from jax.scipy.special import logit\n",
|
|
|
+ "#from jax.nn import softmax\n",
|
|
|
+ "import jax.random as jr\n",
|
|
|
"\n",
|
|
|
- "import inspect\n",
|
|
|
- "import inspect as py_inspect\n",
|
|
|
- "import rich\n",
|
|
|
- "from rich import inspect as r_inspect\n",
|
|
|
- "from rich import print as r_print\n",
|
|
|
"\n",
|
|
|
- "def print_source(fname):\n",
|
|
|
- " r_print(py_inspect.getsource(fname))"
|
|
|
+ "\n",
|
|
|
+ "import distrax\n",
|
|
|
+ "import optax\n",
|
|
|
+ "\n",
|
|
|
+ "import jsl\n",
|
|
|
+ "import ssm_jax"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "markdown",
|
|
|
"metadata": {},
|
|
|
"source": [
|
|
|
- "(sec:forwards)=\n",
|
|
|
- "# HMM filtering (forwards algorithm)\n",
|
|
|
+ "\n",
|
|
|
+ "## Introduction\n",
|
|
|
"\n",
|
|
|
"\n",
|
|
|
- "The **Bayes filter** is an algorithm for recursively computing\n",
|
|
|
+ "The $\\keyword{Bayes filter}$ is an algorithm for recursively computing\n",
|
|
|
"the belief state\n",
|
|
|
"$p(\\hidden_t|\\obs_{1:t})$ given\n",
|
|
|
"the prior belief from the previous step,\n",
|
|
|
"$p(\\hidden_{t-1}|\\obs_{1:t-1})$,\n",
|
|
|
"the new observation $\\obs_t$,\n",
|
|
|
"and the model.\n",
|
|
|
- "This can be done using **sequential Bayesian updating**.\n",
|
|
|
+ "This can be done using $\\keyword{sequential Bayesian updating}$.\n",
|
|
|
"For a dynamical model, this reduces to the\n",
|
|
|
- "**predict-update** cycle described below.\n",
|
|
|
+ "$\\keyword{predict-update}$ cycle described below.\n",
|
|
|
"\n",
|
|
|
- "\n",
|
|
|
- "The **prediction step** is just the **Chapman-Kolmogorov equation**:\n",
|
|
|
- "```{math}\n",
|
|
|
+ "The $\\keyword{prediction step}$ is just the $\\keyword{Chapman-Kolmogorov equation}$:\n",
|
|
|
+ "\\begin{align}\n",
|
|
|
"p(\\hidden_t|\\obs_{1:t-1})\n",
|
|
|
"= \\int p(\\hidden_t|\\hidden_{t-1}) p(\\hidden_{t-1}|\\obs_{1:t-1}) d\\hidden_{t-1}\n",
|
|
|
- "```\n",
|
|
|
+ "\\end{align}\n",
|
|
|
"The prediction step computes\n",
|
|
|
- "the one-step-ahead predictive distribution\n",
|
|
|
- "for the latent state, which updates\n",
|
|
|
- "the posterior from the previous time step into the prior\n",
|
|
|
+ "the $\\keyword{one-step-ahead predictive distribution}$\n",
|
|
|
+ "for the latent state, which converts\n",
|
|
|
+ "the posterior from the previous time step to become the prior\n",
|
|
|
"for the current step.\n",
|
|
|
"\n",
|
|
|
"\n",
|
|
|
- "The **update step**\n",
|
|
|
+ "The $\\keyword{update step}$\n",
|
|
|
"is just Bayes rule:\n",
|
|
|
- "```{math}\n",
|
|
|
+ "\\begin{align}\n",
|
|
|
"p(\\hidden_t|\\obs_{1:t}) = \\frac{1}{Z_t}\n",
|
|
|
"p(\\obs_t|\\hidden_t) p(\\hidden_t|\\obs_{1:t-1})\n",
|
|
|
- "```\n",
|
|
|
+ "\\end{align}\n",
|
|
|
"where the normalization constant is\n",
|
|
|
- "```{math}\n",
|
|
|
+ "\\begin{align}\n",
|
|
|
"Z_t = \\int p(\\obs_t|\\hidden_t) p(\\hidden_t|\\obs_{1:t-1}) d\\hidden_{t}\n",
|
|
|
"= p(\\obs_t|\\obs_{1:t-1})\n",
|
|
|
- "```\n",
|
|
|
+ "\\end{align}\n",
|
|
|
"\n",
|
|
|
+ "Note that we can derive the log marginal likelihood from these normalization constants\n",
|
|
|
+ "as follows:\n",
|
|
|
+ "```{math}\n",
|
|
|
+ ":label: eqn:logZ\n",
|
|
|
"\n",
|
|
|
+ "\\log p(\\obs_{1:T})\n",
|
|
|
+ "= \\sum_{t=1}^{T} \\log p(\\obs_t|\\obs_{1:t-1})\n",
|
|
|
+ "= \\sum_{t=1}^{T} \\log Z_t\n",
|
|
|
+ "```\n",
|
|
|
"\n"
|
|
|
]
|
|
|
},
|
|
@@ -323,10 +104,11 @@
|
|
|
"When the latent states $\\hidden_t$ are discrete, as in HMM,\n",
|
|
|
"the above integrals become sums.\n",
|
|
|
"In particular, suppose we define\n",
|
|
|
- "the belief state as $\\alpha_t(j) \\defeq p(\\hidden_t=j|\\obs_{1:t})$,\n",
|
|
|
- "the local evidence as $\\lambda_t(j) \\defeq p(\\obs_t|\\hidden_t=j)$,\n",
|
|
|
- "and the transition matrix\n",
|
|
|
- "$A(i,j) = p(\\hidden_t=j|\\hidden_{t-1}=i)$.\n",
|
|
|
+ "the $\\keyword{belief state}$ as $\\alpha_t(j) \\defeq p(\\hidden_t=j|\\obs_{1:t})$,\n",
|
|
|
+ "the $\\keyword{local evidence}$ (or $\\keyword{local likelihood}$)\n",
|
|
|
+ "as $\\lambda_t(j) \\defeq p(\\obs_t|\\hidden_t=j)$,\n",
|
|
|
+ "and the transition matrix as\n",
|
|
|
+ "$\\hmmTrans(i,j) = p(\\hidden_t=j|\\hidden_{t-1}=i)$.\n",
|
|
|
"Then the predict step becomes\n",
|
|
|
"```{math}\n",
|
|
|
":label: eqn:predictiveHMM\n",
|
|
@@ -338,7 +120,7 @@
|
|
|
":label: eqn:fwdsEqn\n",
|
|
|
"\\alpha_t(j)\n",
|
|
|
"= \\frac{1}{Z_t} \\lambda_t(j) \\alpha_{t|t-1}(j)\n",
|
|
|
- "= \\frac{1}{Z_t} \\lambda_t(j) \\left[\\sum_i \\alpha_{t-1}(i) A(i,j) \\right]\n",
|
|
|
+ "= \\frac{1}{Z_t} \\lambda_t(j) \\left[\\sum_i \\alpha_{t-1}(i) \\hmmTrans(i,j) \\right]\n",
|
|
|
"```\n",
|
|
|
"where\n",
|
|
|
"the normalization constant for each time step is given by\n",
|
|
@@ -350,20 +132,33 @@
|
|
|
"&= \\sum_{j=1}^K \\lambda_t(j) \\alpha_{t|t-1}(j)\n",
|
|
|
"\\end{align}\n",
|
|
|
"```\n",
|
|
|
+ "\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
"\n",
|
|
|
"Since all the quantities are finite length vectors and matrices,\n",
|
|
|
- "we can write the update equation\n",
|
|
|
- "in matrix-vector notation as follows:\n",
|
|
|
+ "we can implement the whole procedure using matrix vector multoplication:\n",
|
|
|
"```{math}\n",
|
|
|
+ ":label: eqn:fwdsAlgoMatrixForm\n",
|
|
|
"\\valpha_t =\\text{normalize}\\left(\n",
|
|
|
- "\\vlambda_t \\dotstar (\\vA^{\\trans} \\valpha_{t-1}) \\right)\n",
|
|
|
- "\\label{eqn:fwdsAlgoMatrixForm}\n",
|
|
|
+ "\\vlambda_t \\dotstar (\\hmmTrans^{\\trans} \\valpha_{t-1}) \\right)\n",
|
|
|
"```\n",
|
|
|
"where $\\dotstar$ represents\n",
|
|
|
"elementwise vector multiplication,\n",
|
|
|
- "and the $\\text{normalize}$ function just ensures its argument sums to one.\n",
|
|
|
+ "and the $\\text{normalize}$ function just ensures its argument sums to one.\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "## Example\n",
|
|
|
"\n",
|
|
|
- "In {ref}(sec:casino-inference)\n",
|
|
|
+ "In {ref}`sec:casino-inference`\n",
|
|
|
"we illustrate\n",
|
|
|
"filtering for the casino HMM,\n",
|
|
|
"applied to a random sequence $\\obs_{1:T}$ of length $T=300$.\n",
|
|
@@ -371,156 +166,209 @@
|
|
|
"based on the evidence seen so far.\n",
|
|
|
"The gray bars indicate time intervals during which the generative\n",
|
|
|
"process actually switched to the loaded dice.\n",
|
|
|
- "We see that the probability generally increases in the right places.\n"
|
|
|
+ "We see that the probability generally increases in the right places."
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "## Normalization constants\n",
|
|
|
+ "\n",
|
|
|
+ "In most publications on HMMs,\n",
|
|
|
+ "such as {cite}`Rabiner89`,\n",
|
|
|
+ "the forwards message is defined\n",
|
|
|
+ "as the following unnormalized joint probability:\n",
|
|
|
+ "```{math}\n",
|
|
|
+ "\\alpha'_t(j) = p(\\hidden_t=j,\\obs_{1:t}) \n",
|
|
|
+ "= \\lambda_t(j) \\left[\\sum_i \\alpha'_{t-1}(i) A(i,j) \\right]\n",
|
|
|
+ "```\n",
|
|
|
+ "In this book we define the forwards message as the normalized\n",
|
|
|
+ "conditional probability\n",
|
|
|
+ "```{math}\n",
|
|
|
+ "\\alpha_t(j) = p(\\hidden_t=j|\\obs_{1:t}) \n",
|
|
|
+ "= \\frac{1}{Z_t} \\lambda_t(j) \\left[\\sum_i \\alpha_{t-1}(i) A(i,j) \\right]\n",
|
|
|
+ "```\n",
|
|
|
+ "where $Z_t = p(\\obs_t|\\obs_{1:t-1})$.\n",
|
|
|
+ "\n",
|
|
|
+ "The \"traditional\" unnormalized form has several problems.\n",
|
|
|
+ "First, it rapidly suffers from numerical underflow,\n",
|
|
|
+ "since the probability of\n",
|
|
|
+ "the joint event that $(\\hidden_t=j,\\obs_{1:t})$\n",
|
|
|
+ "is vanishingly small. \n",
|
|
|
+ "To see why, suppose the observations are independent of the states.\n",
|
|
|
+ "In this case, the unnormalized joint has the form\n",
|
|
|
+ "\\begin{align}\n",
|
|
|
+ "p(\\hidden_t=j,\\obs_{1:t}) = p(\\hidden_t=j)\\prod_{i=1}^t p(\\obs_i)\n",
|
|
|
+ "\\end{align}\n",
|
|
|
+ "which becomes exponentially small with $t$, because we multiply\n",
|
|
|
+ "many probabilities which are less than one.\n",
|
|
|
+ "Second, the unnormalized probability is less interpretable,\n",
|
|
|
+ "since it is a joint distribution over states and observations,\n",
|
|
|
+ "rather than a conditional probability of states given observations.\n",
|
|
|
+ "Third, the unnormalized joint form is harder to approximate\n",
|
|
|
+ "than the normalized form.\n",
|
|
|
+ "Of course,\n",
|
|
|
+ "the two definitions only differ by a\n",
|
|
|
+ "multiplicative constant\n",
|
|
|
+ "{cite}`Devijver85`,\n",
|
|
|
+ "so the algorithmic difference is just\n",
|
|
|
+ "one line of code (namely the presence or absence of a call to the `normalize` function).\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "\n"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "markdown",
|
|
|
"metadata": {},
|
|
|
"source": [
|
|
|
- "Here is a JAX implementation of the forwards algorithm."
|
|
|
+ "## Naive implementation\n",
|
|
|
+ "\n",
|
|
|
+ "Below we give a simple numpy implementation of the forwards algorithm.\n",
|
|
|
+ "We assume the HMM uses categorical observations, for simplicity.\n",
|
|
|
+ "\n"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
- "execution_count": 5,
|
|
|
+ "execution_count": null,
|
|
|
"metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/html": [
|
|
|
- "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">@jit\n",
|
|
|
- "def <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">hmm_forwards_jax</span><span style=\"font-weight: bold\">(</span>params, obs_seq, <span style=\"color: #808000; text-decoration-color: #808000\">length</span>=<span style=\"color: #800080; text-decoration-color: #800080; font-style: italic\">None</span><span style=\"font-weight: bold\">)</span>:\n",
|
|
|
- " <span style=\"color: #008000; text-decoration-color: #008000\">''</span>'\n",
|
|
|
- " Calculates a belief state\n",
|
|
|
- "\n",
|
|
|
- " Parameters\n",
|
|
|
- " ----------\n",
|
|
|
- " params : HMMJax\n",
|
|
|
- " Hidden Markov Model\n",
|
|
|
- "\n",
|
|
|
- " obs_seq: <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">array</span><span style=\"font-weight: bold\">(</span>seq_len<span style=\"font-weight: bold\">)</span>\n",
|
|
|
- " History of observable events\n",
|
|
|
- "\n",
|
|
|
- " Returns\n",
|
|
|
- " -------\n",
|
|
|
- " * float\n",
|
|
|
- " The loglikelihood giving <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">log</span><span style=\"font-weight: bold\">(</span><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">p</span><span style=\"font-weight: bold\">(</span>x|model<span style=\"font-weight: bold\">))</span>\n",
|
|
|
- "\n",
|
|
|
- " * <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">array</span><span style=\"font-weight: bold\">(</span>seq_len, n_hidden<span style=\"font-weight: bold\">)</span> :\n",
|
|
|
- " All alpha values found for each sample\n",
|
|
|
- " <span style=\"color: #008000; text-decoration-color: #008000\">''</span>'\n",
|
|
|
- " seq_len = <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">len</span><span style=\"font-weight: bold\">(</span>obs_seq<span style=\"font-weight: bold\">)</span>\n",
|
|
|
- "\n",
|
|
|
- " if length is <span style=\"color: #800080; text-decoration-color: #800080; font-style: italic\">None</span>:\n",
|
|
|
- " length = seq_len\n",
|
|
|
- "\n",
|
|
|
- " trans_mat, obs_mat, init_dist = params.trans_mat, params.obs_mat, params.init_dist\n",
|
|
|
- "\n",
|
|
|
- " trans_mat = <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">jnp.array</span><span style=\"font-weight: bold\">(</span>trans_mat<span style=\"font-weight: bold\">)</span>\n",
|
|
|
- " obs_mat = <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">jnp.array</span><span style=\"font-weight: bold\">(</span>obs_mat<span style=\"font-weight: bold\">)</span>\n",
|
|
|
- " init_dist = <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">jnp.array</span><span style=\"font-weight: bold\">(</span>init_dist<span style=\"font-weight: bold\">)</span>\n",
|
|
|
- "\n",
|
|
|
- " n_states, n_obs = obs_mat.shape\n",
|
|
|
- "\n",
|
|
|
- " def <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">scan_fn</span><span style=\"font-weight: bold\">(</span>carry, t<span style=\"font-weight: bold\">)</span>:\n",
|
|
|
- " <span style=\"font-weight: bold\">(</span>alpha_prev, log_ll_prev<span style=\"font-weight: bold\">)</span> = carry\n",
|
|
|
- " alpha_n = <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">jnp.where</span><span style=\"font-weight: bold\">(</span>t < length,\n",
|
|
|
- " obs_mat<span style=\"font-weight: bold\">[</span>:, obs_seq<span style=\"font-weight: bold\">]</span> * <span style=\"font-weight: bold\">(</span>alpha_prev<span style=\"font-weight: bold\">[</span>:, <span style=\"color: #800080; text-decoration-color: #800080; font-style: italic\">None</span><span style=\"font-weight: bold\">]</span> * \n",
|
|
|
- "trans_mat<span style=\"font-weight: bold\">)</span><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">.sum</span><span style=\"font-weight: bold\">(</span><span style=\"color: #808000; text-decoration-color: #808000\">axis</span>=<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span><span style=\"font-weight: bold\">)</span>,\n",
|
|
|
- " <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">jnp.zeros_like</span><span style=\"font-weight: bold\">(</span>alpha_prev<span style=\"font-weight: bold\">))</span>\n",
|
|
|
- "\n",
|
|
|
- " alpha_n, cn = <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">normalize</span><span style=\"font-weight: bold\">(</span>alpha_n<span style=\"font-weight: bold\">)</span>\n",
|
|
|
- " carry = <span style=\"font-weight: bold\">(</span>alpha_n, <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">jnp.log</span><span style=\"font-weight: bold\">(</span>cn<span style=\"font-weight: bold\">)</span> + log_ll_prev<span style=\"font-weight: bold\">)</span>\n",
|
|
|
- "\n",
|
|
|
- " return carry, alpha_n\n",
|
|
|
- "\n",
|
|
|
- " # initial belief state\n",
|
|
|
- " alpha_0, c0 = <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">normalize</span><span style=\"font-weight: bold\">(</span>init_dist * obs_mat<span style=\"font-weight: bold\">[</span>:, obs_seq<span style=\"font-weight: bold\">[</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span><span style=\"font-weight: bold\">]])</span>\n",
|
|
|
- "\n",
|
|
|
- " # setup scan loop\n",
|
|
|
- " init_state = <span style=\"font-weight: bold\">(</span>alpha_0, <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">jnp.log</span><span style=\"font-weight: bold\">(</span>c0<span style=\"font-weight: bold\">))</span>\n",
|
|
|
- " ts = <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">jnp.arange</span><span style=\"font-weight: bold\">(</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span>, seq_len<span style=\"font-weight: bold\">)</span>\n",
|
|
|
- " carry, alpha_hist = <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">lax.scan</span><span style=\"font-weight: bold\">(</span>scan_fn, init_state, ts<span style=\"font-weight: bold\">)</span>\n",
|
|
|
- "\n",
|
|
|
- " # post-process\n",
|
|
|
- " alpha_hist = <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">jnp.vstack</span><span style=\"font-weight: bold\">()</span>\n",
|
|
|
- " <span style=\"font-weight: bold\">(</span>alpha_final, log_ll<span style=\"font-weight: bold\">)</span> = carry\n",
|
|
|
- " return log_ll, alpha_hist\n",
|
|
|
- "\n",
|
|
|
- "</pre>\n"
|
|
|
- ],
|
|
|
- "text/plain": [
|
|
|
- "@jit\n",
|
|
|
- "def \u001b[1;35mhmm_forwards_jax\u001b[0m\u001b[1m(\u001b[0mparams, obs_seq, \u001b[33mlength\u001b[0m=\u001b[3;35mNone\u001b[0m\u001b[1m)\u001b[0m:\n",
|
|
|
- " \u001b[32m''\u001b[0m'\n",
|
|
|
- " Calculates a belief state\n",
|
|
|
- "\n",
|
|
|
- " Parameters\n",
|
|
|
- " ----------\n",
|
|
|
- " params : HMMJax\n",
|
|
|
- " Hidden Markov Model\n",
|
|
|
- "\n",
|
|
|
- " obs_seq: \u001b[1;35marray\u001b[0m\u001b[1m(\u001b[0mseq_len\u001b[1m)\u001b[0m\n",
|
|
|
- " History of observable events\n",
|
|
|
- "\n",
|
|
|
- " Returns\n",
|
|
|
- " -------\n",
|
|
|
- " * float\n",
|
|
|
- " The loglikelihood giving \u001b[1;35mlog\u001b[0m\u001b[1m(\u001b[0m\u001b[1;35mp\u001b[0m\u001b[1m(\u001b[0mx|model\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m\n",
|
|
|
- "\n",
|
|
|
- " * \u001b[1;35marray\u001b[0m\u001b[1m(\u001b[0mseq_len, n_hidden\u001b[1m)\u001b[0m :\n",
|
|
|
- " All alpha values found for each sample\n",
|
|
|
- " \u001b[32m''\u001b[0m'\n",
|
|
|
- " seq_len = \u001b[1;35mlen\u001b[0m\u001b[1m(\u001b[0mobs_seq\u001b[1m)\u001b[0m\n",
|
|
|
- "\n",
|
|
|
- " if length is \u001b[3;35mNone\u001b[0m:\n",
|
|
|
- " length = seq_len\n",
|
|
|
- "\n",
|
|
|
- " trans_mat, obs_mat, init_dist = params.trans_mat, params.obs_mat, params.init_dist\n",
|
|
|
- "\n",
|
|
|
- " trans_mat = \u001b[1;35mjnp.array\u001b[0m\u001b[1m(\u001b[0mtrans_mat\u001b[1m)\u001b[0m\n",
|
|
|
- " obs_mat = \u001b[1;35mjnp.array\u001b[0m\u001b[1m(\u001b[0mobs_mat\u001b[1m)\u001b[0m\n",
|
|
|
- " init_dist = \u001b[1;35mjnp.array\u001b[0m\u001b[1m(\u001b[0minit_dist\u001b[1m)\u001b[0m\n",
|
|
|
- "\n",
|
|
|
- " n_states, n_obs = obs_mat.shape\n",
|
|
|
- "\n",
|
|
|
- " def \u001b[1;35mscan_fn\u001b[0m\u001b[1m(\u001b[0mcarry, t\u001b[1m)\u001b[0m:\n",
|
|
|
- " \u001b[1m(\u001b[0malpha_prev, log_ll_prev\u001b[1m)\u001b[0m = carry\n",
|
|
|
- " alpha_n = \u001b[1;35mjnp.where\u001b[0m\u001b[1m(\u001b[0mt < length,\n",
|
|
|
- " obs_mat\u001b[1m[\u001b[0m:, obs_seq\u001b[1m]\u001b[0m * \u001b[1m(\u001b[0malpha_prev\u001b[1m[\u001b[0m:, \u001b[3;35mNone\u001b[0m\u001b[1m]\u001b[0m * \n",
|
|
|
- "trans_mat\u001b[1m)\u001b[0m\u001b[1;35m.sum\u001b[0m\u001b[1m(\u001b[0m\u001b[33maxis\u001b[0m=\u001b[1;36m0\u001b[0m\u001b[1m)\u001b[0m,\n",
|
|
|
- " \u001b[1;35mjnp.zeros_like\u001b[0m\u001b[1m(\u001b[0malpha_prev\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m\n",
|
|
|
- "\n",
|
|
|
- " alpha_n, cn = \u001b[1;35mnormalize\u001b[0m\u001b[1m(\u001b[0malpha_n\u001b[1m)\u001b[0m\n",
|
|
|
- " carry = \u001b[1m(\u001b[0malpha_n, \u001b[1;35mjnp.log\u001b[0m\u001b[1m(\u001b[0mcn\u001b[1m)\u001b[0m + log_ll_prev\u001b[1m)\u001b[0m\n",
|
|
|
- "\n",
|
|
|
- " return carry, alpha_n\n",
|
|
|
- "\n",
|
|
|
- " # initial belief state\n",
|
|
|
- " alpha_0, c0 = \u001b[1;35mnormalize\u001b[0m\u001b[1m(\u001b[0minit_dist * obs_mat\u001b[1m[\u001b[0m:, obs_seq\u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m\n",
|
|
|
- "\n",
|
|
|
- " # setup scan loop\n",
|
|
|
- " init_state = \u001b[1m(\u001b[0malpha_0, \u001b[1;35mjnp.log\u001b[0m\u001b[1m(\u001b[0mc0\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m\n",
|
|
|
- " ts = \u001b[1;35mjnp.arange\u001b[0m\u001b[1m(\u001b[0m\u001b[1;36m1\u001b[0m, seq_len\u001b[1m)\u001b[0m\n",
|
|
|
- " carry, alpha_hist = \u001b[1;35mlax.scan\u001b[0m\u001b[1m(\u001b[0mscan_fn, init_state, ts\u001b[1m)\u001b[0m\n",
|
|
|
- "\n",
|
|
|
- " # post-process\n",
|
|
|
- " alpha_hist = \u001b[1;35mjnp.vstack\u001b[0m\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\n",
|
|
|
- " \u001b[1m(\u001b[0malpha_final, log_ll\u001b[1m)\u001b[0m = carry\n",
|
|
|
- " return log_ll, alpha_hist\n",
|
|
|
- "\n"
|
|
|
- ]
|
|
|
- },
|
|
|
- "metadata": {},
|
|
|
- "output_type": "display_data"
|
|
|
- }
|
|
|
- ],
|
|
|
+ "outputs": [],
|
|
|
"source": [
|
|
|
- "import jsl.hmm.hmm_lib as hmm_lib\n",
|
|
|
- "print_source(hmm_lib.hmm_forwards_jax)\n",
|
|
|
- "#https://github.com/probml/JSL/blob/main/jsl/hmm/hmm_lib.py#L189\n",
|
|
|
- "\n"
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "def normalize_np(u, axis=0, eps=1e-15):\n",
|
|
|
+ " u = np.where(u == 0, 0, np.where(u < eps, eps, u))\n",
|
|
|
+ " c = u.sum(axis=axis)\n",
|
|
|
+ " c = np.where(c == 0, 1, c)\n",
|
|
|
+ " return u / c, c\n",
|
|
|
+ "\n",
|
|
|
+ "def hmm_forwards_np(trans_mat, obs_mat, init_dist, obs_seq):\n",
|
|
|
+ " n_states, n_obs = obs_mat.shape\n",
|
|
|
+ " seq_len = len(obs_seq)\n",
|
|
|
+ "\n",
|
|
|
+ " alpha_hist = np.zeros((seq_len, n_states))\n",
|
|
|
+ " ll_hist = np.zeros(seq_len) # loglikelihood history\n",
|
|
|
+ "\n",
|
|
|
+ " alpha_n = init_dist * obs_mat[:, obs_seq[0]]\n",
|
|
|
+ " alpha_n, cn = normalize_np(alpha_n)\n",
|
|
|
+ "\n",
|
|
|
+ " alpha_hist[0] = alpha_n\n",
|
|
|
+ " log_normalizer = np.log(cn)\n",
|
|
|
+ "\n",
|
|
|
+ " for t in range(1, seq_len):\n",
|
|
|
+ " alpha_n = obs_mat[:, obs_seq[t]] * (alpha_n[:, None] * trans_mat).sum(axis=0)\n",
|
|
|
+ " alpha_n, zn = normalize_np(alpha_n)\n",
|
|
|
+ "\n",
|
|
|
+ " alpha_hist[t] = alpha_n\n",
|
|
|
+ " log_normalizer = np.log(zn) + log_normalizer\n",
|
|
|
+ "\n",
|
|
|
+ " return log_normalizer, alpha_hist"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "## Numerically stable implementation \n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "In practice it is more numerically stable to compute\n",
|
|
|
+ "the log likelihoods $\\ell_t(j) = \\log p(\\obs_t|\\hidden_t=j)$,\n",
|
|
|
+ "rather than the likelioods $\\lambda_t(j) = p(\\obs_t|\\hidden_t=j)$.\n",
|
|
|
+ "In this case, we can perform the posterior updating in a numerically stable way as follows.\n",
|
|
|
+ "Define $L_t = \\max_j \\ell_t(j)$ and\n",
|
|
|
+ "\\begin{align}\n",
|
|
|
+ "\\tilde{p}(\\hidden_t=j,\\obs_t|\\obs_{1:t-1})\n",
|
|
|
+ "&\\defeq p(\\hidden_t=j|\\obs_{1:t-1}) p(\\obs_t|\\hidden_t=j) e^{-L_t} \\\\\n",
|
|
|
+ " &= p(\\hidden_t=j|\\obs_{1:t-1}) e^{\\ell_t(j) - L_t}\n",
|
|
|
+ "\\end{align}\n",
|
|
|
+ "Then we have\n",
|
|
|
+ "\\begin{align}\n",
|
|
|
+ "p(\\hidden_t=j|\\obs_t,\\obs_{1:t-1})\n",
|
|
|
+ " &= \\frac{1}{\\tilde{Z}_t} \\tilde{p}(\\hidden_t=j,\\obs_t|\\obs_{1:t-1}) \\\\\n",
|
|
|
+ "\\tilde{Z}_t &= \\sum_j \\tilde{p}(\\hidden_t=j,\\obs_t|\\obs_{1:t-1})\n",
|
|
|
+ "= p(\\obs_t|\\obs_{1:t-1}) e^{-L_t} \\\\\n",
|
|
|
+ "\\log Z_t &= \\log p(\\obs_t|\\obs_{1:t-1}) = \\log \\tilde{Z}_t + L_t\n",
|
|
|
+ "\\end{align}\n",
|
|
|
+ "\n",
|
|
|
+ "Below we show some JAX code that implements this core operation.\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "\n",
|
|
|
+ "def _condition_on(probs, ll):\n",
|
|
|
+ " ll_max = ll.max()\n",
|
|
|
+ " new_probs = probs * jnp.exp(ll - ll_max)\n",
|
|
|
+ " norm = new_probs.sum()\n",
|
|
|
+ " new_probs /= norm\n",
|
|
|
+ " log_norm = jnp.log(norm) + ll_max\n",
|
|
|
+ " return new_probs, log_norm"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "With the above function, we can implement a more numerically stable version of the forwards filter,\n",
|
|
|
+ "that works for any likelihood function, as shown below. It takes in the prior predictive distribution,\n",
|
|
|
+ "$\\alpha_{t|t-1}$,\n",
|
|
|
+ "stored in `predicted_probs`, and conditions them on the log-likelihood for each time step $\\ell_t$ to get the\n",
|
|
|
+ "posterior, $\\alpha_t$, stored in `filtered_probs`,\n",
|
|
|
+ "which is then converted to the prediction for the next state, $\\alpha_{t+1|t}$."
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "def _predict(probs, A):\n",
|
|
|
+ " return A.T @ probs\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "def hmm_filter(initial_distribution,\n",
|
|
|
+ " transition_matrix,\n",
|
|
|
+ " log_likelihoods):\n",
|
|
|
+ " def _step(carry, t):\n",
|
|
|
+ " log_normalizer, predicted_probs = carry\n",
|
|
|
+ "\n",
|
|
|
+ " # Get parameters for time t\n",
|
|
|
+ " get = lambda x: x[t] if x.ndim == 3 else x\n",
|
|
|
+ " A = get(transition_matrix)\n",
|
|
|
+ " ll = log_likelihoods[t]\n",
|
|
|
+ "\n",
|
|
|
+ " # Condition on emissions at time t, being careful not to overflow\n",
|
|
|
+ " filtered_probs, log_norm = _condition_on(predicted_probs, ll)\n",
|
|
|
+ " # Update the log normalizer\n",
|
|
|
+ " log_normalizer += log_norm\n",
|
|
|
+ " # Predict the next state\n",
|
|
|
+ " predicted_probs = _predict(filtered_probs, A)\n",
|
|
|
+ "\n",
|
|
|
+ " return (log_normalizer, predicted_probs), (filtered_probs, predicted_probs)\n",
|
|
|
+ "\n",
|
|
|
+ " num_timesteps = len(log_likelihoods)\n",
|
|
|
+ " carry = (0.0, initial_distribution)\n",
|
|
|
+ " (log_normalizer, _), (filtered_probs, predicted_probs) = lax.scan(\n",
|
|
|
+ " _step, carry, jnp.arange(num_timesteps))\n",
|
|
|
+ " return log_normalizer, filtered_probs, predicted_probs"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "\n",
|
|
|
+ "TODO: check equivalence of these two implementations!"
|
|
|
]
|
|
|
}
|
|
|
],
|