hmm_filter.html 48 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867
  1. <!DOCTYPE html>
  2. <html>
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>HMM filtering (forwards algorithm) &#8212; State Space Models: A Modern Approach</title>
  7. <link href="../../_static/css/theme.css" rel="stylesheet">
  8. <link href="../../_static/css/index.ff1ffe594081f20da1ef19478df9384b.css" rel="stylesheet">
  9. <link rel="stylesheet"
  10. href="../../_static/vendor/fontawesome/5.13.0/css/all.min.css">
  11. <link rel="preload" as="font" type="font/woff2" crossorigin
  12. href="../../_static/vendor/fontawesome/5.13.0/webfonts/fa-solid-900.woff2">
  13. <link rel="preload" as="font" type="font/woff2" crossorigin
  14. href="../../_static/vendor/fontawesome/5.13.0/webfonts/fa-brands-400.woff2">
  15. <link rel="stylesheet" type="text/css" href="../../_static/pygments.css" />
  16. <link rel="stylesheet" type="text/css" href="../../_static/sphinx-book-theme.css?digest=c3fdc42140077d1ad13ad2f1588a4309" />
  17. <link rel="stylesheet" type="text/css" href="../../_static/togglebutton.css" />
  18. <link rel="stylesheet" type="text/css" href="../../_static/copybutton.css" />
  19. <link rel="stylesheet" type="text/css" href="../../_static/mystnb.css" />
  20. <link rel="stylesheet" type="text/css" href="../../_static/sphinx-thebe.css" />
  21. <link rel="stylesheet" type="text/css" href="../../_static/panels-main.c949a650a448cc0ae9fd3441c0e17fb0.css" />
  22. <link rel="stylesheet" type="text/css" href="../../_static/panels-variables.06eb56fa6e07937060861dad626602ad.css" />
  23. <link rel="preload" as="script" href="../../_static/js/index.be7d3bbb2ef33a8344ce.js">
  24. <script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
  25. <script src="../../_static/jquery.js"></script>
  26. <script src="../../_static/underscore.js"></script>
  27. <script src="../../_static/doctools.js"></script>
  28. <script src="../../_static/clipboard.min.js"></script>
  29. <script src="../../_static/copybutton.js"></script>
  30. <script>let toggleHintShow = 'Click to show';</script>
  31. <script>let toggleHintHide = 'Click to hide';</script>
  32. <script>let toggleOpenOnPrint = 'true';</script>
  33. <script src="../../_static/togglebutton.js"></script>
  34. <script>var togglebuttonSelector = '.toggle, .admonition.dropdown, .tag_hide_input div.cell_input, .tag_hide-input div.cell_input, .tag_hide_output div.cell_output, .tag_hide-output div.cell_output, .tag_hide_cell.cell, .tag_hide-cell.cell';</script>
  35. <script src="../../_static/sphinx-book-theme.d59cb220de22ca1c485ebbdc042f0030.js"></script>
  36. <script>const THEBE_JS_URL = "https://unpkg.com/thebe@0.8.2/lib/index.js"
  37. const thebe_selector = ".thebe,.cell"
  38. const thebe_selector_input = "pre"
  39. const thebe_selector_output = ".output, .cell_output"
  40. </script>
  41. <script async="async" src="../../_static/sphinx-thebe.js"></script>
  42. <script>window.MathJax = {"tex": {"macros": {"covMat": "\\boldsymbol{\\Sigma}", "data": "\\mathcal{D}", "defeq": "\\triangleq", "diag": "\\mathrm{diag}", "discreteState": "s", "dotstar": "\\odot", "dynamicsFn": "\\mathbf{f}", "floor": ["\\lfloor#1\\rfloor", 1], "gainMatrix": "\\mathbf{K}", "gainMatrixReverse": "\\mathbf{G}", "gauss": "\\mathcal{N}", "gaussInfo": "\\mathcal{N}_{\\text{info}}", "hidden": "\\mathbf{x}", "hiddenScalar": "x", "hmmInit": "\\boldsymbol{\\pi}", "hmmInitScalar": "\\pi", "hmmObs": "\\mathbf{B}", "hmmObsScalar": "B", "hmmTrans": "\\mathbf{A}", "hmmTransScalar": "A", "infoMat": "\\precMat", "input": "\\mathbf{u}", "inputs": "\\input", "inv": ["{#1}^{-1}", 1], "keyword": ["\\textbf{#1}", 1], "ldsDyn": "\\mathbf{F}", "ldsDynIn": "\\mathbf{B}", "initMean": "\\boldsymbol{\\mean}_0", "initCov": "\\boldsymbol{\\covMat}_0", "ldsObs": "\\mathbf{H}", "ldsObsIn": "\\mathbf{D}", "ldsTrans": "\\ldsDyn", "ldsTransIn": "\\ldsDynIn", "obsCov": "\\mathbf{R}", "obsNoise": "\\boldsymbol{r}", "map": "\\mathrm{map}", "measurementFn": "\\mathbf{h}", "mean": "\\boldsymbol{\\mu}", "mle": "\\mathrm{mle}", "nlatents": "n_x", "nhidden": "\\nlatents", "ninputs": "n_u", "nobs": "n_y", "nsymbols": "n_y", "nstates": "n_s", "obs": "\\mathbf{y}", "obsScalar": "y", "observed": "\\obs", "obsFn": "\\measurementFn", "params": "\\boldsymbol{\\theta}", "precMean": "\\boldsymbol{\\eta}", "precMat": "\\boldsymbol{\\Lambda}", "real": "\\mathbb{R}", "sigmoid": "\\sigma", "softmax": "\\boldsymbol{\\sigma}", "trans": "\\mathsf{T}", "transpose": ["{#1}^{\\trans}", 1], "transCov": "\\mathbf{Q}", "transNoise": "\\mathbf{q}", "valpha": "\\boldsymbol{\\alpha}", "vbeta": "\\boldsymbol{\\beta}", "vdelta": "\\boldsymbol{\\delta}", "vepsilon": "\\boldsymbol{\\epsilon}", "vlambda": "\\boldsymbol{\\lambda}", "vLambda": "\\boldsymbol{\\Lambda}", "vmu": "\\boldsymbol{\\mu}", "vpi": "\\boldsymbol{\\pi}", "vsigma": "\\boldsymbol{\\sigma}", "vSigma": "\\boldsymbol{\\Sigma}", "vone": "\\boldsymbol{1}", "vzero": "\\boldsymbol{0}"}}, "options": {"processHtmlClass": "tex2jax_process|mathjax_process|math|output_area"}}</script>
  43. <script defer="defer" src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
  44. <link rel="index" title="Index" href="../../genindex.html" />
  45. <link rel="search" title="Search" href="../../search.html" />
  46. <link rel="next" title="HMM smoothing (forwards-backwards algorithm)" href="hmm_smoother.html" />
  47. <link rel="prev" title="Hidden Markov Models" href="hmm_index.html" />
  48. <meta name="viewport" content="width=device-width, initial-scale=1" />
  49. <meta name="docsearch:language" content="None">
  50. <!-- Google Analytics -->
  51. </head>
  52. <body data-spy="scroll" data-target="#bd-toc-nav" data-offset="80">
  53. <div class="container-fluid" id="banner"></div>
  54. <div class="container-xl">
  55. <div class="row">
  56. <div class="col-12 col-md-3 bd-sidebar site-navigation show" id="site-navigation">
  57. <div class="navbar-brand-box">
  58. <a class="navbar-brand text-wrap" href="../../index.html">
  59. <h1 class="site-logo" id="site-title">State Space Models: A Modern Approach</h1>
  60. </a>
  61. </div><form class="bd-search d-flex align-items-center" action="../../search.html" method="get">
  62. <i class="icon fas fa-search"></i>
  63. <input type="search" class="form-control" name="q" id="search-input" placeholder="Search this book..." aria-label="Search this book..." autocomplete="off" >
  64. </form><nav class="bd-links" id="bd-docs-nav" aria-label="Main">
  65. <div class="bd-toc-item active">
  66. <ul class="nav bd-sidenav">
  67. <li class="toctree-l1">
  68. <a class="reference internal" href="../../root.html">
  69. State Space Models: A Modern Approach
  70. </a>
  71. </li>
  72. </ul>
  73. <ul class="current nav bd-sidenav">
  74. <li class="toctree-l1 has-children">
  75. <a class="reference internal" href="../ssm/ssm_index.html">
  76. State Space Models
  77. </a>
  78. <input class="toctree-checkbox" id="toctree-checkbox-1" name="toctree-checkbox-1" type="checkbox"/>
  79. <label for="toctree-checkbox-1">
  80. <i class="fas fa-chevron-down">
  81. </i>
  82. </label>
  83. <ul>
  84. <li class="toctree-l2">
  85. <a class="reference internal" href="../ssm/ssm_intro.html">
  86. What are State Space Models?
  87. </a>
  88. </li>
  89. <li class="toctree-l2">
  90. <a class="reference internal" href="../ssm/hmm.html">
  91. Hidden Markov Models
  92. </a>
  93. </li>
  94. <li class="toctree-l2">
  95. <a class="reference internal" href="../ssm/lds.html">
  96. Linear Gaussian SSMs
  97. </a>
  98. </li>
  99. <li class="toctree-l2">
  100. <a class="reference internal" href="../ssm/nlds.html">
  101. Nonlinear Gaussian SSMs
  102. </a>
  103. </li>
  104. <li class="toctree-l2">
  105. <a class="reference internal" href="../ssm/inference.html">
  106. States estimation (inference)
  107. </a>
  108. </li>
  109. <li class="toctree-l2">
  110. <a class="reference internal" href="../ssm/learning.html">
  111. Parameter estimation (learning)
  112. </a>
  113. </li>
  114. </ul>
  115. </li>
  116. <li class="toctree-l1 current active has-children">
  117. <a class="reference internal" href="hmm_index.html">
  118. Hidden Markov Models
  119. </a>
  120. <input checked="" class="toctree-checkbox" id="toctree-checkbox-2" name="toctree-checkbox-2" type="checkbox"/>
  121. <label for="toctree-checkbox-2">
  122. <i class="fas fa-chevron-down">
  123. </i>
  124. </label>
  125. <ul class="current">
  126. <li class="toctree-l2 current active">
  127. <a class="current reference internal" href="#">
  128. HMM filtering (forwards algorithm)
  129. </a>
  130. </li>
  131. <li class="toctree-l2">
  132. <a class="reference internal" href="hmm_smoother.html">
  133. HMM smoothing (forwards-backwards algorithm)
  134. </a>
  135. </li>
  136. <li class="toctree-l2">
  137. <a class="reference internal" href="hmm_viterbi.html">
  138. Viterbi algorithm
  139. </a>
  140. </li>
  141. <li class="toctree-l2">
  142. <a class="reference internal" href="hmm_parallel.html">
  143. Parallel HMM smoothing
  144. </a>
  145. </li>
  146. <li class="toctree-l2">
  147. <a class="reference internal" href="hmm_sampling.html">
  148. Forwards-filtering backwards-sampling algorithm
  149. </a>
  150. </li>
  151. </ul>
  152. </li>
  153. <li class="toctree-l1 has-children">
  154. <a class="reference internal" href="../lgssm/lgssm_index.html">
  155. Linear-Gaussian SSMs
  156. </a>
  157. <input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
  158. <label for="toctree-checkbox-3">
  159. <i class="fas fa-chevron-down">
  160. </i>
  161. </label>
  162. <ul>
  163. <li class="toctree-l2">
  164. <a class="reference internal" href="../lgssm/kalman_filter.html">
  165. Kalman filtering
  166. </a>
  167. </li>
  168. <li class="toctree-l2">
  169. <a class="reference internal" href="../lgssm/kalman_smoother.html">
  170. Kalman (RTS) smoother
  171. </a>
  172. </li>
  173. <li class="toctree-l2">
  174. <a class="reference internal" href="../lgssm/kalman_parallel.html">
  175. Parallel Kalman Smoother
  176. </a>
  177. </li>
  178. <li class="toctree-l2">
  179. <a class="reference internal" href="../lgssm/kalman_sampling.html">
  180. Forwards-filtering backwards sampling
  181. </a>
  182. </li>
  183. </ul>
  184. </li>
  185. <li class="toctree-l1 has-children">
  186. <a class="reference internal" href="../extended/extended_index.html">
  187. Extended (linearized) methods
  188. </a>
  189. <input class="toctree-checkbox" id="toctree-checkbox-4" name="toctree-checkbox-4" type="checkbox"/>
  190. <label for="toctree-checkbox-4">
  191. <i class="fas fa-chevron-down">
  192. </i>
  193. </label>
  194. <ul>
  195. <li class="toctree-l2">
  196. <a class="reference internal" href="../extended/extended_filter.html">
  197. Extended Kalman filtering
  198. </a>
  199. </li>
  200. <li class="toctree-l2">
  201. <a class="reference internal" href="../extended/extended_smoother.html">
  202. Extended Kalman smoother
  203. </a>
  204. </li>
  205. <li class="toctree-l2">
  206. <a class="reference internal" href="../extended/extended_parallel.html">
  207. Parallel extended Kalman smoothing
  208. </a>
  209. </li>
  210. </ul>
  211. </li>
  212. <li class="toctree-l1 has-children">
  213. <a class="reference internal" href="../unscented/unscented_index.html">
  214. Unscented methods
  215. </a>
  216. <input class="toctree-checkbox" id="toctree-checkbox-5" name="toctree-checkbox-5" type="checkbox"/>
  217. <label for="toctree-checkbox-5">
  218. <i class="fas fa-chevron-down">
  219. </i>
  220. </label>
  221. <ul>
  222. <li class="toctree-l2">
  223. <a class="reference internal" href="../unscented/unscented_filter.html">
  224. Unscented filtering
  225. </a>
  226. </li>
  227. <li class="toctree-l2">
  228. <a class="reference internal" href="../unscented/unscented_smoother.html">
  229. Unscented smoothing
  230. </a>
  231. </li>
  232. </ul>
  233. </li>
  234. <li class="toctree-l1">
  235. <a class="reference internal" href="../quadrature/quadrature_index.html">
  236. Quadrature and cubature methods
  237. </a>
  238. </li>
  239. <li class="toctree-l1">
  240. <a class="reference internal" href="../postlin/postlin_index.html">
  241. Posterior linearization
  242. </a>
  243. </li>
  244. <li class="toctree-l1">
  245. <a class="reference internal" href="../adf/adf_index.html">
  246. Assumed Density Filtering
  247. </a>
  248. </li>
  249. <li class="toctree-l1">
  250. <a class="reference internal" href="../vi/vi_index.html">
  251. Variational inference
  252. </a>
  253. </li>
  254. <li class="toctree-l1">
  255. <a class="reference internal" href="../pf/pf_index.html">
  256. Particle filtering
  257. </a>
  258. </li>
  259. <li class="toctree-l1">
  260. <a class="reference internal" href="../smc/smc_index.html">
  261. Sequential Monte Carlo
  262. </a>
  263. </li>
  264. <li class="toctree-l1">
  265. <a class="reference internal" href="../learning/learning_index.html">
  266. Offline parameter estimation (learning)
  267. </a>
  268. </li>
  269. <li class="toctree-l1">
  270. <a class="reference internal" href="../tracking/tracking_index.html">
  271. Multi-target tracking
  272. </a>
  273. </li>
  274. <li class="toctree-l1">
  275. <a class="reference internal" href="../ensemble/ensemble_index.html">
  276. Data assimilation using Ensemble Kalman filter
  277. </a>
  278. </li>
  279. <li class="toctree-l1">
  280. <a class="reference internal" href="../bnp/bnp_index.html">
  281. Bayesian non-parametric SSMs
  282. </a>
  283. </li>
  284. <li class="toctree-l1">
  285. <a class="reference internal" href="../changepoint/changepoint_index.html">
  286. Changepoint detection
  287. </a>
  288. </li>
  289. <li class="toctree-l1">
  290. <a class="reference internal" href="../timeseries/timeseries_index.html">
  291. Timeseries forecasting
  292. </a>
  293. </li>
  294. <li class="toctree-l1">
  295. <a class="reference internal" href="../gp/gp_index.html">
  296. Markovian Gaussian processes
  297. </a>
  298. </li>
  299. <li class="toctree-l1">
  300. <a class="reference internal" href="../ode/ode_index.html">
  301. Differential equations and SSMs
  302. </a>
  303. </li>
  304. <li class="toctree-l1">
  305. <a class="reference internal" href="../control/control_index.html">
  306. Optimal control
  307. </a>
  308. </li>
  309. <li class="toctree-l1">
  310. <a class="reference internal" href="../../bib.html">
  311. Bibliography
  312. </a>
  313. </li>
  314. </ul>
  315. </div>
  316. </nav> <!-- To handle the deprecated key -->
  317. <div class="navbar_extra_footer">
  318. Powered by <a href="https://jupyterbook.org">Jupyter Book</a>
  319. </div>
  320. </div>
  321. <main class="col py-md-3 pl-md-4 bd-content overflow-auto" role="main">
  322. <div class="topbar container-xl fixed-top">
  323. <div class="topbar-contents row">
  324. <div class="col-12 col-md-3 bd-topbar-whitespace site-navigation show"></div>
  325. <div class="col pl-md-4 topbar-main">
  326. <button id="navbar-toggler" class="navbar-toggler ml-0" type="button" data-toggle="collapse"
  327. data-toggle="tooltip" data-placement="bottom" data-target=".site-navigation" aria-controls="navbar-menu"
  328. aria-expanded="true" aria-label="Toggle navigation" aria-controls="site-navigation"
  329. title="Toggle navigation" data-toggle="tooltip" data-placement="left">
  330. <i class="fas fa-bars"></i>
  331. <i class="fas fa-arrow-left"></i>
  332. <i class="fas fa-arrow-up"></i>
  333. </button>
  334. <div class="dropdown-buttons-trigger">
  335. <button id="dropdown-buttons-trigger" class="btn btn-secondary topbarbtn" aria-label="Download this page"><i
  336. class="fas fa-download"></i></button>
  337. <div class="dropdown-buttons">
  338. <!-- ipynb file if we had a myst markdown file -->
  339. <!-- Download raw file -->
  340. <a class="dropdown-buttons" href="../../_sources/chapters/hmm/hmm_filter.ipynb"><button type="button"
  341. class="btn btn-secondary topbarbtn" title="Download source file" data-toggle="tooltip"
  342. data-placement="left">.ipynb</button></a>
  343. <!-- Download PDF via print -->
  344. <button type="button" id="download-print" class="btn btn-secondary topbarbtn" title="Print to PDF"
  345. onclick="printPdf(this)" data-toggle="tooltip" data-placement="left">.pdf</button>
  346. </div>
  347. </div>
  348. <!-- Source interaction buttons -->
  349. <div class="dropdown-buttons-trigger">
  350. <button id="dropdown-buttons-trigger" class="btn btn-secondary topbarbtn"
  351. aria-label="Connect with source repository"><i class="fab fa-github"></i></button>
  352. <div class="dropdown-buttons sourcebuttons">
  353. <a class="repository-button"
  354. href="https://github.com/probml/ssm-book"><button type="button" class="btn btn-secondary topbarbtn"
  355. data-toggle="tooltip" data-placement="left" title="Source repository"><i
  356. class="fab fa-github"></i>repository</button></a>
  357. <a class="issues-button"
  358. href="https://github.com/probml/ssm-book/issues/new?title=Issue%20on%20page%20%2Fchapters/hmm/hmm_filter.html&body=Your%20issue%20content%20here."><button
  359. type="button" class="btn btn-secondary topbarbtn" data-toggle="tooltip" data-placement="left"
  360. title="Open an issue"><i class="fas fa-lightbulb"></i>open issue</button></a>
  361. </div>
  362. </div>
  363. <!-- Full screen (wrap in <a> to have style consistency -->
  364. <a class="full-screen-button"><button type="button" class="btn btn-secondary topbarbtn" data-toggle="tooltip"
  365. data-placement="bottom" onclick="toggleFullScreen()" aria-label="Fullscreen mode"
  366. title="Fullscreen mode"><i
  367. class="fas fa-expand"></i></button></a>
  368. <!-- Launch buttons -->
  369. <div class="dropdown-buttons-trigger">
  370. <button id="dropdown-buttons-trigger" class="btn btn-secondary topbarbtn"
  371. aria-label="Launch interactive content"><i class="fas fa-rocket"></i></button>
  372. <div class="dropdown-buttons">
  373. <a class="binder-button" href="https://mybinder.org/v2/gh/probml/ssm-book/main?urlpath=tree/chapters/hmm/hmm_filter.ipynb"><button type="button"
  374. class="btn btn-secondary topbarbtn" title="Launch Binder" data-toggle="tooltip"
  375. data-placement="left"><img class="binder-button-logo"
  376. src="../../_static/images/logo_binder.svg"
  377. alt="Interact on binder">Binder</button></a>
  378. <a class="colab-button" href="https://colab.research.google.com/github/probml/ssm-book/blob/main/chapters/hmm/hmm_filter.ipynb"><button type="button" class="btn btn-secondary topbarbtn"
  379. title="Launch Colab" data-toggle="tooltip" data-placement="left"><img class="colab-button-logo"
  380. src="../../_static/images/logo_colab.png"
  381. alt="Interact on Colab">Colab</button></a>
  382. </div>
  383. </div>
  384. </div>
  385. <!-- Table of contents -->
  386. <div class="d-none d-md-block col-md-2 bd-toc show noprint">
  387. <div class="tocsection onthispage pt-5 pb-3">
  388. <i class="fas fa-list"></i> Contents
  389. </div>
  390. <nav id="bd-toc-nav" aria-label="Page">
  391. <ul class="visible nav section-nav flex-column">
  392. <li class="toc-h2 nav-item toc-entry">
  393. <a class="reference internal nav-link" href="#introduction">
  394. Introduction
  395. </a>
  396. </li>
  397. <li class="toc-h2 nav-item toc-entry">
  398. <a class="reference internal nav-link" href="#example">
  399. Example
  400. </a>
  401. </li>
  402. <li class="toc-h2 nav-item toc-entry">
  403. <a class="reference internal nav-link" href="#normalization-constants">
  404. Normalization constants
  405. </a>
  406. </li>
  407. <li class="toc-h2 nav-item toc-entry">
  408. <a class="reference internal nav-link" href="#naive-implementation">
  409. Naive implementation
  410. </a>
  411. </li>
  412. <li class="toc-h2 nav-item toc-entry">
  413. <a class="reference internal nav-link" href="#numerically-stable-implementation">
  414. Numerically stable implementation
  415. </a>
  416. </li>
  417. </ul>
  418. </nav>
  419. </div>
  420. </div>
  421. </div>
  422. <div id="main-content" class="row">
  423. <div class="col-12 col-md-9 pl-md-3 pr-md-0">
  424. <!-- Table of contents that is only displayed when printing the page -->
  425. <div id="jb-print-docs-body" class="onlyprint">
  426. <h1>HMM filtering (forwards algorithm)</h1>
  427. <!-- Table of contents -->
  428. <div id="print-main-content">
  429. <div id="jb-print-toc">
  430. <div>
  431. <h2> Contents </h2>
  432. </div>
  433. <nav aria-label="Page">
  434. <ul class="visible nav section-nav flex-column">
  435. <li class="toc-h2 nav-item toc-entry">
  436. <a class="reference internal nav-link" href="#introduction">
  437. Introduction
  438. </a>
  439. </li>
  440. <li class="toc-h2 nav-item toc-entry">
  441. <a class="reference internal nav-link" href="#example">
  442. Example
  443. </a>
  444. </li>
  445. <li class="toc-h2 nav-item toc-entry">
  446. <a class="reference internal nav-link" href="#normalization-constants">
  447. Normalization constants
  448. </a>
  449. </li>
  450. <li class="toc-h2 nav-item toc-entry">
  451. <a class="reference internal nav-link" href="#naive-implementation">
  452. Naive implementation
  453. </a>
  454. </li>
  455. <li class="toc-h2 nav-item toc-entry">
  456. <a class="reference internal nav-link" href="#numerically-stable-implementation">
  457. Numerically stable implementation
  458. </a>
  459. </li>
  460. </ul>
  461. </nav>
  462. </div>
  463. </div>
  464. </div>
  465. <div>
  466. <div class="tex2jax_ignore mathjax_ignore section" id="hmm-filtering-forwards-algorithm">
  467. <span id="sec-forwards"></span><h1>HMM filtering (forwards algorithm)<a class="headerlink" href="#hmm-filtering-forwards-algorithm" title="Permalink to this headline">¶</a></h1>
  468. <div class="cell docutils container">
  469. <div class="cell_input docutils container">
  470. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1">### Import standard libraries</span>
  471. <span class="kn">import</span> <span class="nn">abc</span>
  472. <span class="kn">from</span> <span class="nn">dataclasses</span> <span class="kn">import</span> <span class="n">dataclass</span>
  473. <span class="kn">import</span> <span class="nn">functools</span>
  474. <span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">partial</span>
  475. <span class="kn">import</span> <span class="nn">itertools</span>
  476. <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span>
  477. <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
  478. <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span><span class="p">,</span> <span class="n">Callable</span><span class="p">,</span> <span class="n">NamedTuple</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Union</span><span class="p">,</span> <span class="n">Tuple</span>
  479. <span class="kn">import</span> <span class="nn">jax</span>
  480. <span class="kn">import</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span>
  481. <span class="kn">from</span> <span class="nn">jax</span> <span class="kn">import</span> <span class="n">lax</span><span class="p">,</span> <span class="n">vmap</span><span class="p">,</span> <span class="n">jit</span><span class="p">,</span> <span class="n">grad</span>
  482. <span class="c1">#from jax.scipy.special import logit</span>
  483. <span class="c1">#from jax.nn import softmax</span>
  484. <span class="kn">import</span> <span class="nn">jax.random</span> <span class="k">as</span> <span class="nn">jr</span>
  485. <span class="kn">import</span> <span class="nn">distrax</span>
  486. <span class="kn">import</span> <span class="nn">optax</span>
  487. <span class="kn">import</span> <span class="nn">jsl</span>
  488. <span class="kn">import</span> <span class="nn">ssm_jax</span>
  489. </pre></div>
  490. </div>
  491. </div>
  492. </div>
  493. <div class="section" id="introduction">
  494. <h2>Introduction<a class="headerlink" href="#introduction" title="Permalink to this headline">¶</a></h2>
  495. <p>The <span class="math notranslate nohighlight">\(\keyword{Bayes filter}\)</span> is an algorithm for recursively computing
  496. the belief state
  497. <span class="math notranslate nohighlight">\(p(\hidden_t|\obs_{1:t})\)</span> given
  498. the prior belief from the previous step,
  499. <span class="math notranslate nohighlight">\(p(\hidden_{t-1}|\obs_{1:t-1})\)</span>,
  500. the new observation <span class="math notranslate nohighlight">\(\obs_t\)</span>,
  501. and the model.
  502. This can be done using <span class="math notranslate nohighlight">\(\keyword{sequential Bayesian updating}\)</span>.
  503. For a dynamical model, this reduces to the
  504. <span class="math notranslate nohighlight">\(\keyword{predict-update}\)</span> cycle described below.</p>
  505. <p>The <span class="math notranslate nohighlight">\(\keyword{prediction step}\)</span> is just the <span class="math notranslate nohighlight">\(\keyword{Chapman-Kolmogorov equation}\)</span>:</p>
  506. <div class="amsmath math notranslate nohighlight" id="equation-02a2ec0d-8c24-4270-a3c5-b68a6d04b9be">
  507. <span class="eqno">(21)<a class="headerlink" href="#equation-02a2ec0d-8c24-4270-a3c5-b68a6d04b9be" title="Permalink to this equation">¶</a></span>\[\begin{align}
  508. p(\hidden_t|\obs_{1:t-1})
  509. = \int p(\hidden_t|\hidden_{t-1}) p(\hidden_{t-1}|\obs_{1:t-1}) d\hidden_{t-1}
  510. \end{align}\]</div>
  511. <p>The prediction step computes
  512. the <span class="math notranslate nohighlight">\(\keyword{one-step-ahead predictive distribution}\)</span>
  513. for the latent state, which converts
  514. the posterior from the previous time step to become the prior
  515. for the current step.</p>
  516. <p>The <span class="math notranslate nohighlight">\(\keyword{update step}\)</span>
  517. is just Bayes rule:</p>
  518. <div class="amsmath math notranslate nohighlight" id="equation-be8910aa-7596-460f-9268-280f22fdce5e">
  519. <span class="eqno">(22)<a class="headerlink" href="#equation-be8910aa-7596-460f-9268-280f22fdce5e" title="Permalink to this equation">¶</a></span>\[\begin{align}
  520. p(\hidden_t|\obs_{1:t}) = \frac{1}{Z_t}
  521. p(\obs_t|\hidden_t) p(\hidden_t|\obs_{1:t-1})
  522. \end{align}\]</div>
  523. <p>where the normalization constant is</p>
  524. <div class="amsmath math notranslate nohighlight" id="equation-18e83263-589e-4859-84ea-9721705c1f4e">
  525. <span class="eqno">(23)<a class="headerlink" href="#equation-18e83263-589e-4859-84ea-9721705c1f4e" title="Permalink to this equation">¶</a></span>\[\begin{align}
  526. Z_t = \int p(\obs_t|\hidden_t) p(\hidden_t|\obs_{1:t-1}) d\hidden_{t}
  527. = p(\obs_t|\obs_{1:t-1})
  528. \end{align}\]</div>
  529. <p>Note that we can derive the log marginal likelihood from these normalization constants
  530. as follows:</p>
  531. <div class="math notranslate nohighlight" id="equation-eqn-logz">
  532. <span class="eqno">(24)<a class="headerlink" href="#equation-eqn-logz" title="Permalink to this equation">¶</a></span>\[\log p(\obs_{1:T})
  533. = \sum_{t=1}^{T} \log p(\obs_t|\obs_{1:t-1})
  534. = \sum_{t=1}^{T} \log Z_t\]</div>
  535. <p>When the latent states <span class="math notranslate nohighlight">\(\hidden_t\)</span> are discrete, as in HMM,
  536. the above integrals become sums.
  537. In particular, suppose we define
  538. the <span class="math notranslate nohighlight">\(\keyword{belief state}\)</span> as <span class="math notranslate nohighlight">\(\alpha_t(j) \defeq p(\hidden_t=j|\obs_{1:t})\)</span>,
  539. the <span class="math notranslate nohighlight">\(\keyword{local evidence}\)</span> (or <span class="math notranslate nohighlight">\(\keyword{local likelihood}\)</span>)
  540. as <span class="math notranslate nohighlight">\(\lambda_t(j) \defeq p(\obs_t|\hidden_t=j)\)</span>,
  541. and the transition matrix as
  542. <span class="math notranslate nohighlight">\(\hmmTrans(i,j) = p(\hidden_t=j|\hidden_{t-1}=i)\)</span>.
  543. Then the predict step becomes</p>
  544. <div class="math notranslate nohighlight" id="equation-eqn-predictivehmm">
  545. <span class="eqno">(25)<a class="headerlink" href="#equation-eqn-predictivehmm" title="Permalink to this equation">¶</a></span>\[\alpha_{t|t-1}(j) \defeq p(\hidden_t=j|\obs_{1:t-1})
  546. = \sum_i \alpha_{t-1}(i) A(i,j)\]</div>
  547. <p>and the update step becomes</p>
  548. <div class="math notranslate nohighlight" id="equation-eqn-fwdseqn">
  549. <span class="eqno">(26)<a class="headerlink" href="#equation-eqn-fwdseqn" title="Permalink to this equation">¶</a></span>\[\alpha_t(j)
  550. = \frac{1}{Z_t} \lambda_t(j) \alpha_{t|t-1}(j)
  551. = \frac{1}{Z_t} \lambda_t(j) \left[\sum_i \alpha_{t-1}(i) \hmmTrans(i,j) \right]\]</div>
  552. <p>where
  553. the normalization constant for each time step is given by</p>
  554. <div class="math notranslate nohighlight" id="equation-eqn-hmmz">
  555. <span class="eqno">(27)<a class="headerlink" href="#equation-eqn-hmmz" title="Permalink to this equation">¶</a></span>\[\begin{split}\begin{align}
  556. Z_t \defeq p(\obs_t|\obs_{1:t-1})
  557. &amp;= \sum_{j=1}^K p(\obs_t|\hidden_t=j) p(\hidden_t=j|\obs_{1:t-1}) \\
  558. &amp;= \sum_{j=1}^K \lambda_t(j) \alpha_{t|t-1}(j)
  559. \end{align}\end{split}\]</div>
  560. <p>Since all the quantities are finite length vectors and matrices,
  561. we can implement the whole procedure using matrix vector multoplication:</p>
  562. <div class="math notranslate nohighlight" id="equation-eqn-fwdsalgomatrixform">
  563. <span class="eqno">(28)<a class="headerlink" href="#equation-eqn-fwdsalgomatrixform" title="Permalink to this equation">¶</a></span>\[\valpha_t =\text{normalize}\left(
  564. \vlambda_t \dotstar (\hmmTrans^{\trans} \valpha_{t-1}) \right)\]</div>
  565. <p>where <span class="math notranslate nohighlight">\(\dotstar\)</span> represents
  566. elementwise vector multiplication,
  567. and the <span class="math notranslate nohighlight">\(\text{normalize}\)</span> function just ensures its argument sums to one.</p>
  568. </div>
  569. <div class="section" id="example">
  570. <h2>Example<a class="headerlink" href="#example" title="Permalink to this headline">¶</a></h2>
  571. <p>In <a class="reference internal" href="../ssm/inference.html#sec-casino-inference"><span class="std std-ref">Example: inference in the casino HMM</span></a>
  572. we illustrate
  573. filtering for the casino HMM,
  574. applied to a random sequence <span class="math notranslate nohighlight">\(\obs_{1:T}\)</span> of length <span class="math notranslate nohighlight">\(T=300\)</span>.
  575. In blue, we plot the probability that the dice is in the loaded (vs fair) state,
  576. based on the evidence seen so far.
  577. The gray bars indicate time intervals during which the generative
  578. process actually switched to the loaded dice.
  579. We see that the probability generally increases in the right places.</p>
  580. </div>
  581. <div class="section" id="normalization-constants">
  582. <h2>Normalization constants<a class="headerlink" href="#normalization-constants" title="Permalink to this headline">¶</a></h2>
  583. <p>In most publications on HMMs,
  584. such as <span id="id1">[<a class="reference internal" href="../../bib.html#id32" title="L. R. Rabiner. A tutorial on Hidden Markov Models and selected applications in speech recognition. Proc. of the IEEE, 77(2):257–286, 1989.">Rab89</a>]</span>,
  585. the forwards message is defined
  586. as the following unnormalized joint probability:</p>
  587. <div class="math notranslate nohighlight">
  588. \[\alpha'_t(j) = p(\hidden_t=j,\obs_{1:t})
  589. = \lambda_t(j) \left[\sum_i \alpha'_{t-1}(i) A(i,j) \right]\]</div>
  590. <p>In this book we define the forwards message as the normalized
  591. conditional probability</p>
  592. <div class="math notranslate nohighlight">
  593. \[\alpha_t(j) = p(\hidden_t=j|\obs_{1:t})
  594. = \frac{1}{Z_t} \lambda_t(j) \left[\sum_i \alpha_{t-1}(i) A(i,j) \right]\]</div>
  595. <p>where <span class="math notranslate nohighlight">\(Z_t = p(\obs_t|\obs_{1:t-1})\)</span>.</p>
  596. <p>The “traditional” unnormalized form has several problems.
  597. First, it rapidly suffers from numerical underflow,
  598. since the probability of
  599. the joint event that <span class="math notranslate nohighlight">\((\hidden_t=j,\obs_{1:t})\)</span>
  600. is vanishingly small.
  601. To see why, suppose the observations are independent of the states.
  602. In this case, the unnormalized joint has the form</p>
  603. <div class="amsmath math notranslate nohighlight" id="equation-2f9dfff7-1a23-4760-98f4-68f254fa0022">
  604. <span class="eqno">(29)<a class="headerlink" href="#equation-2f9dfff7-1a23-4760-98f4-68f254fa0022" title="Permalink to this equation">¶</a></span>\[\begin{align}
  605. p(\hidden_t=j,\obs_{1:t}) = p(\hidden_t=j)\prod_{i=1}^t p(\obs_i)
  606. \end{align}\]</div>
  607. <p>which becomes exponentially small with <span class="math notranslate nohighlight">\(t\)</span>, because we multiply
  608. many probabilities which are less than one.
  609. Second, the unnormalized probability is less interpretable,
  610. since it is a joint distribution over states and observations,
  611. rather than a conditional probability of states given observations.
  612. Third, the unnormalized joint form is harder to approximate
  613. than the normalized form.
  614. Of course,
  615. the two definitions only differ by a
  616. multiplicative constant
  617. <span id="id2">[<a class="reference internal" href="../../bib.html#id35" title="Pierre A Devijver. Baum's forward-backward algorithm revisited. Pattern Recognition Letters, 3(6):369–373, 1985.">Dev85</a>]</span>,
  618. so the algorithmic difference is just
  619. one line of code (namely the presence or absence of a call to the <code class="docutils literal notranslate"><span class="pre">normalize</span></code> function).</p>
  620. </div>
  621. <div class="section" id="naive-implementation">
  622. <h2>Naive implementation<a class="headerlink" href="#naive-implementation" title="Permalink to this headline">¶</a></h2>
  623. <p>Below we give a simple numpy implementation of the forwards algorithm.
  624. We assume the HMM uses categorical observations, for simplicity.</p>
  625. <div class="cell docutils container">
  626. <div class="cell_input docutils container">
  627. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">normalize_np</span><span class="p">(</span><span class="n">u</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-15</span><span class="p">):</span>
  628. <span class="n">u</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">u</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">u</span> <span class="o">&lt;</span> <span class="n">eps</span><span class="p">,</span> <span class="n">eps</span><span class="p">,</span> <span class="n">u</span><span class="p">))</span>
  629. <span class="n">c</span> <span class="o">=</span> <span class="n">u</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="n">axis</span><span class="p">)</span>
  630. <span class="n">c</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">c</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">c</span><span class="p">)</span>
  631. <span class="k">return</span> <span class="n">u</span> <span class="o">/</span> <span class="n">c</span><span class="p">,</span> <span class="n">c</span>
  632. <span class="k">def</span> <span class="nf">hmm_forwards_np</span><span class="p">(</span><span class="n">trans_mat</span><span class="p">,</span> <span class="n">obs_mat</span><span class="p">,</span> <span class="n">init_dist</span><span class="p">,</span> <span class="n">obs_seq</span><span class="p">):</span>
  633. <span class="n">n_states</span><span class="p">,</span> <span class="n">n_obs</span> <span class="o">=</span> <span class="n">obs_mat</span><span class="o">.</span><span class="n">shape</span>
  634. <span class="n">seq_len</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">obs_seq</span><span class="p">)</span>
  635. <span class="n">alpha_hist</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">n_states</span><span class="p">))</span>
  636. <span class="n">ll_hist</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">seq_len</span><span class="p">)</span> <span class="c1"># loglikelihood history</span>
  637. <span class="n">alpha_n</span> <span class="o">=</span> <span class="n">init_dist</span> <span class="o">*</span> <span class="n">obs_mat</span><span class="p">[:,</span> <span class="n">obs_seq</span><span class="p">[</span><span class="mi">0</span><span class="p">]]</span>
  638. <span class="n">alpha_n</span><span class="p">,</span> <span class="n">cn</span> <span class="o">=</span> <span class="n">normalize_np</span><span class="p">(</span><span class="n">alpha_n</span><span class="p">)</span>
  639. <span class="n">alpha_hist</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">alpha_n</span>
  640. <span class="n">log_normalizer</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">cn</span><span class="p">)</span>
  641. <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">):</span>
  642. <span class="n">alpha_n</span> <span class="o">=</span> <span class="n">obs_mat</span><span class="p">[:,</span> <span class="n">obs_seq</span><span class="p">[</span><span class="n">t</span><span class="p">]]</span> <span class="o">*</span> <span class="p">(</span><span class="n">alpha_n</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">trans_mat</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
  643. <span class="n">alpha_n</span><span class="p">,</span> <span class="n">zn</span> <span class="o">=</span> <span class="n">normalize_np</span><span class="p">(</span><span class="n">alpha_n</span><span class="p">)</span>
  644. <span class="n">alpha_hist</span><span class="p">[</span><span class="n">t</span><span class="p">]</span> <span class="o">=</span> <span class="n">alpha_n</span>
  645. <span class="n">log_normalizer</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">zn</span><span class="p">)</span> <span class="o">+</span> <span class="n">log_normalizer</span>
  646. <span class="k">return</span> <span class="n">log_normalizer</span><span class="p">,</span> <span class="n">alpha_hist</span>
  647. </pre></div>
  648. </div>
  649. </div>
  650. </div>
  651. </div>
  652. <div class="section" id="numerically-stable-implementation">
  653. <h2>Numerically stable implementation<a class="headerlink" href="#numerically-stable-implementation" title="Permalink to this headline">¶</a></h2>
  654. <p>In practice it is more numerically stable to compute
  655. the log likelihoods <span class="math notranslate nohighlight">\(\ell_t(j) = \log p(\obs_t|\hidden_t=j)\)</span>,
  656. rather than the likelioods <span class="math notranslate nohighlight">\(\lambda_t(j) = p(\obs_t|\hidden_t=j)\)</span>.
  657. In this case, we can perform the posterior updating in a numerically stable way as follows.
  658. Define <span class="math notranslate nohighlight">\(L_t = \max_j \ell_t(j)\)</span> and</p>
  659. <div class="amsmath math notranslate nohighlight" id="equation-1cf0a390-376b-4da3-873f-daf7e1489cbb">
  660. <span class="eqno">(30)<a class="headerlink" href="#equation-1cf0a390-376b-4da3-873f-daf7e1489cbb" title="Permalink to this equation">¶</a></span>\[\begin{align}
  661. \tilde{p}(\hidden_t=j,\obs_t|\obs_{1:t-1})
  662. &amp;\defeq p(\hidden_t=j|\obs_{1:t-1}) p(\obs_t|\hidden_t=j) e^{-L_t} \\
  663. &amp;= p(\hidden_t=j|\obs_{1:t-1}) e^{\ell_t(j) - L_t}
  664. \end{align}\]</div>
  665. <p>Then we have</p>
  666. <div class="amsmath math notranslate nohighlight" id="equation-562d5a1d-b724-4093-85d2-ec62f15c5421">
  667. <span class="eqno">(31)<a class="headerlink" href="#equation-562d5a1d-b724-4093-85d2-ec62f15c5421" title="Permalink to this equation">¶</a></span>\[\begin{align}
  668. p(\hidden_t=j|\obs_t,\obs_{1:t-1})
  669. &amp;= \frac{1}{\tilde{Z}_t} \tilde{p}(\hidden_t=j,\obs_t|\obs_{1:t-1}) \\
  670. \tilde{Z}_t &amp;= \sum_j \tilde{p}(\hidden_t=j,\obs_t|\obs_{1:t-1})
  671. = p(\obs_t|\obs_{1:t-1}) e^{-L_t} \\
  672. \log Z_t &amp;= \log p(\obs_t|\obs_{1:t-1}) = \log \tilde{Z}_t + L_t
  673. \end{align}\]</div>
  674. <p>Below we show some JAX code that implements this core operation.</p>
  675. <div class="cell docutils container">
  676. <div class="cell_input docutils container">
  677. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">_condition_on</span><span class="p">(</span><span class="n">probs</span><span class="p">,</span> <span class="n">ll</span><span class="p">):</span>
  678. <span class="n">ll_max</span> <span class="o">=</span> <span class="n">ll</span><span class="o">.</span><span class="n">max</span><span class="p">()</span>
  679. <span class="n">new_probs</span> <span class="o">=</span> <span class="n">probs</span> <span class="o">*</span> <span class="n">jnp</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">ll</span> <span class="o">-</span> <span class="n">ll_max</span><span class="p">)</span>
  680. <span class="n">norm</span> <span class="o">=</span> <span class="n">new_probs</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span>
  681. <span class="n">new_probs</span> <span class="o">/=</span> <span class="n">norm</span>
  682. <span class="n">log_norm</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">norm</span><span class="p">)</span> <span class="o">+</span> <span class="n">ll_max</span>
  683. <span class="k">return</span> <span class="n">new_probs</span><span class="p">,</span> <span class="n">log_norm</span>
  684. </pre></div>
  685. </div>
  686. </div>
  687. </div>
  688. <p>With the above function, we can implement a more numerically stable version of the forwards filter,
  689. that works for any likelihood function, as shown below. It takes in the prior predictive distribution,
  690. <span class="math notranslate nohighlight">\(\alpha_{t|t-1}\)</span>,
  691. stored in <code class="docutils literal notranslate"><span class="pre">predicted_probs</span></code>, and conditions them on the log-likelihood for each time step <span class="math notranslate nohighlight">\(\ell_t\)</span> to get the
  692. posterior, <span class="math notranslate nohighlight">\(\alpha_t\)</span>, stored in <code class="docutils literal notranslate"><span class="pre">filtered_probs</span></code>,
  693. which is then converted to the prediction for the next state, <span class="math notranslate nohighlight">\(\alpha_{t+1|t}\)</span>.</p>
  694. <div class="cell docutils container">
  695. <div class="cell_input docutils container">
  696. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">_predict</span><span class="p">(</span><span class="n">probs</span><span class="p">,</span> <span class="n">A</span><span class="p">):</span>
  697. <span class="k">return</span> <span class="n">A</span><span class="o">.</span><span class="n">T</span> <span class="o">@</span> <span class="n">probs</span>
  698. <span class="k">def</span> <span class="nf">hmm_filter</span><span class="p">(</span><span class="n">initial_distribution</span><span class="p">,</span>
  699. <span class="n">transition_matrix</span><span class="p">,</span>
  700. <span class="n">log_likelihoods</span><span class="p">):</span>
  701. <span class="k">def</span> <span class="nf">_step</span><span class="p">(</span><span class="n">carry</span><span class="p">,</span> <span class="n">t</span><span class="p">):</span>
  702. <span class="n">log_normalizer</span><span class="p">,</span> <span class="n">predicted_probs</span> <span class="o">=</span> <span class="n">carry</span>
  703. <span class="c1"># Get parameters for time t</span>
  704. <span class="n">get</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">x</span><span class="p">[</span><span class="n">t</span><span class="p">]</span> <span class="k">if</span> <span class="n">x</span><span class="o">.</span><span class="n">ndim</span> <span class="o">==</span> <span class="mi">3</span> <span class="k">else</span> <span class="n">x</span>
  705. <span class="n">A</span> <span class="o">=</span> <span class="n">get</span><span class="p">(</span><span class="n">transition_matrix</span><span class="p">)</span>
  706. <span class="n">ll</span> <span class="o">=</span> <span class="n">log_likelihoods</span><span class="p">[</span><span class="n">t</span><span class="p">]</span>
  707. <span class="c1"># Condition on emissions at time t, being careful not to overflow</span>
  708. <span class="n">filtered_probs</span><span class="p">,</span> <span class="n">log_norm</span> <span class="o">=</span> <span class="n">_condition_on</span><span class="p">(</span><span class="n">predicted_probs</span><span class="p">,</span> <span class="n">ll</span><span class="p">)</span>
  709. <span class="c1"># Update the log normalizer</span>
  710. <span class="n">log_normalizer</span> <span class="o">+=</span> <span class="n">log_norm</span>
  711. <span class="c1"># Predict the next state</span>
  712. <span class="n">predicted_probs</span> <span class="o">=</span> <span class="n">_predict</span><span class="p">(</span><span class="n">filtered_probs</span><span class="p">,</span> <span class="n">A</span><span class="p">)</span>
  713. <span class="k">return</span> <span class="p">(</span><span class="n">log_normalizer</span><span class="p">,</span> <span class="n">predicted_probs</span><span class="p">),</span> <span class="p">(</span><span class="n">filtered_probs</span><span class="p">,</span> <span class="n">predicted_probs</span><span class="p">)</span>
  714. <span class="n">num_timesteps</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">log_likelihoods</span><span class="p">)</span>
  715. <span class="n">carry</span> <span class="o">=</span> <span class="p">(</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">initial_distribution</span><span class="p">)</span>
  716. <span class="p">(</span><span class="n">log_normalizer</span><span class="p">,</span> <span class="n">_</span><span class="p">),</span> <span class="p">(</span><span class="n">filtered_probs</span><span class="p">,</span> <span class="n">predicted_probs</span><span class="p">)</span> <span class="o">=</span> <span class="n">lax</span><span class="o">.</span><span class="n">scan</span><span class="p">(</span>
  717. <span class="n">_step</span><span class="p">,</span> <span class="n">carry</span><span class="p">,</span> <span class="n">jnp</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">num_timesteps</span><span class="p">))</span>
  718. <span class="k">return</span> <span class="n">log_normalizer</span><span class="p">,</span> <span class="n">filtered_probs</span><span class="p">,</span> <span class="n">predicted_probs</span>
  719. </pre></div>
  720. </div>
  721. </div>
  722. </div>
  723. <p>TODO: check equivalence of these two implementations!</p>
  724. </div>
  725. </div>
  726. <script type="text/x-thebe-config">
  727. {
  728. requestKernel: true,
  729. binderOptions: {
  730. repo: "binder-examples/jupyter-stacks-datascience",
  731. ref: "master",
  732. },
  733. codeMirrorConfig: {
  734. theme: "abcdef",
  735. mode: "python"
  736. },
  737. kernelOptions: {
  738. kernelName: "python3",
  739. path: "./chapters/hmm"
  740. },
  741. predefinedOutput: true
  742. }
  743. </script>
  744. <script>kernelName = 'python3'</script>
  745. </div>
  746. <!-- Previous / next buttons -->
  747. <div class='prev-next-area'>
  748. <a class='left-prev' id="prev-link" href="hmm_index.html" title="previous page">
  749. <i class="fas fa-angle-left"></i>
  750. <div class="prev-next-info">
  751. <p class="prev-next-subtitle">previous</p>
  752. <p class="prev-next-title">Hidden Markov Models</p>
  753. </div>
  754. </a>
  755. <a class='right-next' id="next-link" href="hmm_smoother.html" title="next page">
  756. <div class="prev-next-info">
  757. <p class="prev-next-subtitle">next</p>
  758. <p class="prev-next-title">HMM smoothing (forwards-backwards algorithm)</p>
  759. </div>
  760. <i class="fas fa-angle-right"></i>
  761. </a>
  762. </div>
  763. </div>
  764. </div>
  765. <footer class="footer">
  766. <p>
  767. By Kevin Murphy, Scott Linderman, et al.<br/>
  768. &copy; Copyright 2021.<br/>
  769. </p>
  770. </footer>
  771. </main>
  772. </div>
  773. </div>
  774. <script src="../../_static/js/index.be7d3bbb2ef33a8344ce.js"></script>
  775. </body>
  776. </html>