hmm.html 64 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995
  1. <!DOCTYPE html>
  2. <html>
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>Hidden Markov Models &#8212; State Space Models: A Modern Approach</title>
  7. <link href="../../_static/css/theme.css" rel="stylesheet">
  8. <link href="../../_static/css/index.ff1ffe594081f20da1ef19478df9384b.css" rel="stylesheet">
  9. <link rel="stylesheet"
  10. href="../../_static/vendor/fontawesome/5.13.0/css/all.min.css">
  11. <link rel="preload" as="font" type="font/woff2" crossorigin
  12. href="../../_static/vendor/fontawesome/5.13.0/webfonts/fa-solid-900.woff2">
  13. <link rel="preload" as="font" type="font/woff2" crossorigin
  14. href="../../_static/vendor/fontawesome/5.13.0/webfonts/fa-brands-400.woff2">
  15. <link rel="stylesheet" type="text/css" href="../../_static/pygments.css" />
  16. <link rel="stylesheet" type="text/css" href="../../_static/sphinx-book-theme.css?digest=c3fdc42140077d1ad13ad2f1588a4309" />
  17. <link rel="stylesheet" type="text/css" href="../../_static/togglebutton.css" />
  18. <link rel="stylesheet" type="text/css" href="../../_static/copybutton.css" />
  19. <link rel="stylesheet" type="text/css" href="../../_static/mystnb.css" />
  20. <link rel="stylesheet" type="text/css" href="../../_static/sphinx-thebe.css" />
  21. <link rel="stylesheet" type="text/css" href="../../_static/panels-main.c949a650a448cc0ae9fd3441c0e17fb0.css" />
  22. <link rel="stylesheet" type="text/css" href="../../_static/panels-variables.06eb56fa6e07937060861dad626602ad.css" />
  23. <link rel="preload" as="script" href="../../_static/js/index.be7d3bbb2ef33a8344ce.js">
  24. <script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
  25. <script src="../../_static/jquery.js"></script>
  26. <script src="../../_static/underscore.js"></script>
  27. <script src="../../_static/doctools.js"></script>
  28. <script src="../../_static/clipboard.min.js"></script>
  29. <script src="../../_static/copybutton.js"></script>
  30. <script>let toggleHintShow = 'Click to show';</script>
  31. <script>let toggleHintHide = 'Click to hide';</script>
  32. <script>let toggleOpenOnPrint = 'true';</script>
  33. <script src="../../_static/togglebutton.js"></script>
  34. <script>var togglebuttonSelector = '.toggle, .admonition.dropdown, .tag_hide_input div.cell_input, .tag_hide-input div.cell_input, .tag_hide_output div.cell_output, .tag_hide-output div.cell_output, .tag_hide_cell.cell, .tag_hide-cell.cell';</script>
  35. <script src="../../_static/sphinx-book-theme.d59cb220de22ca1c485ebbdc042f0030.js"></script>
  36. <script>const THEBE_JS_URL = "https://unpkg.com/thebe@0.8.2/lib/index.js"
  37. const thebe_selector = ".thebe,.cell"
  38. const thebe_selector_input = "pre"
  39. const thebe_selector_output = ".output, .cell_output"
  40. </script>
  41. <script async="async" src="../../_static/sphinx-thebe.js"></script>
  42. <link rel="index" title="Index" href="../../genindex.html" />
  43. <link rel="search" title="Search" href="../../search.html" />
  44. <link rel="next" title="HMM filtering (forwards algorithm)" href="hmm_filter.html" />
  45. <link rel="prev" title="Inference in discrete SSMs" href="hmm_index.html" />
  46. <meta name="viewport" content="width=device-width, initial-scale=1" />
  47. <meta name="docsearch:language" content="None">
  48. <!-- Google Analytics -->
  49. </head>
  50. <body data-spy="scroll" data-target="#bd-toc-nav" data-offset="80">
  51. <div class="container-fluid" id="banner"></div>
  52. <div class="container-xl">
  53. <div class="row">
  54. <div class="col-12 col-md-3 bd-sidebar site-navigation show" id="site-navigation">
  55. <div class="navbar-brand-box">
  56. <a class="navbar-brand text-wrap" href="../../index.html">
  57. <h1 class="site-logo" id="site-title">State Space Models: A Modern Approach</h1>
  58. </a>
  59. </div><form class="bd-search d-flex align-items-center" action="../../search.html" method="get">
  60. <i class="icon fas fa-search"></i>
  61. <input type="search" class="form-control" name="q" id="search-input" placeholder="Search this book..." aria-label="Search this book..." autocomplete="off" >
  62. </form><nav class="bd-links" id="bd-docs-nav" aria-label="Main">
  63. <div class="bd-toc-item active">
  64. <ul class="nav bd-sidenav">
  65. <li class="toctree-l1">
  66. <a class="reference internal" href="../../root.html">
  67. State Space Models: A Modern Approach
  68. </a>
  69. </li>
  70. </ul>
  71. <ul class="current nav bd-sidenav">
  72. <li class="toctree-l1">
  73. <a class="reference internal" href="../scratch.html">
  74. Scratchpad
  75. </a>
  76. </li>
  77. <li class="toctree-l1">
  78. <a class="reference internal" href="../ssm/ssm.html">
  79. What are State Space Models?
  80. </a>
  81. </li>
  82. <li class="toctree-l1 current active has-children">
  83. <a class="reference internal" href="hmm_index.html">
  84. Inference in discrete SSMs
  85. </a>
  86. <input checked="" class="toctree-checkbox" id="toctree-checkbox-1" name="toctree-checkbox-1" type="checkbox"/>
  87. <label for="toctree-checkbox-1">
  88. <i class="fas fa-chevron-down">
  89. </i>
  90. </label>
  91. <ul class="current">
  92. <li class="toctree-l2 current active">
  93. <a class="current reference internal" href="#">
  94. Hidden Markov Models
  95. </a>
  96. </li>
  97. <li class="toctree-l2">
  98. <a class="reference internal" href="hmm_filter.html">
  99. HMM filtering (forwards algorithm)
  100. </a>
  101. </li>
  102. <li class="toctree-l2">
  103. <a class="reference internal" href="hmm_smoother.html">
  104. HMM smoothing (forwards-backwards algorithm)
  105. </a>
  106. </li>
  107. <li class="toctree-l2">
  108. <a class="reference internal" href="hmm_viterbi.html">
  109. Viterbi algorithm
  110. </a>
  111. </li>
  112. <li class="toctree-l2">
  113. <a class="reference internal" href="hmm_parallel.html">
  114. Parallel HMM smoothing
  115. </a>
  116. </li>
  117. <li class="toctree-l2">
  118. <a class="reference internal" href="hmm_sampling.html">
  119. Forwards-filtering backwards-sampling algorithm
  120. </a>
  121. </li>
  122. </ul>
  123. </li>
  124. <li class="toctree-l1 has-children">
  125. <a class="reference internal" href="../lgssm/lgssm_index.html">
  126. Inference in linear-Gaussian SSMs
  127. </a>
  128. <input class="toctree-checkbox" id="toctree-checkbox-2" name="toctree-checkbox-2" type="checkbox"/>
  129. <label for="toctree-checkbox-2">
  130. <i class="fas fa-chevron-down">
  131. </i>
  132. </label>
  133. <ul>
  134. <li class="toctree-l2">
  135. <a class="reference internal" href="../lgssm/kalman_filter.html">
  136. Kalman filtering
  137. </a>
  138. </li>
  139. <li class="toctree-l2">
  140. <a class="reference internal" href="../lgssm/kalman_smoother.html">
  141. Kalman (RTS) smoother
  142. </a>
  143. </li>
  144. <li class="toctree-l2">
  145. <a class="reference internal" href="../lgssm/kalman_parallel.html">
  146. Parallel Kalman Smoother
  147. </a>
  148. </li>
  149. <li class="toctree-l2">
  150. <a class="reference internal" href="../lgssm/kalman_sampling.html">
  151. Forwards-filtering backwards sampling
  152. </a>
  153. </li>
  154. </ul>
  155. </li>
  156. <li class="toctree-l1 has-children">
  157. <a class="reference internal" href="../extended/extended_index.html">
  158. Extended (linearized) methods
  159. </a>
  160. <input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
  161. <label for="toctree-checkbox-3">
  162. <i class="fas fa-chevron-down">
  163. </i>
  164. </label>
  165. <ul>
  166. <li class="toctree-l2">
  167. <a class="reference internal" href="../extended/extended_filter.html">
  168. Extended Kalman filtering
  169. </a>
  170. </li>
  171. <li class="toctree-l2">
  172. <a class="reference internal" href="../extended/extended_smoother.html">
  173. Extended Kalman smoother
  174. </a>
  175. </li>
  176. <li class="toctree-l2">
  177. <a class="reference internal" href="../extended/extended_parallel.html">
  178. Parallel extended Kalman smoothing
  179. </a>
  180. </li>
  181. </ul>
  182. </li>
  183. <li class="toctree-l1 has-children">
  184. <a class="reference internal" href="../unscented/unscented_index.html">
  185. Unscented methods
  186. </a>
  187. <input class="toctree-checkbox" id="toctree-checkbox-4" name="toctree-checkbox-4" type="checkbox"/>
  188. <label for="toctree-checkbox-4">
  189. <i class="fas fa-chevron-down">
  190. </i>
  191. </label>
  192. <ul>
  193. <li class="toctree-l2">
  194. <a class="reference internal" href="../unscented/unscented_filter.html">
  195. Unscented filtering
  196. </a>
  197. </li>
  198. <li class="toctree-l2">
  199. <a class="reference internal" href="../unscented/unscented_smoother.html">
  200. Unscented smoothing
  201. </a>
  202. </li>
  203. </ul>
  204. </li>
  205. <li class="toctree-l1">
  206. <a class="reference internal" href="../quadrature/quadrature_index.html">
  207. Quadrature and cubature methods
  208. </a>
  209. </li>
  210. <li class="toctree-l1">
  211. <a class="reference internal" href="../postlin/postlin_index.html">
  212. Posterior linearization
  213. </a>
  214. </li>
  215. <li class="toctree-l1">
  216. <a class="reference internal" href="../adf/adf_index.html">
  217. Assumed Density Filtering
  218. </a>
  219. </li>
  220. <li class="toctree-l1">
  221. <a class="reference internal" href="../vi/vi_index.html">
  222. Variational inference
  223. </a>
  224. </li>
  225. <li class="toctree-l1">
  226. <a class="reference internal" href="../pf/pf_index.html">
  227. Particle filtering
  228. </a>
  229. </li>
  230. <li class="toctree-l1">
  231. <a class="reference internal" href="../smc/smc_index.html">
  232. Sequential Monte Carlo
  233. </a>
  234. </li>
  235. <li class="toctree-l1 has-children">
  236. <a class="reference internal" href="../learning/learning_index.html">
  237. Offline parameter estimation (learning)
  238. </a>
  239. <input class="toctree-checkbox" id="toctree-checkbox-5" name="toctree-checkbox-5" type="checkbox"/>
  240. <label for="toctree-checkbox-5">
  241. <i class="fas fa-chevron-down">
  242. </i>
  243. </label>
  244. <ul>
  245. <li class="toctree-l2">
  246. <a class="reference internal" href="../learning/em.html">
  247. Expectation Maximization (EM)
  248. </a>
  249. </li>
  250. <li class="toctree-l2">
  251. <a class="reference internal" href="../learning/sgd.html">
  252. Stochastic Gradient Descent (SGD)
  253. </a>
  254. </li>
  255. <li class="toctree-l2">
  256. <a class="reference internal" href="../learning/vb.html">
  257. Variational Bayes (VB)
  258. </a>
  259. </li>
  260. <li class="toctree-l2">
  261. <a class="reference internal" href="../learning/mcmc.html">
  262. Markov Chain Monte Carlo (MCMC)
  263. </a>
  264. </li>
  265. </ul>
  266. </li>
  267. <li class="toctree-l1">
  268. <a class="reference internal" href="../tracking/tracking_index.html">
  269. Multi-target tracking
  270. </a>
  271. </li>
  272. <li class="toctree-l1">
  273. <a class="reference internal" href="../ensemble/ensemble_index.html">
  274. Data assimilation using Ensemble Kalman filter
  275. </a>
  276. </li>
  277. <li class="toctree-l1">
  278. <a class="reference internal" href="../bnp/bnp_index.html">
  279. Bayesian non-parametric SSMs
  280. </a>
  281. </li>
  282. <li class="toctree-l1">
  283. <a class="reference internal" href="../changepoint/changepoint_index.html">
  284. Changepoint detection
  285. </a>
  286. </li>
  287. <li class="toctree-l1">
  288. <a class="reference internal" href="../timeseries/timeseries_index.html">
  289. Timeseries forecasting
  290. </a>
  291. </li>
  292. <li class="toctree-l1">
  293. <a class="reference internal" href="../gp/gp_index.html">
  294. Markovian Gaussian processes
  295. </a>
  296. </li>
  297. <li class="toctree-l1">
  298. <a class="reference internal" href="../ode/ode_index.html">
  299. Differential equations and SSMs
  300. </a>
  301. </li>
  302. <li class="toctree-l1">
  303. <a class="reference internal" href="../control/control_index.html">
  304. Optimal control
  305. </a>
  306. </li>
  307. <li class="toctree-l1">
  308. <a class="reference internal" href="../../bib.html">
  309. Bibliography
  310. </a>
  311. </li>
  312. </ul>
  313. </div>
  314. </nav> <!-- To handle the deprecated key -->
  315. <div class="navbar_extra_footer">
  316. Powered by <a href="https://jupyterbook.org">Jupyter Book</a>
  317. </div>
  318. </div>
  319. <main class="col py-md-3 pl-md-4 bd-content overflow-auto" role="main">
  320. <div class="topbar container-xl fixed-top">
  321. <div class="topbar-contents row">
  322. <div class="col-12 col-md-3 bd-topbar-whitespace site-navigation show"></div>
  323. <div class="col pl-md-4 topbar-main">
  324. <button id="navbar-toggler" class="navbar-toggler ml-0" type="button" data-toggle="collapse"
  325. data-toggle="tooltip" data-placement="bottom" data-target=".site-navigation" aria-controls="navbar-menu"
  326. aria-expanded="true" aria-label="Toggle navigation" aria-controls="site-navigation"
  327. title="Toggle navigation" data-toggle="tooltip" data-placement="left">
  328. <i class="fas fa-bars"></i>
  329. <i class="fas fa-arrow-left"></i>
  330. <i class="fas fa-arrow-up"></i>
  331. </button>
  332. <div class="dropdown-buttons-trigger">
  333. <button id="dropdown-buttons-trigger" class="btn btn-secondary topbarbtn" aria-label="Download this page"><i
  334. class="fas fa-download"></i></button>
  335. <div class="dropdown-buttons">
  336. <!-- ipynb file if we had a myst markdown file -->
  337. <!-- Download raw file -->
  338. <a class="dropdown-buttons" href="../../_sources/chapters/hmm/hmm.ipynb"><button type="button"
  339. class="btn btn-secondary topbarbtn" title="Download source file" data-toggle="tooltip"
  340. data-placement="left">.ipynb</button></a>
  341. <!-- Download PDF via print -->
  342. <button type="button" id="download-print" class="btn btn-secondary topbarbtn" title="Print to PDF"
  343. onclick="printPdf(this)" data-toggle="tooltip" data-placement="left">.pdf</button>
  344. </div>
  345. </div>
  346. <!-- Source interaction buttons -->
  347. <div class="dropdown-buttons-trigger">
  348. <button id="dropdown-buttons-trigger" class="btn btn-secondary topbarbtn"
  349. aria-label="Connect with source repository"><i class="fab fa-github"></i></button>
  350. <div class="dropdown-buttons sourcebuttons">
  351. <a class="repository-button"
  352. href="https://github.com/probml/ssm-book"><button type="button" class="btn btn-secondary topbarbtn"
  353. data-toggle="tooltip" data-placement="left" title="Source repository"><i
  354. class="fab fa-github"></i>repository</button></a>
  355. <a class="issues-button"
  356. href="https://github.com/probml/ssm-book/issues/new?title=Issue%20on%20page%20%2Fchapters/hmm/hmm.html&body=Your%20issue%20content%20here."><button
  357. type="button" class="btn btn-secondary topbarbtn" data-toggle="tooltip" data-placement="left"
  358. title="Open an issue"><i class="fas fa-lightbulb"></i>open issue</button></a>
  359. </div>
  360. </div>
  361. <!-- Full screen (wrap in <a> to have style consistency -->
  362. <a class="full-screen-button"><button type="button" class="btn btn-secondary topbarbtn" data-toggle="tooltip"
  363. data-placement="bottom" onclick="toggleFullScreen()" aria-label="Fullscreen mode"
  364. title="Fullscreen mode"><i
  365. class="fas fa-expand"></i></button></a>
  366. <!-- Launch buttons -->
  367. <div class="dropdown-buttons-trigger">
  368. <button id="dropdown-buttons-trigger" class="btn btn-secondary topbarbtn"
  369. aria-label="Launch interactive content"><i class="fas fa-rocket"></i></button>
  370. <div class="dropdown-buttons">
  371. <a class="binder-button" href="https://mybinder.org/v2/gh/probml/ssm-book/main?urlpath=tree/chapters/hmm/hmm.ipynb"><button type="button"
  372. class="btn btn-secondary topbarbtn" title="Launch Binder" data-toggle="tooltip"
  373. data-placement="left"><img class="binder-button-logo"
  374. src="../../_static/images/logo_binder.svg"
  375. alt="Interact on binder">Binder</button></a>
  376. <a class="colab-button" href="https://colab.research.google.com/github/probml/ssm-book/blob/main/chapters/hmm/hmm.ipynb"><button type="button" class="btn btn-secondary topbarbtn"
  377. title="Launch Colab" data-toggle="tooltip" data-placement="left"><img class="colab-button-logo"
  378. src="../../_static/images/logo_colab.png"
  379. alt="Interact on Colab">Colab</button></a>
  380. </div>
  381. </div>
  382. </div>
  383. <!-- Table of contents -->
  384. <div class="d-none d-md-block col-md-2 bd-toc show noprint">
  385. <div class="tocsection onthispage pt-5 pb-3">
  386. <i class="fas fa-list"></i> Contents
  387. </div>
  388. <nav id="bd-toc-nav" aria-label="Page">
  389. <ul class="visible nav section-nav flex-column">
  390. <li class="toc-h2 nav-item toc-entry">
  391. <a class="reference internal nav-link" href="#boilerplate">
  392. Boilerplate
  393. </a>
  394. </li>
  395. <li class="toc-h2 nav-item toc-entry">
  396. <a class="reference internal nav-link" href="#utility-code">
  397. Utility code
  398. </a>
  399. </li>
  400. <li class="toc-h2 nav-item toc-entry">
  401. <a class="reference internal nav-link" href="#example-casino-hmm">
  402. Example: Casino HMM
  403. </a>
  404. </li>
  405. <li class="toc-h2 nav-item toc-entry">
  406. <a class="reference internal nav-link" href="#sampling-from-the-joint">
  407. Sampling from the joint
  408. </a>
  409. <ul class="nav section-nav flex-column">
  410. <li class="toc-h3 nav-item toc-entry">
  411. <a class="reference internal nav-link" href="#numpy-version">
  412. Numpy version
  413. </a>
  414. </li>
  415. <li class="toc-h3 nav-item toc-entry">
  416. <a class="reference internal nav-link" href="#jax-version">
  417. JAX version
  418. </a>
  419. </li>
  420. <li class="toc-h3 nav-item toc-entry">
  421. <a class="reference internal nav-link" href="#check-correctness-by-computing-empirical-pairwise-statistics">
  422. Check correctness by computing empirical pairwise statistics
  423. </a>
  424. </li>
  425. </ul>
  426. </li>
  427. </ul>
  428. </nav>
  429. </div>
  430. </div>
  431. </div>
  432. <div id="main-content" class="row">
  433. <div class="col-12 col-md-9 pl-md-3 pr-md-0">
  434. <!-- Table of contents that is only displayed when printing the page -->
  435. <div id="jb-print-docs-body" class="onlyprint">
  436. <h1>Hidden Markov Models</h1>
  437. <!-- Table of contents -->
  438. <div id="print-main-content">
  439. <div id="jb-print-toc">
  440. <div>
  441. <h2> Contents </h2>
  442. </div>
  443. <nav aria-label="Page">
  444. <ul class="visible nav section-nav flex-column">
  445. <li class="toc-h2 nav-item toc-entry">
  446. <a class="reference internal nav-link" href="#boilerplate">
  447. Boilerplate
  448. </a>
  449. </li>
  450. <li class="toc-h2 nav-item toc-entry">
  451. <a class="reference internal nav-link" href="#utility-code">
  452. Utility code
  453. </a>
  454. </li>
  455. <li class="toc-h2 nav-item toc-entry">
  456. <a class="reference internal nav-link" href="#example-casino-hmm">
  457. Example: Casino HMM
  458. </a>
  459. </li>
  460. <li class="toc-h2 nav-item toc-entry">
  461. <a class="reference internal nav-link" href="#sampling-from-the-joint">
  462. Sampling from the joint
  463. </a>
  464. <ul class="nav section-nav flex-column">
  465. <li class="toc-h3 nav-item toc-entry">
  466. <a class="reference internal nav-link" href="#numpy-version">
  467. Numpy version
  468. </a>
  469. </li>
  470. <li class="toc-h3 nav-item toc-entry">
  471. <a class="reference internal nav-link" href="#jax-version">
  472. JAX version
  473. </a>
  474. </li>
  475. <li class="toc-h3 nav-item toc-entry">
  476. <a class="reference internal nav-link" href="#check-correctness-by-computing-empirical-pairwise-statistics">
  477. Check correctness by computing empirical pairwise statistics
  478. </a>
  479. </li>
  480. </ul>
  481. </li>
  482. </ul>
  483. </nav>
  484. </div>
  485. </div>
  486. </div>
  487. <div>
  488. <div class="tex2jax_ignore mathjax_ignore section" id="hidden-markov-models">
  489. <span id="sec-hmm-ex"></span><h1>Hidden Markov Models<a class="headerlink" href="#hidden-markov-models" title="Permalink to this headline">¶</a></h1>
  490. <p>In this section, we introduce Hidden Markov Models (HMMs).</p>
  491. <div class="section" id="boilerplate">
  492. <h2>Boilerplate<a class="headerlink" href="#boilerplate" title="Permalink to this headline">¶</a></h2>
  493. <div class="cell docutils container">
  494. <div class="cell_input docutils container">
  495. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># Install necessary libraries</span>
  496. <span class="k">try</span><span class="p">:</span>
  497. <span class="kn">import</span> <span class="nn">jax</span>
  498. <span class="k">except</span><span class="p">:</span>
  499. <span class="c1"># For cuda version, see https://github.com/google/jax#installation</span>
  500. <span class="o">%</span><span class="k">pip</span> install --upgrade &quot;jax[cpu]&quot;
  501. <span class="kn">import</span> <span class="nn">jax</span>
  502. <span class="k">try</span><span class="p">:</span>
  503. <span class="kn">import</span> <span class="nn">jsl</span>
  504. <span class="k">except</span><span class="p">:</span>
  505. <span class="o">%</span><span class="k">pip</span> install git+https://github.com/probml/jsl
  506. <span class="kn">import</span> <span class="nn">jsl</span>
  507. <span class="k">try</span><span class="p">:</span>
  508. <span class="kn">import</span> <span class="nn">rich</span>
  509. <span class="k">except</span><span class="p">:</span>
  510. <span class="o">%</span><span class="k">pip</span> install rich
  511. <span class="kn">import</span> <span class="nn">rich</span>
  512. </pre></div>
  513. </div>
  514. </div>
  515. </div>
  516. <div class="cell docutils container">
  517. <div class="cell_input docutils container">
  518. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># Import standard libraries</span>
  519. <span class="kn">import</span> <span class="nn">abc</span>
  520. <span class="kn">from</span> <span class="nn">dataclasses</span> <span class="kn">import</span> <span class="n">dataclass</span>
  521. <span class="kn">import</span> <span class="nn">functools</span>
  522. <span class="kn">import</span> <span class="nn">itertools</span>
  523. <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span><span class="p">,</span> <span class="n">Callable</span><span class="p">,</span> <span class="n">NamedTuple</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Union</span><span class="p">,</span> <span class="n">Tuple</span>
  524. <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span>
  525. <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
  526. <span class="kn">import</span> <span class="nn">jax</span>
  527. <span class="kn">import</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span>
  528. <span class="kn">from</span> <span class="nn">jax</span> <span class="kn">import</span> <span class="n">lax</span><span class="p">,</span> <span class="n">vmap</span><span class="p">,</span> <span class="n">jit</span><span class="p">,</span> <span class="n">grad</span>
  529. <span class="kn">from</span> <span class="nn">jax.scipy.special</span> <span class="kn">import</span> <span class="n">logit</span>
  530. <span class="kn">from</span> <span class="nn">jax.nn</span> <span class="kn">import</span> <span class="n">softmax</span>
  531. <span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">partial</span>
  532. <span class="kn">from</span> <span class="nn">jax.random</span> <span class="kn">import</span> <span class="n">PRNGKey</span><span class="p">,</span> <span class="n">split</span>
  533. <span class="kn">import</span> <span class="nn">inspect</span>
  534. <span class="kn">import</span> <span class="nn">inspect</span> <span class="k">as</span> <span class="nn">py_inspect</span>
  535. <span class="kn">from</span> <span class="nn">rich</span> <span class="kn">import</span> <span class="n">inspect</span> <span class="k">as</span> <span class="n">r_inspect</span>
  536. <span class="kn">from</span> <span class="nn">rich</span> <span class="kn">import</span> <span class="nb">print</span> <span class="k">as</span> <span class="n">r_print</span>
  537. <span class="k">def</span> <span class="nf">print_source</span><span class="p">(</span><span class="n">fname</span><span class="p">):</span>
  538. <span class="n">r_print</span><span class="p">(</span><span class="n">py_inspect</span><span class="o">.</span><span class="n">getsource</span><span class="p">(</span><span class="n">fname</span><span class="p">))</span>
  539. </pre></div>
  540. </div>
  541. </div>
  542. </div>
  543. </div>
  544. <div class="section" id="utility-code">
  545. <h2>Utility code<a class="headerlink" href="#utility-code" title="Permalink to this headline">¶</a></h2>
  546. <div class="cell docutils container">
  547. <div class="cell_input docutils container">
  548. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">normalize</span><span class="p">(</span><span class="n">u</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-15</span><span class="p">):</span>
  549. <span class="sd">&#39;&#39;&#39;</span>
  550. <span class="sd"> Normalizes the values within the axis in a way that they sum up to 1.</span>
  551. <span class="sd"> Parameters</span>
  552. <span class="sd"> ----------</span>
  553. <span class="sd"> u : array</span>
  554. <span class="sd"> axis : int</span>
  555. <span class="sd"> eps : float</span>
  556. <span class="sd"> Threshold for the alpha values</span>
  557. <span class="sd"> Returns</span>
  558. <span class="sd"> -------</span>
  559. <span class="sd"> * array</span>
  560. <span class="sd"> Normalized version of the given matrix</span>
  561. <span class="sd"> * array(seq_len, n_hidden) :</span>
  562. <span class="sd"> The values of the normalizer</span>
  563. <span class="sd"> &#39;&#39;&#39;</span>
  564. <span class="n">u</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">u</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">jnp</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">u</span> <span class="o">&lt;</span> <span class="n">eps</span><span class="p">,</span> <span class="n">eps</span><span class="p">,</span> <span class="n">u</span><span class="p">))</span>
  565. <span class="n">c</span> <span class="o">=</span> <span class="n">u</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="n">axis</span><span class="p">)</span>
  566. <span class="n">c</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">c</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">c</span><span class="p">)</span>
  567. <span class="k">return</span> <span class="n">u</span> <span class="o">/</span> <span class="n">c</span><span class="p">,</span> <span class="n">c</span>
  568. </pre></div>
  569. </div>
  570. </div>
  571. </div>
  572. </div>
  573. <div class="section" id="example-casino-hmm">
  574. <span id="sec-casino-ex"></span><h2>Example: Casino HMM<a class="headerlink" href="#example-casino-hmm" title="Permalink to this headline">¶</a></h2>
  575. <p>We first create the “Ocassionally dishonest casino” model from <span id="id1">[<a class="reference internal" href="../../bib.html#id3" title="R. Durbin, S. Eddy, A. Krogh, and G. Mitchison. Biological Sequence Analysis: Probabilistic Models of Proteins and Nucleic Acids. Cambridge University Press, 1998.">DEKM98</a>]</span>.</p>
  576. <div class="figure align-default" id="casino-fig">
  577. <a class="reference internal image-reference" href="../../_images/casino.png"><img alt="../../_images/casino.png" src="../../_images/casino.png" style="width: 208.5px; height: 142.5px;" /></a>
  578. <p class="caption"><span class="caption-number">Fig. 7 </span><span class="caption-text">Illustration of the casino HMM.</span><a class="headerlink" href="#casino-fig" title="Permalink to this image">¶</a></p>
  579. </div>
  580. <p>There are 2 hidden states, each of which emit 6 possible observations.</p>
  581. <div class="cell docutils container">
  582. <div class="cell_input docutils container">
  583. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># state transition matrix</span>
  584. <span class="n">A</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span>
  585. <span class="p">[</span><span class="mf">0.95</span><span class="p">,</span> <span class="mf">0.05</span><span class="p">],</span>
  586. <span class="p">[</span><span class="mf">0.10</span><span class="p">,</span> <span class="mf">0.90</span><span class="p">]</span>
  587. <span class="p">])</span>
  588. <span class="c1"># observation matrix</span>
  589. <span class="n">B</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span>
  590. <span class="p">[</span><span class="mi">1</span><span class="o">/</span><span class="mi">6</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">6</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">6</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">6</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">6</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">6</span><span class="p">],</span> <span class="c1"># fair die</span>
  591. <span class="p">[</span><span class="mi">1</span><span class="o">/</span><span class="mi">10</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">10</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">10</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">10</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">10</span><span class="p">,</span> <span class="mi">5</span><span class="o">/</span><span class="mi">10</span><span class="p">]</span> <span class="c1"># loaded die</span>
  592. <span class="p">])</span>
  593. <span class="n">pi</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">normalize</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">]))</span>
  594. <span class="n">pi</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">pi</span><span class="p">)</span>
  595. <span class="p">(</span><span class="n">nstates</span><span class="p">,</span> <span class="n">nobs</span><span class="p">)</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">B</span><span class="p">)</span>
  596. </pre></div>
  597. </div>
  598. </div>
  599. <div class="cell_output docutils container">
  600. <div class="output stderr highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
  601. </pre></div>
  602. </div>
  603. </div>
  604. </div>
  605. <p>Let’s make a little data structure to store all the parameters.
  606. We use NamedTuple rather than dataclass, since we assume these are immutable.
  607. (Also, standard python dataclass does not work well with JAX, which requires parameters to be
  608. pytrees, as discussed in <a class="reference external" href="https://github.com/google/jax/issues/2371">https://github.com/google/jax/issues/2371</a>).</p>
  609. <div class="cell docutils container">
  610. <div class="cell_input docutils container">
  611. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">Array</span> <span class="o">=</span> <span class="n">Union</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">,</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">]</span>
  612. <span class="k">class</span> <span class="nc">HMM</span><span class="p">(</span><span class="n">NamedTuple</span><span class="p">):</span>
  613. <span class="n">trans_mat</span><span class="p">:</span> <span class="n">Array</span> <span class="c1"># A : (n_states, n_states)</span>
  614. <span class="n">obs_mat</span><span class="p">:</span> <span class="n">Array</span> <span class="c1"># B : (n_states, n_obs)</span>
  615. <span class="n">init_dist</span><span class="p">:</span> <span class="n">Array</span> <span class="c1"># pi : (n_states)</span>
  616. <span class="n">params_np</span> <span class="o">=</span> <span class="n">HMM</span><span class="p">(</span><span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">pi</span><span class="p">)</span>
  617. <span class="nb">print</span><span class="p">(</span><span class="n">params_np</span><span class="p">)</span>
  618. <span class="nb">print</span><span class="p">(</span><span class="nb">type</span><span class="p">(</span><span class="n">params_np</span><span class="o">.</span><span class="n">trans_mat</span><span class="p">))</span>
  619. <span class="n">params</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">tree_map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="n">params_np</span><span class="p">)</span>
  620. <span class="nb">print</span><span class="p">(</span><span class="n">params</span><span class="p">)</span>
  621. <span class="nb">print</span><span class="p">(</span><span class="nb">type</span><span class="p">(</span><span class="n">params</span><span class="o">.</span><span class="n">trans_mat</span><span class="p">))</span>
  622. </pre></div>
  623. </div>
  624. </div>
  625. <div class="cell_output docutils container">
  626. <div class="output stream highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>HMM(trans_mat=array([[0.95, 0.05],
  627. [0.1 , 0.9 ]]), obs_mat=array([[0.16666667, 0.16666667, 0.16666667, 0.16666667, 0.16666667,
  628. 0.16666667],
  629. [0.1 , 0.1 , 0.1 , 0.1 , 0.1 ,
  630. 0.5 ]]), init_dist=array([0.5, 0.5], dtype=float32))
  631. &lt;class &#39;numpy.ndarray&#39;&gt;
  632. HMM(trans_mat=DeviceArray([[0.95, 0.05],
  633. [0.1 , 0.9 ]], dtype=float32), obs_mat=DeviceArray([[0.16666667, 0.16666667, 0.16666667, 0.16666667, 0.16666667,
  634. 0.16666667],
  635. [0.1 , 0.1 , 0.1 , 0.1 , 0.1 ,
  636. 0.5 ]], dtype=float32), init_dist=DeviceArray([0.5, 0.5], dtype=float32))
  637. &lt;class &#39;jaxlib.xla_extension.DeviceArray&#39;&gt;
  638. </pre></div>
  639. </div>
  640. </div>
  641. </div>
  642. </div>
  643. <div class="section" id="sampling-from-the-joint">
  644. <h2>Sampling from the joint<a class="headerlink" href="#sampling-from-the-joint" title="Permalink to this headline">¶</a></h2>
  645. <p>Let’s write code to sample from this model.</p>
  646. <div class="section" id="numpy-version">
  647. <h3>Numpy version<a class="headerlink" href="#numpy-version" title="Permalink to this headline">¶</a></h3>
  648. <p>First we code it in numpy using a for loop.</p>
  649. <div class="cell docutils container">
  650. <div class="cell_input docutils container">
  651. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">hmm_sample_np</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="mi">0</span><span class="p">):</span>
  652. <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">(</span><span class="n">random_state</span><span class="p">)</span>
  653. <span class="n">trans_mat</span><span class="p">,</span> <span class="n">obs_mat</span><span class="p">,</span> <span class="n">init_dist</span> <span class="o">=</span> <span class="n">params</span><span class="o">.</span><span class="n">trans_mat</span><span class="p">,</span> <span class="n">params</span><span class="o">.</span><span class="n">obs_mat</span><span class="p">,</span> <span class="n">params</span><span class="o">.</span><span class="n">init_dist</span>
  654. <span class="n">n_states</span><span class="p">,</span> <span class="n">n_obs</span> <span class="o">=</span> <span class="n">obs_mat</span><span class="o">.</span><span class="n">shape</span>
  655. <span class="n">state_seq</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="nb">int</span><span class="p">)</span>
  656. <span class="n">obs_seq</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="nb">int</span><span class="p">)</span>
  657. <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">seq_len</span><span class="p">):</span>
  658. <span class="k">if</span> <span class="n">t</span><span class="o">==</span><span class="mi">0</span><span class="p">:</span>
  659. <span class="n">zt</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="n">n_states</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="n">init_dist</span><span class="p">)</span>
  660. <span class="k">else</span><span class="p">:</span>
  661. <span class="n">zt</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="n">n_states</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="n">trans_mat</span><span class="p">[</span><span class="n">zt</span><span class="p">])</span>
  662. <span class="n">yt</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="n">n_obs</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="n">obs_mat</span><span class="p">[</span><span class="n">zt</span><span class="p">])</span>
  663. <span class="n">state_seq</span><span class="p">[</span><span class="n">t</span><span class="p">]</span> <span class="o">=</span> <span class="n">zt</span>
  664. <span class="n">obs_seq</span><span class="p">[</span><span class="n">t</span><span class="p">]</span> <span class="o">=</span> <span class="n">yt</span>
  665. <span class="k">return</span> <span class="n">state_seq</span><span class="p">,</span> <span class="n">obs_seq</span>
  666. </pre></div>
  667. </div>
  668. </div>
  669. </div>
  670. <div class="cell docutils container">
  671. <div class="cell_input docutils container">
  672. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">seq_len</span> <span class="o">=</span> <span class="mi">100</span>
  673. <span class="n">state_seq</span><span class="p">,</span> <span class="n">obs_seq</span> <span class="o">=</span> <span class="n">hmm_sample_np</span><span class="p">(</span><span class="n">params_np</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
  674. <span class="nb">print</span><span class="p">(</span><span class="n">state_seq</span><span class="p">)</span>
  675. <span class="nb">print</span><span class="p">(</span><span class="n">obs_seq</span><span class="p">)</span>
  676. </pre></div>
  677. </div>
  678. </div>
  679. <div class="cell_output docutils container">
  680. <div class="output stream highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0
  681. 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  682. 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
  683. [4 1 0 2 3 4 5 4 3 1 5 4 5 0 5 2 5 3 5 4 5 5 4 2 1 4 1 0 0 4 2 2 3 3 3 0 4
  684. 0 2 4 3 2 5 5 3 5 3 1 3 3 3 2 3 5 5 0 4 4 5 0 0 1 3 5 1 5 0 1 2 4 0 0 0 4
  685. 0 5 1 4 3 5 4 5 0 2 3 5 2 4 1 2 1 0 4 3 5 0 4 5 1 5]
  686. </pre></div>
  687. </div>
  688. </div>
  689. </div>
  690. </div>
  691. <div class="section" id="jax-version">
  692. <h3>JAX version<a class="headerlink" href="#jax-version" title="Permalink to this headline">¶</a></h3>
  693. <p>Now let’s write a JAX version using jax.lax.scan (for the inter-dependent states) and vmap (for the observations).
  694. This is harder to read than the numpy version, but faster.</p>
  695. <div class="cell docutils container">
  696. <div class="cell_input docutils container">
  697. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1">#@partial(jit, static_argnums=(1,))</span>
  698. <span class="k">def</span> <span class="nf">markov_chain_sample</span><span class="p">(</span><span class="n">rng_key</span><span class="p">,</span> <span class="n">init_dist</span><span class="p">,</span> <span class="n">trans_mat</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">):</span>
  699. <span class="n">n_states</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">init_dist</span><span class="p">)</span>
  700. <span class="k">def</span> <span class="nf">draw_state</span><span class="p">(</span><span class="n">prev_state</span><span class="p">,</span> <span class="n">key</span><span class="p">):</span>
  701. <span class="n">state</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">n_states</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="n">trans_mat</span><span class="p">[</span><span class="n">prev_state</span><span class="p">])</span>
  702. <span class="k">return</span> <span class="n">state</span><span class="p">,</span> <span class="n">state</span>
  703. <span class="n">rng_key</span><span class="p">,</span> <span class="n">rng_state</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">rng_key</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
  704. <span class="n">keys</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">rng_state</span><span class="p">,</span> <span class="n">seq_len</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
  705. <span class="n">initial_state</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="n">rng_key</span><span class="p">,</span> <span class="n">n_states</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="n">init_dist</span><span class="p">)</span>
  706. <span class="n">final_state</span><span class="p">,</span> <span class="n">states</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">lax</span><span class="o">.</span><span class="n">scan</span><span class="p">(</span><span class="n">draw_state</span><span class="p">,</span> <span class="n">initial_state</span><span class="p">,</span> <span class="n">keys</span><span class="p">)</span>
  707. <span class="n">state_seq</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">initial_state</span><span class="p">]),</span> <span class="n">states</span><span class="p">)</span>
  708. <span class="k">return</span> <span class="n">state_seq</span>
  709. </pre></div>
  710. </div>
  711. </div>
  712. </div>
  713. <div class="cell docutils container">
  714. <div class="cell_input docutils container">
  715. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1">#@partial(jit, static_argnums=(1,))</span>
  716. <span class="k">def</span> <span class="nf">hmm_sample</span><span class="p">(</span><span class="n">rng_key</span><span class="p">,</span> <span class="n">params</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">):</span>
  717. <span class="n">trans_mat</span><span class="p">,</span> <span class="n">obs_mat</span><span class="p">,</span> <span class="n">init_dist</span> <span class="o">=</span> <span class="n">params</span><span class="o">.</span><span class="n">trans_mat</span><span class="p">,</span> <span class="n">params</span><span class="o">.</span><span class="n">obs_mat</span><span class="p">,</span> <span class="n">params</span><span class="o">.</span><span class="n">init_dist</span>
  718. <span class="n">n_states</span><span class="p">,</span> <span class="n">n_obs</span> <span class="o">=</span> <span class="n">obs_mat</span><span class="o">.</span><span class="n">shape</span>
  719. <span class="n">rng_key</span><span class="p">,</span> <span class="n">rng_obs</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">rng_key</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
  720. <span class="n">state_seq</span> <span class="o">=</span> <span class="n">markov_chain_sample</span><span class="p">(</span><span class="n">rng_key</span><span class="p">,</span> <span class="n">init_dist</span><span class="p">,</span> <span class="n">trans_mat</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">)</span>
  721. <span class="k">def</span> <span class="nf">draw_obs</span><span class="p">(</span><span class="n">z</span><span class="p">,</span> <span class="n">key</span><span class="p">):</span>
  722. <span class="n">obs</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">n_obs</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="n">obs_mat</span><span class="p">[</span><span class="n">z</span><span class="p">])</span>
  723. <span class="k">return</span> <span class="n">obs</span>
  724. <span class="n">keys</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">rng_obs</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">)</span>
  725. <span class="n">obs_seq</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">vmap</span><span class="p">(</span><span class="n">draw_obs</span><span class="p">,</span> <span class="n">in_axes</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">))(</span><span class="n">state_seq</span><span class="p">,</span> <span class="n">keys</span><span class="p">)</span>
  726. <span class="k">return</span> <span class="n">state_seq</span><span class="p">,</span> <span class="n">obs_seq</span>
  727. </pre></div>
  728. </div>
  729. </div>
  730. </div>
  731. <div class="cell docutils container">
  732. <div class="cell_input docutils container">
  733. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1">#@partial(jit, static_argnums=(1,))</span>
  734. <span class="k">def</span> <span class="nf">hmm_sample2</span><span class="p">(</span><span class="n">rng_key</span><span class="p">,</span> <span class="n">params</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">):</span>
  735. <span class="n">trans_mat</span><span class="p">,</span> <span class="n">obs_mat</span><span class="p">,</span> <span class="n">init_dist</span> <span class="o">=</span> <span class="n">params</span><span class="o">.</span><span class="n">trans_mat</span><span class="p">,</span> <span class="n">params</span><span class="o">.</span><span class="n">obs_mat</span><span class="p">,</span> <span class="n">params</span><span class="o">.</span><span class="n">init_dist</span>
  736. <span class="n">n_states</span><span class="p">,</span> <span class="n">n_obs</span> <span class="o">=</span> <span class="n">obs_mat</span><span class="o">.</span><span class="n">shape</span>
  737. <span class="k">def</span> <span class="nf">draw_state</span><span class="p">(</span><span class="n">prev_state</span><span class="p">,</span> <span class="n">key</span><span class="p">):</span>
  738. <span class="n">state</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">n_states</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="n">trans_mat</span><span class="p">[</span><span class="n">prev_state</span><span class="p">])</span>
  739. <span class="k">return</span> <span class="n">state</span><span class="p">,</span> <span class="n">state</span>
  740. <span class="n">rng_key</span><span class="p">,</span> <span class="n">rng_state</span><span class="p">,</span> <span class="n">rng_obs</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">rng_key</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
  741. <span class="n">keys</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">rng_state</span><span class="p">,</span> <span class="n">seq_len</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
  742. <span class="n">initial_state</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="n">rng_key</span><span class="p">,</span> <span class="n">n_states</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="n">init_dist</span><span class="p">)</span>
  743. <span class="n">final_state</span><span class="p">,</span> <span class="n">states</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">lax</span><span class="o">.</span><span class="n">scan</span><span class="p">(</span><span class="n">draw_state</span><span class="p">,</span> <span class="n">initial_state</span><span class="p">,</span> <span class="n">keys</span><span class="p">)</span>
  744. <span class="n">state_seq</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">initial_state</span><span class="p">]),</span> <span class="n">states</span><span class="p">)</span>
  745. <span class="k">def</span> <span class="nf">draw_obs</span><span class="p">(</span><span class="n">z</span><span class="p">,</span> <span class="n">key</span><span class="p">):</span>
  746. <span class="n">obs</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">n_obs</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="n">obs_mat</span><span class="p">[</span><span class="n">z</span><span class="p">])</span>
  747. <span class="k">return</span> <span class="n">obs</span>
  748. <span class="n">keys</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">rng_obs</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">)</span>
  749. <span class="n">obs_seq</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">vmap</span><span class="p">(</span><span class="n">draw_obs</span><span class="p">,</span> <span class="n">in_axes</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">))(</span><span class="n">state_seq</span><span class="p">,</span> <span class="n">keys</span><span class="p">)</span>
  750. <span class="k">return</span> <span class="n">state_seq</span><span class="p">,</span> <span class="n">obs_seq</span>
  751. </pre></div>
  752. </div>
  753. </div>
  754. </div>
  755. <div class="cell docutils container">
  756. <div class="cell_input docutils container">
  757. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">key</span> <span class="o">=</span> <span class="n">PRNGKey</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span>
  758. <span class="n">seq_len</span> <span class="o">=</span> <span class="mi">100</span>
  759. <span class="n">state_seq</span><span class="p">,</span> <span class="n">obs_seq</span> <span class="o">=</span> <span class="n">hmm_sample</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">params</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">)</span>
  760. <span class="nb">print</span><span class="p">(</span><span class="n">state_seq</span><span class="p">)</span>
  761. <span class="nb">print</span><span class="p">(</span><span class="n">obs_seq</span><span class="p">)</span>
  762. </pre></div>
  763. </div>
  764. </div>
  765. <div class="cell_output docutils container">
  766. <div class="output stream highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>[1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  767. 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1
  768. 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
  769. [5 5 2 2 0 0 0 1 3 3 2 2 5 1 5 1 0 2 2 4 2 5 1 5 5 0 0 4 2 4 3 2 3 4 1 0 5
  770. 2 2 2 1 4 3 2 2 2 4 1 0 3 5 2 5 1 4 2 5 2 5 0 5 4 4 4 2 2 0 4 5 2 2 0 1 5
  771. 1 3 4 5 1 5 0 5 1 5 1 2 4 5 3 4 5 4 0 4 0 2 4 5 3 3]
  772. </pre></div>
  773. </div>
  774. </div>
  775. </div>
  776. </div>
  777. <div class="section" id="check-correctness-by-computing-empirical-pairwise-statistics">
  778. <h3>Check correctness by computing empirical pairwise statistics<a class="headerlink" href="#check-correctness-by-computing-empirical-pairwise-statistics" title="Permalink to this headline">¶</a></h3>
  779. <p>We will compute the number of i-&gt;j transitions, and check that it is close to the true
  780. A[i,j] transition probabilites.</p>
  781. <div class="cell docutils container">
  782. <div class="cell_input docutils container">
  783. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">collections</span>
  784. <span class="k">def</span> <span class="nf">compute_counts</span><span class="p">(</span><span class="n">state_seq</span><span class="p">,</span> <span class="n">nstates</span><span class="p">):</span>
  785. <span class="n">wseq</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">state_seq</span><span class="p">)</span>
  786. <span class="n">word_pairs</span> <span class="o">=</span> <span class="p">[</span><span class="n">pair</span> <span class="k">for</span> <span class="n">pair</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">wseq</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">wseq</span><span class="p">[</span><span class="mi">1</span><span class="p">:])]</span>
  787. <span class="n">counter_pairs</span> <span class="o">=</span> <span class="n">collections</span><span class="o">.</span><span class="n">Counter</span><span class="p">(</span><span class="n">word_pairs</span><span class="p">)</span>
  788. <span class="n">counts</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">nstates</span><span class="p">,</span> <span class="n">nstates</span><span class="p">))</span>
  789. <span class="k">for</span> <span class="p">(</span><span class="n">k</span><span class="p">,</span><span class="n">v</span><span class="p">)</span> <span class="ow">in</span> <span class="n">counter_pairs</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
  790. <span class="n">counts</span><span class="p">[</span><span class="n">k</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">k</span><span class="p">[</span><span class="mi">1</span><span class="p">]]</span> <span class="o">=</span> <span class="n">v</span>
  791. <span class="k">return</span> <span class="n">counts</span>
  792. <span class="k">def</span> <span class="nf">normalize_counts</span><span class="p">(</span><span class="n">counts</span><span class="p">):</span>
  793. <span class="n">ncounts</span> <span class="o">=</span> <span class="n">vmap</span><span class="p">(</span><span class="k">lambda</span> <span class="n">v</span><span class="p">:</span> <span class="n">normalize</span><span class="p">(</span><span class="n">v</span><span class="p">)[</span><span class="mi">0</span><span class="p">],</span> <span class="n">in_axes</span><span class="o">=</span><span class="mi">0</span><span class="p">)(</span><span class="n">counts</span><span class="p">)</span>
  794. <span class="k">return</span> <span class="n">ncounts</span>
  795. <span class="n">init_dist</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mf">1.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">])</span>
  796. <span class="n">trans_mat</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="mf">0.7</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">]])</span>
  797. <span class="n">rng_key</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">PRNGKey</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
  798. <span class="n">seq_len</span> <span class="o">=</span> <span class="mi">500</span>
  799. <span class="n">state_seq</span> <span class="o">=</span> <span class="n">markov_chain_sample</span><span class="p">(</span><span class="n">rng_key</span><span class="p">,</span> <span class="n">init_dist</span><span class="p">,</span> <span class="n">trans_mat</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">)</span>
  800. <span class="nb">print</span><span class="p">(</span><span class="n">state_seq</span><span class="p">)</span>
  801. <span class="n">counts</span> <span class="o">=</span> <span class="n">compute_counts</span><span class="p">(</span><span class="n">state_seq</span><span class="p">,</span> <span class="n">nstates</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
  802. <span class="nb">print</span><span class="p">(</span><span class="n">counts</span><span class="p">)</span>
  803. <span class="n">trans_mat_empirical</span> <span class="o">=</span> <span class="n">normalize_counts</span><span class="p">(</span><span class="n">counts</span><span class="p">)</span>
  804. <span class="nb">print</span><span class="p">(</span><span class="n">trans_mat_empirical</span><span class="p">)</span>
  805. <span class="k">assert</span> <span class="n">jnp</span><span class="o">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">trans_mat</span><span class="p">,</span> <span class="n">trans_mat_empirical</span><span class="p">,</span> <span class="n">atol</span><span class="o">=</span><span class="mf">1e-1</span><span class="p">)</span>
  806. </pre></div>
  807. </div>
  808. </div>
  809. <div class="cell_output docutils container">
  810. <div class="output stream highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>[0 0 1 1 1 1 0 0 1 1 1 0 1 0 0 1 1 1 1 0 1 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 1
  811. 1 0 0 0 0 1 0 1 0 0 0 0 1 0 0 1 1 0 1 1 0 1 1 0 1 1 1 0 0 1 1 0 1 0 0 1 0
  812. 1 0 0 0 0 0 0 0 0 0 0 1 1 1 1 0 0 0 1 1 1 0 1 0 0 0 0 1 0 0 0 0 1 1 0 0 0
  813. 0 0 1 1 1 1 1 1 0 0 0 1 1 0 0 0 0 0 1 0 0 0 1 0 1 1 0 1 1 0 0 0 0 0 0 1 0
  814. 0 0 0 1 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 1 1 1 0 1 1 0 0 0 0 0 1 0 0 0 0
  815. 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 1 0 0 0 1 1 1 0 0 0 1 1 0 0 0
  816. 0 0 0 1 1 1 0 0 0 0 1 0 0 1 1 1 0 1 1 1 1 1 0 1 1 0 0 0 1 1 0 1 0 0 1 0 0
  817. 0 0 0 1 0 0 0 1 0 1 0 0 0 0 1 0 0 1 0 0 0 1 1 0 0 0 0 0 0 0 1 0 0 1 1 1 1
  818. 1 1 0 0 0 0 0 1 1 0 0 0 0 0 0 1 0 0 1 0 0 0 0 0 0 1 0 0 0 0 1 0 1 0 0 0 1
  819. 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 1 0 0 1 1 0 1 0 0 0
  820. 0 0 0 0 0 0 0 1 0 0 1 1 1 1 0 0 1 1 0 0 0 0 1 1 0 1 1 0 0 0 0 0 0 0 0 1 0
  821. 1 0 1 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0
  822. 0 0 0 0 1 0 0 1 1 0 1 1 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 0 1 1 0 0 1 1 0 0 1
  823. 1 0 0 0 0 0 0 0 1 0 0 1 1 0 0 0 0 1 1]
  824. [[244. 93.]
  825. [ 92. 70.]]
  826. [[0.7240356 0.27596438]
  827. [0.56790125 0.43209878]]
  828. </pre></div>
  829. </div>
  830. </div>
  831. </div>
  832. </div>
  833. </div>
  834. </div>
  835. <script type="text/x-thebe-config">
  836. {
  837. requestKernel: true,
  838. binderOptions: {
  839. repo: "binder-examples/jupyter-stacks-datascience",
  840. ref: "master",
  841. },
  842. codeMirrorConfig: {
  843. theme: "abcdef",
  844. mode: "python"
  845. },
  846. kernelOptions: {
  847. kernelName: "python3",
  848. path: "./chapters/hmm"
  849. },
  850. predefinedOutput: true
  851. }
  852. </script>
  853. <script>kernelName = 'python3'</script>
  854. </div>
  855. <!-- Previous / next buttons -->
  856. <div class='prev-next-area'>
  857. <a class='left-prev' id="prev-link" href="hmm_index.html" title="previous page">
  858. <i class="fas fa-angle-left"></i>
  859. <div class="prev-next-info">
  860. <p class="prev-next-subtitle">previous</p>
  861. <p class="prev-next-title">Inference in discrete SSMs</p>
  862. </div>
  863. </a>
  864. <a class='right-next' id="next-link" href="hmm_filter.html" title="next page">
  865. <div class="prev-next-info">
  866. <p class="prev-next-subtitle">next</p>
  867. <p class="prev-next-title">HMM filtering (forwards algorithm)</p>
  868. </div>
  869. <i class="fas fa-angle-right"></i>
  870. </a>
  871. </div>
  872. </div>
  873. </div>
  874. <footer class="footer">
  875. <p>
  876. By Kevin Murphy, Scott Linderman, et al.<br/>
  877. &copy; Copyright 2021.<br/>
  878. </p>
  879. </footer>
  880. </main>
  881. </div>
  882. </div>
  883. <script src="../../_static/js/index.be7d3bbb2ef33a8344ce.js"></script>
  884. </body>
  885. </html>