hmm.html 71 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985
  1. <!DOCTYPE html>
  2. <html>
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>Hidden Markov Models &#8212; State Space Models: A Modern Approach</title>
  7. <link href="../../_static/css/theme.css" rel="stylesheet">
  8. <link href="../../_static/css/index.ff1ffe594081f20da1ef19478df9384b.css" rel="stylesheet">
  9. <link rel="stylesheet"
  10. href="../../_static/vendor/fontawesome/5.13.0/css/all.min.css">
  11. <link rel="preload" as="font" type="font/woff2" crossorigin
  12. href="../../_static/vendor/fontawesome/5.13.0/webfonts/fa-solid-900.woff2">
  13. <link rel="preload" as="font" type="font/woff2" crossorigin
  14. href="../../_static/vendor/fontawesome/5.13.0/webfonts/fa-brands-400.woff2">
  15. <link rel="stylesheet" type="text/css" href="../../_static/pygments.css" />
  16. <link rel="stylesheet" type="text/css" href="../../_static/sphinx-book-theme.css?digest=c3fdc42140077d1ad13ad2f1588a4309" />
  17. <link rel="stylesheet" type="text/css" href="../../_static/togglebutton.css" />
  18. <link rel="stylesheet" type="text/css" href="../../_static/copybutton.css" />
  19. <link rel="stylesheet" type="text/css" href="../../_static/mystnb.css" />
  20. <link rel="stylesheet" type="text/css" href="../../_static/sphinx-thebe.css" />
  21. <link rel="stylesheet" type="text/css" href="../../_static/panels-main.c949a650a448cc0ae9fd3441c0e17fb0.css" />
  22. <link rel="stylesheet" type="text/css" href="../../_static/panels-variables.06eb56fa6e07937060861dad626602ad.css" />
  23. <link rel="preload" as="script" href="../../_static/js/index.be7d3bbb2ef33a8344ce.js">
  24. <script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
  25. <script src="../../_static/jquery.js"></script>
  26. <script src="../../_static/underscore.js"></script>
  27. <script src="../../_static/doctools.js"></script>
  28. <script src="../../_static/clipboard.min.js"></script>
  29. <script src="../../_static/copybutton.js"></script>
  30. <script>let toggleHintShow = 'Click to show';</script>
  31. <script>let toggleHintHide = 'Click to hide';</script>
  32. <script>let toggleOpenOnPrint = 'true';</script>
  33. <script src="../../_static/togglebutton.js"></script>
  34. <script>var togglebuttonSelector = '.toggle, .admonition.dropdown, .tag_hide_input div.cell_input, .tag_hide-input div.cell_input, .tag_hide_output div.cell_output, .tag_hide-output div.cell_output, .tag_hide_cell.cell, .tag_hide-cell.cell';</script>
  35. <script src="../../_static/sphinx-book-theme.d59cb220de22ca1c485ebbdc042f0030.js"></script>
  36. <script>const THEBE_JS_URL = "https://unpkg.com/thebe@0.8.2/lib/index.js"
  37. const thebe_selector = ".thebe,.cell"
  38. const thebe_selector_input = "pre"
  39. const thebe_selector_output = ".output, .cell_output"
  40. </script>
  41. <script async="async" src="../../_static/sphinx-thebe.js"></script>
  42. <script>window.MathJax = {"tex": {"macros": {"covMat": "\\boldsymbol{\\Sigma}", "data": "\\mathcal{D}", "defeq": "\\triangleq", "diag": "\\mathrm{diag}", "discreteState": "s", "dotstar": "\\odot", "dynamicsFn": "\\mathbf{f}", "floor": ["\\lfloor#1\\rfloor", 1], "gainMatrix": "\\mathbf{K}", "gainMatrixReverse": "\\mathbf{G}", "gauss": "\\mathcal{N}", "gaussInfo": "\\mathcal{N}_{\\text{info}}", "hidden": "\\mathbf{x}", "hiddenScalar": "x", "hmmInit": "\\boldsymbol{\\pi}", "hmmInitScalar": "\\pi", "hmmObs": "\\mathbf{B}", "hmmObsScalar": "B", "hmmTrans": "\\mathbf{A}", "hmmTransScalar": "A", "infoMat": "\\precMat", "input": "\\mathbf{u}", "inputs": "\\input", "inv": ["{#1}^{-1}", 1], "keyword": ["\\textbf{#1}", 1], "ldsDyn": "\\mathbf{F}", "ldsDynIn": "\\mathbf{B}", "initMean": "\\boldsymbol{\\mean}_0", "initCov": "\\boldsymbol{\\covMat}_0", "ldsObs": "\\mathbf{H}", "ldsObsIn": "\\mathbf{D}", "ldsTrans": "\\ldsDyn", "ldsTransIn": "\\ldsDynIn", "obsCov": "\\mathbf{R}", "obsNoise": "\\boldsymbol{r}", "map": "\\mathrm{map}", "measurementFn": "\\mathbf{h}", "mean": "\\boldsymbol{\\mu}", "mle": "\\mathrm{mle}", "nlatents": "n_x", "nhidden": "\\nlatents", "ninputs": "n_u", "nobs": "n_y", "nsymbols": "n_y", "nstates": "n_s", "obs": "\\mathbf{y}", "obsScalar": "y", "observed": "\\obs", "obsFn": "\\measurementFn", "params": "\\boldsymbol{\\theta}", "precMean": "\\boldsymbol{\\eta}", "precMat": "\\boldsymbol{\\Lambda}", "real": "\\mathbb{R}", "sigmoid": "\\sigma", "softmax": "\\boldsymbol{\\sigma}", "trans": "\\mathsf{T}", "transpose": ["{#1}^{\\trans}", 1], "transCov": "\\mathbf{Q}", "transNoise": "\\mathbf{q}", "valpha": "\\boldsymbol{\\alpha}", "vbeta": "\\boldsymbol{\\beta}", "vdelta": "\\boldsymbol{\\delta}", "vepsilon": "\\boldsymbol{\\epsilon}", "vlambda": "\\boldsymbol{\\lambda}", "vLambda": "\\boldsymbol{\\Lambda}", "vmu": "\\boldsymbol{\\mu}", "vpi": "\\boldsymbol{\\pi}", "vsigma": "\\boldsymbol{\\sigma}", "vSigma": "\\boldsymbol{\\Sigma}", "vone": "\\boldsymbol{1}", "vzero": "\\boldsymbol{0}"}}, "options": {"processHtmlClass": "tex2jax_process|mathjax_process|math|output_area"}}</script>
  43. <script defer="defer" src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
  44. <link rel="index" title="Index" href="../../genindex.html" />
  45. <link rel="search" title="Search" href="../../search.html" />
  46. <link rel="next" title="Linear Gaussian SSMs" href="lds.html" />
  47. <link rel="prev" title="What are State Space Models?" href="ssm_intro.html" />
  48. <meta name="viewport" content="width=device-width, initial-scale=1" />
  49. <meta name="docsearch:language" content="None">
  50. <!-- Google Analytics -->
  51. </head>
  52. <body data-spy="scroll" data-target="#bd-toc-nav" data-offset="80">
  53. <div class="container-fluid" id="banner"></div>
  54. <div class="container-xl">
  55. <div class="row">
  56. <div class="col-12 col-md-3 bd-sidebar site-navigation show" id="site-navigation">
  57. <div class="navbar-brand-box">
  58. <a class="navbar-brand text-wrap" href="../../index.html">
  59. <h1 class="site-logo" id="site-title">State Space Models: A Modern Approach</h1>
  60. </a>
  61. </div><form class="bd-search d-flex align-items-center" action="../../search.html" method="get">
  62. <i class="icon fas fa-search"></i>
  63. <input type="search" class="form-control" name="q" id="search-input" placeholder="Search this book..." aria-label="Search this book..." autocomplete="off" >
  64. </form><nav class="bd-links" id="bd-docs-nav" aria-label="Main">
  65. <div class="bd-toc-item active">
  66. <ul class="nav bd-sidenav">
  67. <li class="toctree-l1">
  68. <a class="reference internal" href="../../root.html">
  69. State Space Models: A Modern Approach
  70. </a>
  71. </li>
  72. </ul>
  73. <ul class="current nav bd-sidenav">
  74. <li class="toctree-l1 current active has-children">
  75. <a class="reference internal" href="ssm_index.html">
  76. State Space Models
  77. </a>
  78. <input checked="" class="toctree-checkbox" id="toctree-checkbox-1" name="toctree-checkbox-1" type="checkbox"/>
  79. <label for="toctree-checkbox-1">
  80. <i class="fas fa-chevron-down">
  81. </i>
  82. </label>
  83. <ul class="current">
  84. <li class="toctree-l2">
  85. <a class="reference internal" href="ssm_intro.html">
  86. What are State Space Models?
  87. </a>
  88. </li>
  89. <li class="toctree-l2 current active">
  90. <a class="current reference internal" href="#">
  91. Hidden Markov Models
  92. </a>
  93. </li>
  94. <li class="toctree-l2">
  95. <a class="reference internal" href="lds.html">
  96. Linear Gaussian SSMs
  97. </a>
  98. </li>
  99. <li class="toctree-l2">
  100. <a class="reference internal" href="nlds.html">
  101. Nonlinear Gaussian SSMs
  102. </a>
  103. </li>
  104. <li class="toctree-l2">
  105. <a class="reference internal" href="inference.html">
  106. States estimation (inference)
  107. </a>
  108. </li>
  109. <li class="toctree-l2">
  110. <a class="reference internal" href="learning.html">
  111. Parameter estimation (learning)
  112. </a>
  113. </li>
  114. </ul>
  115. </li>
  116. <li class="toctree-l1 has-children">
  117. <a class="reference internal" href="../hmm/hmm_index.html">
  118. Hidden Markov Models
  119. </a>
  120. <input class="toctree-checkbox" id="toctree-checkbox-2" name="toctree-checkbox-2" type="checkbox"/>
  121. <label for="toctree-checkbox-2">
  122. <i class="fas fa-chevron-down">
  123. </i>
  124. </label>
  125. <ul>
  126. <li class="toctree-l2">
  127. <a class="reference internal" href="../hmm/hmm_filter.html">
  128. HMM filtering (forwards algorithm)
  129. </a>
  130. </li>
  131. <li class="toctree-l2">
  132. <a class="reference internal" href="../hmm/hmm_smoother.html">
  133. HMM smoothing (forwards-backwards algorithm)
  134. </a>
  135. </li>
  136. <li class="toctree-l2">
  137. <a class="reference internal" href="../hmm/hmm_viterbi.html">
  138. Viterbi algorithm
  139. </a>
  140. </li>
  141. <li class="toctree-l2">
  142. <a class="reference internal" href="../hmm/hmm_parallel.html">
  143. Parallel HMM smoothing
  144. </a>
  145. </li>
  146. <li class="toctree-l2">
  147. <a class="reference internal" href="../hmm/hmm_sampling.html">
  148. Forwards-filtering backwards-sampling algorithm
  149. </a>
  150. </li>
  151. </ul>
  152. </li>
  153. <li class="toctree-l1 has-children">
  154. <a class="reference internal" href="../lgssm/lgssm_index.html">
  155. Linear-Gaussian SSMs
  156. </a>
  157. <input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
  158. <label for="toctree-checkbox-3">
  159. <i class="fas fa-chevron-down">
  160. </i>
  161. </label>
  162. <ul>
  163. <li class="toctree-l2">
  164. <a class="reference internal" href="../lgssm/kalman_filter.html">
  165. Kalman filtering
  166. </a>
  167. </li>
  168. <li class="toctree-l2">
  169. <a class="reference internal" href="../lgssm/kalman_smoother.html">
  170. Kalman (RTS) smoother
  171. </a>
  172. </li>
  173. <li class="toctree-l2">
  174. <a class="reference internal" href="../lgssm/kalman_parallel.html">
  175. Parallel Kalman Smoother
  176. </a>
  177. </li>
  178. <li class="toctree-l2">
  179. <a class="reference internal" href="../lgssm/kalman_sampling.html">
  180. Forwards-filtering backwards sampling
  181. </a>
  182. </li>
  183. </ul>
  184. </li>
  185. <li class="toctree-l1 has-children">
  186. <a class="reference internal" href="../extended/extended_index.html">
  187. Extended (linearized) methods
  188. </a>
  189. <input class="toctree-checkbox" id="toctree-checkbox-4" name="toctree-checkbox-4" type="checkbox"/>
  190. <label for="toctree-checkbox-4">
  191. <i class="fas fa-chevron-down">
  192. </i>
  193. </label>
  194. <ul>
  195. <li class="toctree-l2">
  196. <a class="reference internal" href="../extended/extended_filter.html">
  197. Extended Kalman filtering
  198. </a>
  199. </li>
  200. <li class="toctree-l2">
  201. <a class="reference internal" href="../extended/extended_smoother.html">
  202. Extended Kalman smoother
  203. </a>
  204. </li>
  205. <li class="toctree-l2">
  206. <a class="reference internal" href="../extended/extended_parallel.html">
  207. Parallel extended Kalman smoothing
  208. </a>
  209. </li>
  210. </ul>
  211. </li>
  212. <li class="toctree-l1 has-children">
  213. <a class="reference internal" href="../unscented/unscented_index.html">
  214. Unscented methods
  215. </a>
  216. <input class="toctree-checkbox" id="toctree-checkbox-5" name="toctree-checkbox-5" type="checkbox"/>
  217. <label for="toctree-checkbox-5">
  218. <i class="fas fa-chevron-down">
  219. </i>
  220. </label>
  221. <ul>
  222. <li class="toctree-l2">
  223. <a class="reference internal" href="../unscented/unscented_filter.html">
  224. Unscented filtering
  225. </a>
  226. </li>
  227. <li class="toctree-l2">
  228. <a class="reference internal" href="../unscented/unscented_smoother.html">
  229. Unscented smoothing
  230. </a>
  231. </li>
  232. </ul>
  233. </li>
  234. <li class="toctree-l1">
  235. <a class="reference internal" href="../quadrature/quadrature_index.html">
  236. Quadrature and cubature methods
  237. </a>
  238. </li>
  239. <li class="toctree-l1">
  240. <a class="reference internal" href="../postlin/postlin_index.html">
  241. Posterior linearization
  242. </a>
  243. </li>
  244. <li class="toctree-l1">
  245. <a class="reference internal" href="../adf/adf_index.html">
  246. Assumed Density Filtering
  247. </a>
  248. </li>
  249. <li class="toctree-l1">
  250. <a class="reference internal" href="../vi/vi_index.html">
  251. Variational inference
  252. </a>
  253. </li>
  254. <li class="toctree-l1">
  255. <a class="reference internal" href="../pf/pf_index.html">
  256. Particle filtering
  257. </a>
  258. </li>
  259. <li class="toctree-l1">
  260. <a class="reference internal" href="../smc/smc_index.html">
  261. Sequential Monte Carlo
  262. </a>
  263. </li>
  264. <li class="toctree-l1">
  265. <a class="reference internal" href="../learning/learning_index.html">
  266. Offline parameter estimation (learning)
  267. </a>
  268. </li>
  269. <li class="toctree-l1">
  270. <a class="reference internal" href="../tracking/tracking_index.html">
  271. Multi-target tracking
  272. </a>
  273. </li>
  274. <li class="toctree-l1">
  275. <a class="reference internal" href="../ensemble/ensemble_index.html">
  276. Data assimilation using Ensemble Kalman filter
  277. </a>
  278. </li>
  279. <li class="toctree-l1">
  280. <a class="reference internal" href="../bnp/bnp_index.html">
  281. Bayesian non-parametric SSMs
  282. </a>
  283. </li>
  284. <li class="toctree-l1">
  285. <a class="reference internal" href="../changepoint/changepoint_index.html">
  286. Changepoint detection
  287. </a>
  288. </li>
  289. <li class="toctree-l1">
  290. <a class="reference internal" href="../timeseries/timeseries_index.html">
  291. Timeseries forecasting
  292. </a>
  293. </li>
  294. <li class="toctree-l1">
  295. <a class="reference internal" href="../gp/gp_index.html">
  296. Markovian Gaussian processes
  297. </a>
  298. </li>
  299. <li class="toctree-l1">
  300. <a class="reference internal" href="../ode/ode_index.html">
  301. Differential equations and SSMs
  302. </a>
  303. </li>
  304. <li class="toctree-l1">
  305. <a class="reference internal" href="../control/control_index.html">
  306. Optimal control
  307. </a>
  308. </li>
  309. <li class="toctree-l1">
  310. <a class="reference internal" href="../../bib.html">
  311. Bibliography
  312. </a>
  313. </li>
  314. </ul>
  315. </div>
  316. </nav> <!-- To handle the deprecated key -->
  317. <div class="navbar_extra_footer">
  318. Powered by <a href="https://jupyterbook.org">Jupyter Book</a>
  319. </div>
  320. </div>
  321. <main class="col py-md-3 pl-md-4 bd-content overflow-auto" role="main">
  322. <div class="topbar container-xl fixed-top">
  323. <div class="topbar-contents row">
  324. <div class="col-12 col-md-3 bd-topbar-whitespace site-navigation show"></div>
  325. <div class="col pl-md-4 topbar-main">
  326. <button id="navbar-toggler" class="navbar-toggler ml-0" type="button" data-toggle="collapse"
  327. data-toggle="tooltip" data-placement="bottom" data-target=".site-navigation" aria-controls="navbar-menu"
  328. aria-expanded="true" aria-label="Toggle navigation" aria-controls="site-navigation"
  329. title="Toggle navigation" data-toggle="tooltip" data-placement="left">
  330. <i class="fas fa-bars"></i>
  331. <i class="fas fa-arrow-left"></i>
  332. <i class="fas fa-arrow-up"></i>
  333. </button>
  334. <div class="dropdown-buttons-trigger">
  335. <button id="dropdown-buttons-trigger" class="btn btn-secondary topbarbtn" aria-label="Download this page"><i
  336. class="fas fa-download"></i></button>
  337. <div class="dropdown-buttons">
  338. <!-- ipynb file if we had a myst markdown file -->
  339. <!-- Download raw file -->
  340. <a class="dropdown-buttons" href="../../_sources/chapters/ssm/hmm.ipynb"><button type="button"
  341. class="btn btn-secondary topbarbtn" title="Download source file" data-toggle="tooltip"
  342. data-placement="left">.ipynb</button></a>
  343. <!-- Download PDF via print -->
  344. <button type="button" id="download-print" class="btn btn-secondary topbarbtn" title="Print to PDF"
  345. onclick="printPdf(this)" data-toggle="tooltip" data-placement="left">.pdf</button>
  346. </div>
  347. </div>
  348. <!-- Source interaction buttons -->
  349. <div class="dropdown-buttons-trigger">
  350. <button id="dropdown-buttons-trigger" class="btn btn-secondary topbarbtn"
  351. aria-label="Connect with source repository"><i class="fab fa-github"></i></button>
  352. <div class="dropdown-buttons sourcebuttons">
  353. <a class="repository-button"
  354. href="https://github.com/probml/ssm-book"><button type="button" class="btn btn-secondary topbarbtn"
  355. data-toggle="tooltip" data-placement="left" title="Source repository"><i
  356. class="fab fa-github"></i>repository</button></a>
  357. <a class="issues-button"
  358. href="https://github.com/probml/ssm-book/issues/new?title=Issue%20on%20page%20%2Fchapters/ssm/hmm.html&body=Your%20issue%20content%20here."><button
  359. type="button" class="btn btn-secondary topbarbtn" data-toggle="tooltip" data-placement="left"
  360. title="Open an issue"><i class="fas fa-lightbulb"></i>open issue</button></a>
  361. </div>
  362. </div>
  363. <!-- Full screen (wrap in <a> to have style consistency -->
  364. <a class="full-screen-button"><button type="button" class="btn btn-secondary topbarbtn" data-toggle="tooltip"
  365. data-placement="bottom" onclick="toggleFullScreen()" aria-label="Fullscreen mode"
  366. title="Fullscreen mode"><i
  367. class="fas fa-expand"></i></button></a>
  368. <!-- Launch buttons -->
  369. <div class="dropdown-buttons-trigger">
  370. <button id="dropdown-buttons-trigger" class="btn btn-secondary topbarbtn"
  371. aria-label="Launch interactive content"><i class="fas fa-rocket"></i></button>
  372. <div class="dropdown-buttons">
  373. <a class="binder-button" href="https://mybinder.org/v2/gh/probml/ssm-book/main?urlpath=tree/chapters/ssm/hmm.ipynb"><button type="button"
  374. class="btn btn-secondary topbarbtn" title="Launch Binder" data-toggle="tooltip"
  375. data-placement="left"><img class="binder-button-logo"
  376. src="../../_static/images/logo_binder.svg"
  377. alt="Interact on binder">Binder</button></a>
  378. <a class="colab-button" href="https://colab.research.google.com/github/probml/ssm-book/blob/main/chapters/ssm/hmm.ipynb"><button type="button" class="btn btn-secondary topbarbtn"
  379. title="Launch Colab" data-toggle="tooltip" data-placement="left"><img class="colab-button-logo"
  380. src="../../_static/images/logo_colab.png"
  381. alt="Interact on Colab">Colab</button></a>
  382. </div>
  383. </div>
  384. </div>
  385. <!-- Table of contents -->
  386. <div class="d-none d-md-block col-md-2 bd-toc show noprint">
  387. <div class="tocsection onthispage pt-5 pb-3">
  388. <i class="fas fa-list"></i> Contents
  389. </div>
  390. <nav id="bd-toc-nav" aria-label="Page">
  391. <ul class="visible nav section-nav flex-column">
  392. <li class="toc-h2 nav-item toc-entry">
  393. <a class="reference internal nav-link" href="#example-casino-hmm">
  394. Example: Casino HMM
  395. </a>
  396. </li>
  397. <li class="toc-h2 nav-item toc-entry">
  398. <a class="reference internal nav-link" href="#example-lillypad-hmm">
  399. Example: Lillypad HMM
  400. </a>
  401. </li>
  402. </ul>
  403. </nav>
  404. </div>
  405. </div>
  406. </div>
  407. <div id="main-content" class="row">
  408. <div class="col-12 col-md-9 pl-md-3 pr-md-0">
  409. <!-- Table of contents that is only displayed when printing the page -->
  410. <div id="jb-print-docs-body" class="onlyprint">
  411. <h1>Hidden Markov Models</h1>
  412. <!-- Table of contents -->
  413. <div id="print-main-content">
  414. <div id="jb-print-toc">
  415. <div>
  416. <h2> Contents </h2>
  417. </div>
  418. <nav aria-label="Page">
  419. <ul class="visible nav section-nav flex-column">
  420. <li class="toc-h2 nav-item toc-entry">
  421. <a class="reference internal nav-link" href="#example-casino-hmm">
  422. Example: Casino HMM
  423. </a>
  424. </li>
  425. <li class="toc-h2 nav-item toc-entry">
  426. <a class="reference internal nav-link" href="#example-lillypad-hmm">
  427. Example: Lillypad HMM
  428. </a>
  429. </li>
  430. </ul>
  431. </nav>
  432. </div>
  433. </div>
  434. </div>
  435. <div>
  436. <div class="tex2jax_ignore mathjax_ignore section" id="hidden-markov-models">
  437. <span id="sec-hmm-intro"></span><h1>Hidden Markov Models<a class="headerlink" href="#hidden-markov-models" title="Permalink to this headline">¶</a></h1>
  438. <p>In this section, we discuss the
  439. hidden Markov model or HMM,
  440. which is a state space model in which the hidden states
  441. are discrete, so <span class="math notranslate nohighlight">\(\hidden_t \in \{1,\ldots, \nstates\}\)</span>.
  442. The observations may be discrete,
  443. <span class="math notranslate nohighlight">\(\obs_t \in \{1,\ldots, \nsymbols\}\)</span>,
  444. or continuous,
  445. <span class="math notranslate nohighlight">\(\obs_t \in \real^\nstates\)</span>,
  446. or some combination,
  447. as we illustrate below.
  448. More details can be found in e.g.,
  449. <span id="id1">[<a class="reference internal" href="../../bib.html#id34" title="O. Cappe, E. Moulines, and T. Ryden. Inference in Hidden Markov Models. Springer, 2005.">CMR05</a>, <a class="reference internal" href="../../bib.html#id33" title="A. Fraser. Hidden Markov Models and Dynamical Systems. SIAM Press, 2008.">Fra08</a>, <a class="reference internal" href="../../bib.html#id32" title="L. R. Rabiner. A tutorial on Hidden Markov Models and selected applications in speech recognition. Proc. of the IEEE, 77(2):257–286, 1989.">Rab89</a>]</span>.
  450. For an interactive introduction,
  451. see <a class="reference external" href="https://nipunbatra.github.io/hmm/">https://nipunbatra.github.io/hmm/</a>.</p>
  452. <div class="cell docutils container">
  453. <div class="cell_input docutils container">
  454. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="p">{</span>
  455. <span class="s2">&quot;tags&quot;</span><span class="p">:</span> <span class="p">[</span>
  456. <span class="s2">&quot;hide-cell&quot;</span>
  457. <span class="p">]</span>
  458. <span class="p">}</span>
  459. <span class="c1">### Import standard libraries</span>
  460. <span class="kn">import</span> <span class="nn">abc</span>
  461. <span class="kn">from</span> <span class="nn">dataclasses</span> <span class="kn">import</span> <span class="n">dataclass</span>
  462. <span class="kn">import</span> <span class="nn">functools</span>
  463. <span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">partial</span>
  464. <span class="kn">import</span> <span class="nn">itertools</span>
  465. <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span>
  466. <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
  467. <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span><span class="p">,</span> <span class="n">Callable</span><span class="p">,</span> <span class="n">NamedTuple</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Union</span><span class="p">,</span> <span class="n">Tuple</span>
  468. <span class="kn">import</span> <span class="nn">jax</span>
  469. <span class="kn">import</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span>
  470. <span class="kn">from</span> <span class="nn">jax</span> <span class="kn">import</span> <span class="n">lax</span><span class="p">,</span> <span class="n">vmap</span><span class="p">,</span> <span class="n">jit</span><span class="p">,</span> <span class="n">grad</span>
  471. <span class="c1">#from jax.scipy.special import logit</span>
  472. <span class="c1">#from jax.nn import softmax</span>
  473. <span class="kn">import</span> <span class="nn">jax.random</span> <span class="k">as</span> <span class="nn">jr</span>
  474. <span class="kn">import</span> <span class="nn">distrax</span>
  475. <span class="kn">import</span> <span class="nn">optax</span>
  476. <span class="kn">import</span> <span class="nn">jsl</span>
  477. <span class="kn">import</span> <span class="nn">ssm_jax</span>
  478. <span class="kn">import</span> <span class="nn">inspect</span>
  479. <span class="kn">import</span> <span class="nn">inspect</span> <span class="k">as</span> <span class="nn">py_inspect</span>
  480. <span class="kn">import</span> <span class="nn">rich</span>
  481. <span class="kn">from</span> <span class="nn">rich</span> <span class="kn">import</span> <span class="n">inspect</span> <span class="k">as</span> <span class="n">r_inspect</span>
  482. <span class="kn">from</span> <span class="nn">rich</span> <span class="kn">import</span> <span class="nb">print</span> <span class="k">as</span> <span class="n">r_print</span>
  483. <span class="k">def</span> <span class="nf">print_source</span><span class="p">(</span><span class="n">fname</span><span class="p">):</span>
  484. <span class="n">r_print</span><span class="p">(</span><span class="n">py_inspect</span><span class="o">.</span><span class="n">getsource</span><span class="p">(</span><span class="n">fname</span><span class="p">))</span>
  485. </pre></div>
  486. </div>
  487. </div>
  488. </div>
  489. <div class="section" id="example-casino-hmm">
  490. <h2>Example: Casino HMM<a class="headerlink" href="#example-casino-hmm" title="Permalink to this headline">¶</a></h2>
  491. <p>To illustrate HMMs with categorical observation model,
  492. we consider the “Ocassionally dishonest casino” model from <span id="id2">[<a class="reference internal" href="../../bib.html#id3" title="R. Durbin, S. Eddy, A. Krogh, and G. Mitchison. Biological Sequence Analysis: Probabilistic Models of Proteins and Nucleic Acids. Cambridge University Press, 1998.">DEKM98</a>]</span>.
  493. There are 2 hidden states, representing whether the dice being used in the casino is fair or loaded.
  494. Each state defines a distribution over the 6 possible observations.</p>
  495. <p>The transition model is denoted by</p>
  496. <div class="math notranslate nohighlight">
  497. \[p(\hidden_t=j|\hidden_{t-1}=i) = \hmmTransScalar_{ij}\]</div>
  498. <p>Here the <span class="math notranslate nohighlight">\(i\)</span>’th row of <span class="math notranslate nohighlight">\(\hmmTrans\)</span> corresponds to the outgoing distribution from state <span class="math notranslate nohighlight">\(i\)</span>.
  499. This is a row stochastic matrix,
  500. meaning each row sums to one.
  501. We can visualize
  502. the non-zero entries in the transition matrix by creating a state transition diagram,
  503. as shown in
  504. <a class="reference internal" href="#fig-casino"><span class="std std-numref">Fig. 3</span></a>.</p>
  505. <div class="figure align-default" id="fig-casino">
  506. <a class="reference internal image-reference" href="../../_images/casino.png"><img alt="../../_images/casino.png" src="../../_images/casino.png" style="width: 208.5px; height: 142.5px;" /></a>
  507. <p class="caption"><span class="caption-number">Fig. 3 </span><span class="caption-text">Illustration of the casino HMM.</span><a class="headerlink" href="#fig-casino" title="Permalink to this image">¶</a></p>
  508. </div>
  509. <p>The observation model
  510. <span class="math notranslate nohighlight">\(p(\obs_t|\hidden_t=j)\)</span> has the form</p>
  511. <div class="math notranslate nohighlight">
  512. \[p(\obs_t=k|\hidden_t=j) = \hmmObsScalar_{jk} \]</div>
  513. <p>This is represented by the histograms associated with each
  514. state in <a class="reference internal" href="#fig-casino"><span class="std std-numref">Fig. 3</span></a>.</p>
  515. <p>Finally,
  516. the initial state distribution is denoted by</p>
  517. <div class="math notranslate nohighlight">
  518. \[p(\hidden_1=j) = \hmmInitScalar_j\]</div>
  519. <p>Collectively we denote all the parameters by <span class="math notranslate nohighlight">\(\params=(\hmmTrans, \hmmObs, \hmmInit)\)</span>.</p>
  520. <p>Now let us implement this model in code.</p>
  521. <div class="cell docutils container">
  522. <div class="cell_input docutils container">
  523. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># state transition matrix</span>
  524. <span class="n">A</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span>
  525. <span class="p">[</span><span class="mf">0.95</span><span class="p">,</span> <span class="mf">0.05</span><span class="p">],</span>
  526. <span class="p">[</span><span class="mf">0.10</span><span class="p">,</span> <span class="mf">0.90</span><span class="p">]</span>
  527. <span class="p">])</span>
  528. <span class="c1"># observation matrix</span>
  529. <span class="n">B</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span>
  530. <span class="p">[</span><span class="mi">1</span><span class="o">/</span><span class="mi">6</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">6</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">6</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">6</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">6</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">6</span><span class="p">],</span> <span class="c1"># fair die</span>
  531. <span class="p">[</span><span class="mi">1</span><span class="o">/</span><span class="mi">10</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">10</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">10</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">10</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">10</span><span class="p">,</span> <span class="mi">5</span><span class="o">/</span><span class="mi">10</span><span class="p">]</span> <span class="c1"># loaded die</span>
  532. <span class="p">])</span>
  533. <span class="n">pi</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">])</span>
  534. <span class="p">(</span><span class="n">nstates</span><span class="p">,</span> <span class="n">nobs</span><span class="p">)</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">B</span><span class="p">)</span>
  535. </pre></div>
  536. </div>
  537. </div>
  538. </div>
  539. <div class="cell docutils container">
  540. <div class="cell_input docutils container">
  541. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">distrax</span>
  542. <span class="kn">from</span> <span class="nn">distrax</span> <span class="kn">import</span> <span class="n">HMM</span>
  543. <span class="n">hmm</span> <span class="o">=</span> <span class="n">HMM</span><span class="p">(</span><span class="n">trans_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">Categorical</span><span class="p">(</span><span class="n">probs</span><span class="o">=</span><span class="n">A</span><span class="p">),</span>
  544. <span class="n">init_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">Categorical</span><span class="p">(</span><span class="n">probs</span><span class="o">=</span><span class="n">pi</span><span class="p">),</span>
  545. <span class="n">obs_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">Categorical</span><span class="p">(</span><span class="n">probs</span><span class="o">=</span><span class="n">B</span><span class="p">))</span>
  546. <span class="nb">print</span><span class="p">(</span><span class="n">hmm</span><span class="p">)</span>
  547. </pre></div>
  548. </div>
  549. </div>
  550. <div class="cell_output docutils container">
  551. <div class="output stderr highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
  552. </pre></div>
  553. </div>
  554. <div class="output stream highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>&lt;distrax._src.utils.hmm.HMM object at 0x7fbca906ad60&gt;
  555. </pre></div>
  556. </div>
  557. </div>
  558. </div>
  559. <p>Let’s sample from the model. We will generate a sequence of latent states, <span class="math notranslate nohighlight">\(\hidden_{1:T}\)</span>,
  560. which we then convert to a sequence of observations, <span class="math notranslate nohighlight">\(\obs_{1:T}\)</span>.</p>
  561. <div class="cell docutils container">
  562. <div class="cell_input docutils container">
  563. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">seed</span> <span class="o">=</span> <span class="mi">314</span>
  564. <span class="n">n_samples</span> <span class="o">=</span> <span class="mi">300</span>
  565. <span class="n">z_hist</span><span class="p">,</span> <span class="n">x_hist</span> <span class="o">=</span> <span class="n">hmm</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="n">seed</span><span class="o">=</span><span class="n">jr</span><span class="o">.</span><span class="n">PRNGKey</span><span class="p">(</span><span class="n">seed</span><span class="p">),</span> <span class="n">seq_len</span><span class="o">=</span><span class="n">n_samples</span><span class="p">)</span>
  566. <span class="n">z_hist_str</span> <span class="o">=</span> <span class="s2">&quot;&quot;</span><span class="o">.</span><span class="n">join</span><span class="p">((</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">z_hist</span><span class="p">)</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="nb">str</span><span class="p">))[:</span><span class="mi">60</span><span class="p">]</span>
  567. <span class="n">x_hist_str</span> <span class="o">=</span> <span class="s2">&quot;&quot;</span><span class="o">.</span><span class="n">join</span><span class="p">((</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">x_hist</span><span class="p">)</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="nb">str</span><span class="p">))[:</span><span class="mi">60</span><span class="p">]</span>
  568. <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Printing sample observed/latent...&quot;</span><span class="p">)</span>
  569. <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;x: </span><span class="si">{</span><span class="n">x_hist_str</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
  570. <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;z: </span><span class="si">{</span><span class="n">z_hist_str</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
  571. </pre></div>
  572. </div>
  573. </div>
  574. <div class="cell_output docutils container">
  575. <div class="output stream highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>Printing sample observed/latent...
  576. x: 633665342652353616444236412331351246651613325161656366246242
  577. z: 222222211111111111111111111111111111111222111111112222211111
  578. </pre></div>
  579. </div>
  580. </div>
  581. </div>
  582. <p>Below is the source code for the sampling algorithm.</p>
  583. <div class="cell docutils container">
  584. <div class="cell_input docutils container">
  585. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">print_source</span><span class="p">(</span><span class="n">hmm</span><span class="o">.</span><span class="n">sample</span><span class="p">)</span>
  586. </pre></div>
  587. </div>
  588. </div>
  589. <div class="cell_output docutils container">
  590. <div class="output text_html"><pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"> def sample<span style="font-weight: bold">(</span>self,
  591. *,
  592. seed: chex.PRNGKey,
  593. seq_len: chex.Array<span style="font-weight: bold">)</span> -&gt; Tuple:
  594. <span style="color: #008000; text-decoration-color: #008000">""</span>"Sample from this HMM.
  595. Samples an observation of given length according to this
  596. Hidden Markov Model and gives the sequence of the hidden states
  597. as well as the observation.
  598. Args:
  599. seed: Random key of shape <span style="font-weight: bold">(</span><span style="color: #000080; text-decoration-color: #000080; font-weight: bold">2</span>,<span style="font-weight: bold">)</span> and dtype uint32.
  600. seq_len: The length of the observation sequence.
  601. Returns:
  602. Tuple of hidden state sequence, and observation sequence.
  603. <span style="color: #008000; text-decoration-color: #008000">""</span>"
  604. rng_key, rng_init = jax.random.split<span style="font-weight: bold">(</span>seed<span style="font-weight: bold">)</span>
  605. initial_state = self._init_dist.sample<span style="font-weight: bold">(</span><span style="color: #808000; text-decoration-color: #808000">seed</span>=<span style="color: #800080; text-decoration-color: #800080">rng_init</span><span style="font-weight: bold">)</span>
  606. def draw_state<span style="font-weight: bold">(</span>prev_state, key<span style="font-weight: bold">)</span>:
  607. state = self._trans_dist.sample<span style="font-weight: bold">(</span><span style="color: #808000; text-decoration-color: #808000">seed</span>=<span style="color: #800080; text-decoration-color: #800080">key</span><span style="font-weight: bold">)</span>
  608. return state, state
  609. rng_state, rng_obs = jax.random.split<span style="font-weight: bold">(</span>rng_key<span style="font-weight: bold">)</span>
  610. keys = jax.random.split<span style="font-weight: bold">(</span>rng_state, seq_len - <span style="color: #000080; text-decoration-color: #000080; font-weight: bold">1</span><span style="font-weight: bold">)</span>
  611. _, states = jax.lax.scan<span style="font-weight: bold">(</span>draw_state, initial_state, keys<span style="font-weight: bold">)</span>
  612. states = jnp.append<span style="font-weight: bold">(</span>initial_state, states<span style="font-weight: bold">)</span>
  613. def draw_obs<span style="font-weight: bold">(</span>state, key<span style="font-weight: bold">)</span>:
  614. return self._obs_dist.sample<span style="font-weight: bold">(</span><span style="color: #808000; text-decoration-color: #808000">seed</span>=<span style="color: #800080; text-decoration-color: #800080">key</span><span style="font-weight: bold">)</span>
  615. keys = jax.random.split<span style="font-weight: bold">(</span>rng_obs, seq_len<span style="font-weight: bold">)</span>
  616. obs_seq = jax.vmap<span style="font-weight: bold">(</span>draw_obs, <span style="color: #808000; text-decoration-color: #808000">in_axes</span>=<span style="font-weight: bold">(</span><span style="color: #000080; text-decoration-color: #000080; font-weight: bold">0</span>, <span style="color: #000080; text-decoration-color: #000080; font-weight: bold">0</span><span style="font-weight: bold">))(</span>states, keys<span style="font-weight: bold">)</span>
  617. return states, obs_seq
  618. </pre>
  619. </div></div>
  620. </div>
  621. <p>Let us check correctness by computing empirical pairwise statistics</p>
  622. <p>We will compute the number of i-&gt;j latent state transitions, and check that it is close to the true
  623. A[i,j] transition probabilites.</p>
  624. <div class="cell docutils container">
  625. <div class="cell_input docutils container">
  626. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">collections</span>
  627. <span class="k">def</span> <span class="nf">compute_counts</span><span class="p">(</span><span class="n">state_seq</span><span class="p">,</span> <span class="n">nstates</span><span class="p">):</span>
  628. <span class="n">wseq</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">state_seq</span><span class="p">)</span>
  629. <span class="n">word_pairs</span> <span class="o">=</span> <span class="p">[</span><span class="n">pair</span> <span class="k">for</span> <span class="n">pair</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">wseq</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">wseq</span><span class="p">[</span><span class="mi">1</span><span class="p">:])]</span>
  630. <span class="n">counter_pairs</span> <span class="o">=</span> <span class="n">collections</span><span class="o">.</span><span class="n">Counter</span><span class="p">(</span><span class="n">word_pairs</span><span class="p">)</span>
  631. <span class="n">counts</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">nstates</span><span class="p">,</span> <span class="n">nstates</span><span class="p">))</span>
  632. <span class="k">for</span> <span class="p">(</span><span class="n">k</span><span class="p">,</span><span class="n">v</span><span class="p">)</span> <span class="ow">in</span> <span class="n">counter_pairs</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
  633. <span class="n">counts</span><span class="p">[</span><span class="n">k</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">k</span><span class="p">[</span><span class="mi">1</span><span class="p">]]</span> <span class="o">=</span> <span class="n">v</span>
  634. <span class="k">return</span> <span class="n">counts</span>
  635. <span class="k">def</span> <span class="nf">normalize</span><span class="p">(</span><span class="n">u</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-15</span><span class="p">):</span>
  636. <span class="n">u</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">u</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">jnp</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">u</span> <span class="o">&lt;</span> <span class="n">eps</span><span class="p">,</span> <span class="n">eps</span><span class="p">,</span> <span class="n">u</span><span class="p">))</span>
  637. <span class="n">c</span> <span class="o">=</span> <span class="n">u</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="n">axis</span><span class="p">)</span>
  638. <span class="n">c</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">c</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">c</span><span class="p">)</span>
  639. <span class="k">return</span> <span class="n">u</span> <span class="o">/</span> <span class="n">c</span><span class="p">,</span> <span class="n">c</span>
  640. <span class="k">def</span> <span class="nf">normalize_counts</span><span class="p">(</span><span class="n">counts</span><span class="p">):</span>
  641. <span class="n">ncounts</span> <span class="o">=</span> <span class="n">vmap</span><span class="p">(</span><span class="k">lambda</span> <span class="n">v</span><span class="p">:</span> <span class="n">normalize</span><span class="p">(</span><span class="n">v</span><span class="p">)[</span><span class="mi">0</span><span class="p">],</span> <span class="n">in_axes</span><span class="o">=</span><span class="mi">0</span><span class="p">)(</span><span class="n">counts</span><span class="p">)</span>
  642. <span class="k">return</span> <span class="n">ncounts</span>
  643. <span class="n">init_dist</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mf">1.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">])</span>
  644. <span class="n">trans_mat</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="mf">0.7</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">]])</span>
  645. <span class="n">obs_mat</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span>
  646. <span class="n">hmm</span> <span class="o">=</span> <span class="n">HMM</span><span class="p">(</span><span class="n">trans_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">Categorical</span><span class="p">(</span><span class="n">probs</span><span class="o">=</span><span class="n">trans_mat</span><span class="p">),</span>
  647. <span class="n">init_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">Categorical</span><span class="p">(</span><span class="n">probs</span><span class="o">=</span><span class="n">init_dist</span><span class="p">),</span>
  648. <span class="n">obs_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">Categorical</span><span class="p">(</span><span class="n">probs</span><span class="o">=</span><span class="n">obs_mat</span><span class="p">))</span>
  649. <span class="n">rng_key</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">PRNGKey</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
  650. <span class="n">seq_len</span> <span class="o">=</span> <span class="mi">500</span>
  651. <span class="n">state_seq</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">hmm</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="n">seed</span><span class="o">=</span><span class="n">PRNGKey</span><span class="p">(</span><span class="n">seed</span><span class="p">),</span> <span class="n">seq_len</span><span class="o">=</span><span class="n">seq_len</span><span class="p">)</span>
  652. <span class="n">counts</span> <span class="o">=</span> <span class="n">compute_counts</span><span class="p">(</span><span class="n">state_seq</span><span class="p">,</span> <span class="n">nstates</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
  653. <span class="nb">print</span><span class="p">(</span><span class="n">counts</span><span class="p">)</span>
  654. <span class="n">trans_mat_empirical</span> <span class="o">=</span> <span class="n">normalize_counts</span><span class="p">(</span><span class="n">counts</span><span class="p">)</span>
  655. <span class="nb">print</span><span class="p">(</span><span class="n">trans_mat_empirical</span><span class="p">)</span>
  656. <span class="k">assert</span> <span class="n">jnp</span><span class="o">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">trans_mat</span><span class="p">,</span> <span class="n">trans_mat_empirical</span><span class="p">,</span> <span class="n">atol</span><span class="o">=</span><span class="mf">1e-1</span><span class="p">)</span>
  657. </pre></div>
  658. </div>
  659. </div>
  660. <div class="cell_output docutils container">
  661. <div class="output traceback highlight-ipythontb notranslate"><div class="highlight"><pre><span></span><span class="gt">---------------------------------------------------------------------------</span>
  662. <span class="ne">NameError</span><span class="g g-Whitespace"> </span>Traceback (most recent call last)
  663. <span class="o">&lt;</span><span class="n">ipython</span><span class="o">-</span><span class="nb">input</span><span class="o">-</span><span class="mi">6</span><span class="o">-</span><span class="n">f054683fcd82</span><span class="o">&gt;</span> <span class="ow">in</span> <span class="o">&lt;</span><span class="n">module</span><span class="o">&gt;</span>
  664. <span class="g g-Whitespace"> </span><span class="mi">30</span> <span class="n">rng_key</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">PRNGKey</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
  665. <span class="g g-Whitespace"> </span><span class="mi">31</span> <span class="n">seq_len</span> <span class="o">=</span> <span class="mi">500</span>
  666. <span class="ne">---&gt; </span><span class="mi">32</span> <span class="n">state_seq</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">hmm</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="n">seed</span><span class="o">=</span><span class="n">PRNGKey</span><span class="p">(</span><span class="n">seed</span><span class="p">),</span> <span class="n">seq_len</span><span class="o">=</span><span class="n">seq_len</span><span class="p">)</span>
  667. <span class="g g-Whitespace"> </span><span class="mi">33</span>
  668. <span class="g g-Whitespace"> </span><span class="mi">34</span> <span class="n">counts</span> <span class="o">=</span> <span class="n">compute_counts</span><span class="p">(</span><span class="n">state_seq</span><span class="p">,</span> <span class="n">nstates</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
  669. <span class="ne">NameError</span>: name &#39;PRNGKey&#39; is not defined
  670. </pre></div>
  671. </div>
  672. </div>
  673. </div>
  674. <p>Our primary goal will be to infer the latent state from the observations,
  675. so we can detect if the casino is being dishonest or not. This will
  676. affect how we choose to gamble our money.
  677. We discuss various ways to perform this inference below.</p>
  678. </div>
  679. <div class="section" id="example-lillypad-hmm">
  680. <span id="sec-lillypad"></span><h2>Example: Lillypad HMM<a class="headerlink" href="#example-lillypad-hmm" title="Permalink to this headline">¶</a></h2>
  681. <p>If <span class="math notranslate nohighlight">\(\obs_t\)</span> is continuous, it is common to use a Gaussian
  682. observation model:</p>
  683. <div class="math notranslate nohighlight">
  684. \[p(\obs_t|\hidden_t=j) = \gauss(\obs_t|\vmu_j,\vSigma_j)\]</div>
  685. <p>This is sometimes called a Gaussian HMM.</p>
  686. <p>As a simple example, suppose we have an HMM with 3 hidden states,
  687. each of which generates a 2d Gaussian.
  688. We can represent these Gaussian distributions are 2d ellipses,
  689. as we show below.
  690. We call these ``lilly pads’’, because of their shape.
  691. We can imagine a frog hopping from one lilly pad to another.
  692. (This analogy is due to the late Sam Roweis.)
  693. The frog will stay on a pad for a while (corresponding to remaining in the same
  694. discrete state <span class="math notranslate nohighlight">\(\hidden_t\)</span>), and then jump to a new pad
  695. (corresponding to a transition to a new state).
  696. The data we see are just the 2d points (e.g., water droplets)
  697. coming from near the pad that the frog is currently on.
  698. Thus this model is like a Gaussian mixture model,
  699. in that it generates clusters of observations,
  700. except now there is temporal correlation between the data points.</p>
  701. <p>Let us now illustrate this model in code.</p>
  702. <div class="cell docutils container">
  703. <div class="cell_input docutils container">
  704. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># Let us create the model</span>
  705. <span class="n">initial_probs</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.2</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">])</span>
  706. <span class="c1"># transition matrix</span>
  707. <span class="n">A</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span>
  708. <span class="p">[</span><span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.4</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">],</span>
  709. <span class="p">[</span><span class="mf">0.1</span><span class="p">,</span> <span class="mf">0.6</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">],</span>
  710. <span class="p">[</span><span class="mf">0.2</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">]</span>
  711. <span class="p">])</span>
  712. <span class="c1"># Observation model</span>
  713. <span class="n">mu_collection</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span>
  714. <span class="p">[</span><span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">],</span>
  715. <span class="p">[</span><span class="mf">0.8</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">],</span>
  716. <span class="p">[</span><span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.8</span><span class="p">]</span>
  717. <span class="p">])</span>
  718. <span class="n">S1</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="mf">1.1</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">]])</span>
  719. <span class="n">S2</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="mf">0.3</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.5</span><span class="p">],</span> <span class="p">[</span><span class="o">-</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">1.3</span><span class="p">]])</span>
  720. <span class="n">S3</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="mf">0.8</span><span class="p">,</span> <span class="mf">0.4</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.4</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">]])</span>
  721. <span class="n">cov_collection</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">S1</span><span class="p">,</span> <span class="n">S2</span><span class="p">,</span> <span class="n">S3</span><span class="p">])</span> <span class="o">/</span> <span class="mi">60</span>
  722. <span class="kn">import</span> <span class="nn">tensorflow_probability</span> <span class="k">as</span> <span class="nn">tfp</span>
  723. <span class="k">if</span> <span class="kc">False</span><span class="p">:</span>
  724. <span class="n">hmm</span> <span class="o">=</span> <span class="n">HMM</span><span class="p">(</span><span class="n">trans_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">Categorical</span><span class="p">(</span><span class="n">probs</span><span class="o">=</span><span class="n">A</span><span class="p">),</span>
  725. <span class="n">init_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">Categorical</span><span class="p">(</span><span class="n">probs</span><span class="o">=</span><span class="n">initial_probs</span><span class="p">),</span>
  726. <span class="n">obs_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">MultivariateNormalFullCovariance</span><span class="p">(</span>
  727. <span class="n">loc</span><span class="o">=</span><span class="n">mu_collection</span><span class="p">,</span> <span class="n">covariance_matrix</span><span class="o">=</span><span class="n">cov_collection</span><span class="p">))</span>
  728. <span class="k">else</span><span class="p">:</span>
  729. <span class="n">hmm</span> <span class="o">=</span> <span class="n">HMM</span><span class="p">(</span><span class="n">trans_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">Categorical</span><span class="p">(</span><span class="n">probs</span><span class="o">=</span><span class="n">A</span><span class="p">),</span>
  730. <span class="n">init_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">Categorical</span><span class="p">(</span><span class="n">probs</span><span class="o">=</span><span class="n">initial_probs</span><span class="p">),</span>
  731. <span class="n">obs_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">as_distribution</span><span class="p">(</span>
  732. <span class="n">tfp</span><span class="o">.</span><span class="n">substrates</span><span class="o">.</span><span class="n">jax</span><span class="o">.</span><span class="n">distributions</span><span class="o">.</span><span class="n">MultivariateNormalFullCovariance</span><span class="p">(</span><span class="n">loc</span><span class="o">=</span><span class="n">mu_collection</span><span class="p">,</span>
  733. <span class="n">covariance_matrix</span><span class="o">=</span><span class="n">cov_collection</span><span class="p">)))</span>
  734. <span class="nb">print</span><span class="p">(</span><span class="n">hmm</span><span class="p">)</span>
  735. </pre></div>
  736. </div>
  737. </div>
  738. <div class="cell_output docutils container">
  739. <div class="output stream highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>&lt;distrax._src.utils.hmm.HMM object at 0x7fd658f7b7f0&gt;
  740. </pre></div>
  741. </div>
  742. </div>
  743. </div>
  744. <div class="cell docutils container">
  745. <div class="cell_input docutils container">
  746. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">n_samples</span><span class="p">,</span> <span class="n">seed</span> <span class="o">=</span> <span class="mi">50</span><span class="p">,</span> <span class="mi">10</span>
  747. <span class="n">samples_state</span><span class="p">,</span> <span class="n">samples_obs</span> <span class="o">=</span> <span class="n">hmm</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="n">seed</span><span class="o">=</span><span class="n">PRNGKey</span><span class="p">(</span><span class="n">seed</span><span class="p">),</span> <span class="n">seq_len</span><span class="o">=</span><span class="n">n_samples</span><span class="p">)</span>
  748. <span class="nb">print</span><span class="p">(</span><span class="n">samples_state</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
  749. <span class="nb">print</span><span class="p">(</span><span class="n">samples_obs</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
  750. </pre></div>
  751. </div>
  752. </div>
  753. <div class="cell_output docutils container">
  754. <div class="output stream highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>(50,)
  755. (50, 2)
  756. </pre></div>
  757. </div>
  758. </div>
  759. </div>
  760. <div class="cell docutils container">
  761. <div class="cell_input docutils container">
  762. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># Let&#39;s plot the observed data in 2d</span>
  763. <span class="n">xmin</span><span class="p">,</span> <span class="n">xmax</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span>
  764. <span class="n">ymin</span><span class="p">,</span> <span class="n">ymax</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="mf">1.2</span>
  765. <span class="n">colors</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;tab:green&quot;</span><span class="p">,</span> <span class="s2">&quot;tab:blue&quot;</span><span class="p">,</span> <span class="s2">&quot;tab:red&quot;</span><span class="p">]</span>
  766. <span class="k">def</span> <span class="nf">plot_2dhmm</span><span class="p">(</span><span class="n">hmm</span><span class="p">,</span> <span class="n">samples_obs</span><span class="p">,</span> <span class="n">samples_state</span><span class="p">,</span> <span class="n">colors</span><span class="p">,</span> <span class="n">ax</span><span class="p">,</span> <span class="n">xmin</span><span class="p">,</span> <span class="n">xmax</span><span class="p">,</span> <span class="n">ymin</span><span class="p">,</span> <span class="n">ymax</span><span class="p">,</span> <span class="n">step</span><span class="o">=</span><span class="mf">1e-2</span><span class="p">):</span>
  767. <span class="n">obs_dist</span> <span class="o">=</span> <span class="n">hmm</span><span class="o">.</span><span class="n">obs_dist</span>
  768. <span class="n">color_sample</span> <span class="o">=</span> <span class="p">[</span><span class="n">colors</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">samples_state</span><span class="p">]</span>
  769. <span class="n">xs</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">xmin</span><span class="p">,</span> <span class="n">xmax</span><span class="p">,</span> <span class="n">step</span><span class="p">)</span>
  770. <span class="n">ys</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">ymin</span><span class="p">,</span> <span class="n">ymax</span><span class="p">,</span> <span class="n">step</span><span class="p">)</span>
  771. <span class="n">v_prob</span> <span class="o">=</span> <span class="n">vmap</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">obs_dist</span><span class="o">.</span><span class="n">prob</span><span class="p">(</span><span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">])),</span> <span class="n">in_axes</span><span class="o">=</span><span class="p">(</span><span class="kc">None</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span>
  772. <span class="n">z</span> <span class="o">=</span> <span class="n">vmap</span><span class="p">(</span><span class="n">v_prob</span><span class="p">,</span> <span class="n">in_axes</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="kc">None</span><span class="p">))(</span><span class="n">xs</span><span class="p">,</span> <span class="n">ys</span><span class="p">)</span>
  773. <span class="n">grid</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">mgrid</span><span class="p">[</span><span class="n">xmin</span><span class="p">:</span><span class="n">xmax</span><span class="p">:</span><span class="n">step</span><span class="p">,</span> <span class="n">ymin</span><span class="p">:</span><span class="n">ymax</span><span class="p">:</span><span class="n">step</span><span class="p">]</span>
  774. <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">color</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">colors</span><span class="p">):</span>
  775. <span class="n">ax</span><span class="o">.</span><span class="n">contour</span><span class="p">(</span><span class="o">*</span><span class="n">grid</span><span class="p">,</span> <span class="n">z</span><span class="p">[:,</span> <span class="p">:,</span> <span class="n">k</span><span class="p">],</span> <span class="n">levels</span><span class="o">=</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">colors</span><span class="o">=</span><span class="n">color</span><span class="p">,</span> <span class="n">linewidths</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span>
  776. <span class="n">ax</span><span class="o">.</span><span class="n">text</span><span class="p">(</span><span class="o">*</span><span class="p">(</span><span class="n">obs_dist</span><span class="o">.</span><span class="n">mean</span><span class="p">()[</span><span class="n">k</span><span class="p">]</span> <span class="o">+</span> <span class="mf">0.13</span><span class="p">),</span> <span class="sa">f</span><span class="s2">&quot;$k$=</span><span class="si">{</span><span class="n">k</span> <span class="o">+</span> <span class="mi">1</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">13</span><span class="p">,</span> <span class="n">horizontalalignment</span><span class="o">=</span><span class="s2">&quot;right&quot;</span><span class="p">)</span>
  777. <span class="n">ax</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="o">*</span><span class="n">samples_obs</span><span class="o">.</span><span class="n">T</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="s2">&quot;black&quot;</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.3</span><span class="p">,</span> <span class="n">zorder</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
  778. <span class="n">ax</span><span class="o">.</span><span class="n">scatter</span><span class="p">(</span><span class="o">*</span><span class="n">samples_obs</span><span class="o">.</span><span class="n">T</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="n">color_sample</span><span class="p">,</span> <span class="n">s</span><span class="o">=</span><span class="mi">30</span><span class="p">,</span> <span class="n">zorder</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.8</span><span class="p">)</span>
  779. <span class="k">return</span> <span class="n">ax</span><span class="p">,</span> <span class="n">color_sample</span>
  780. <span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">()</span>
  781. <span class="n">_</span><span class="p">,</span> <span class="n">color_sample</span> <span class="o">=</span> <span class="n">plot_2dhmm</span><span class="p">(</span><span class="n">hmm</span><span class="p">,</span> <span class="n">samples_obs</span><span class="p">,</span> <span class="n">samples_state</span><span class="p">,</span> <span class="n">colors</span><span class="p">,</span> <span class="n">ax</span><span class="p">,</span> <span class="n">xmin</span><span class="p">,</span> <span class="n">xmax</span><span class="p">,</span> <span class="n">ymin</span><span class="p">,</span> <span class="n">ymax</span><span class="p">)</span>
  782. </pre></div>
  783. </div>
  784. </div>
  785. <div class="cell_output docutils container">
  786. <img alt="../../_images/hmm_15_0.png" src="../../_images/hmm_15_0.png" />
  787. </div>
  788. </div>
  789. <div class="cell docutils container">
  790. <div class="cell_input docutils container">
  791. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># Let&#39;s plot the hidden state sequence</span>
  792. <span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">()</span>
  793. <span class="n">ax</span><span class="o">.</span><span class="n">step</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">n_samples</span><span class="p">),</span> <span class="n">samples_state</span><span class="p">,</span> <span class="n">where</span><span class="o">=</span><span class="s2">&quot;post&quot;</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="s2">&quot;black&quot;</span><span class="p">,</span> <span class="n">linewidth</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.3</span><span class="p">)</span>
  794. <span class="n">ax</span><span class="o">.</span><span class="n">scatter</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">n_samples</span><span class="p">),</span> <span class="n">samples_state</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="n">color_sample</span><span class="p">,</span> <span class="n">zorder</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span>
  795. </pre></div>
  796. </div>
  797. </div>
  798. <div class="cell_output docutils container">
  799. <div class="output text_plain highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>&lt;matplotlib.collections.PathCollection at 0x7fd5e174b460&gt;
  800. </pre></div>
  801. </div>
  802. <img alt="../../_images/hmm_16_1.png" src="../../_images/hmm_16_1.png" />
  803. </div>
  804. </div>
  805. </div>
  806. </div>
  807. <script type="text/x-thebe-config">
  808. {
  809. requestKernel: true,
  810. binderOptions: {
  811. repo: "binder-examples/jupyter-stacks-datascience",
  812. ref: "master",
  813. },
  814. codeMirrorConfig: {
  815. theme: "abcdef",
  816. mode: "python"
  817. },
  818. kernelOptions: {
  819. kernelName: "python3",
  820. path: "./chapters/ssm"
  821. },
  822. predefinedOutput: true
  823. }
  824. </script>
  825. <script>kernelName = 'python3'</script>
  826. </div>
  827. <!-- Previous / next buttons -->
  828. <div class='prev-next-area'>
  829. <a class='left-prev' id="prev-link" href="ssm_intro.html" title="previous page">
  830. <i class="fas fa-angle-left"></i>
  831. <div class="prev-next-info">
  832. <p class="prev-next-subtitle">previous</p>
  833. <p class="prev-next-title">What are State Space Models?</p>
  834. </div>
  835. </a>
  836. <a class='right-next' id="next-link" href="lds.html" title="next page">
  837. <div class="prev-next-info">
  838. <p class="prev-next-subtitle">next</p>
  839. <p class="prev-next-title">Linear Gaussian SSMs</p>
  840. </div>
  841. <i class="fas fa-angle-right"></i>
  842. </a>
  843. </div>
  844. </div>
  845. </div>
  846. <footer class="footer">
  847. <p>
  848. By Kevin Murphy, Scott Linderman, et al.<br/>
  849. &copy; Copyright 2021.<br/>
  850. </p>
  851. </footer>
  852. </main>
  853. </div>
  854. </div>
  855. <script src="../../_static/js/index.be7d3bbb2ef33a8344ce.js"></script>
  856. </body>
  857. </html>