ssm_old.html 124 KB


  1. <!DOCTYPE html>
  2. <html>
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>What are State Space Models? &#8212; State Space Models: A Modern Approach</title>
  7. <link href="../_static/css/theme.css" rel="stylesheet">
  8. <link href="../_static/css/index.ff1ffe594081f20da1ef19478df9384b.css" rel="stylesheet">
  9. <link rel="stylesheet"
  10. href="../_static/vendor/fontawesome/5.13.0/css/all.min.css">
  11. <link rel="preload" as="font" type="font/woff2" crossorigin
  12. href="../_static/vendor/fontawesome/5.13.0/webfonts/fa-solid-900.woff2">
  13. <link rel="preload" as="font" type="font/woff2" crossorigin
  14. href="../_static/vendor/fontawesome/5.13.0/webfonts/fa-brands-400.woff2">
  15. <link rel="stylesheet" type="text/css" href="../_static/pygments.css" />
  16. <link rel="stylesheet" type="text/css" href="../_static/sphinx-book-theme.css?digest=c3fdc42140077d1ad13ad2f1588a4309" />
  17. <link rel="stylesheet" type="text/css" href="../_static/togglebutton.css" />
  18. <link rel="stylesheet" type="text/css" href="../_static/copybutton.css" />
  19. <link rel="stylesheet" type="text/css" href="../_static/mystnb.css" />
  20. <link rel="stylesheet" type="text/css" href="../_static/sphinx-thebe.css" />
  21. <link rel="stylesheet" type="text/css" href="../_static/panels-main.c949a650a448cc0ae9fd3441c0e17fb0.css" />
  22. <link rel="stylesheet" type="text/css" href="../_static/panels-variables.06eb56fa6e07937060861dad626602ad.css" />
  23. <link rel="preload" as="script" href="../_static/js/index.be7d3bbb2ef33a8344ce.js">
  24. <script data-url_root="../" id="documentation_options" src="../_static/documentation_options.js"></script>
  25. <script src="../_static/jquery.js"></script>
  26. <script src="../_static/underscore.js"></script>
  27. <script src="../_static/doctools.js"></script>
  28. <script src="../_static/clipboard.min.js"></script>
  29. <script src="../_static/copybutton.js"></script>
  30. <script>let toggleHintShow = 'Click to show';</script>
  31. <script>let toggleHintHide = 'Click to hide';</script>
  32. <script>let toggleOpenOnPrint = 'true';</script>
  33. <script src="../_static/togglebutton.js"></script>
  34. <script>var togglebuttonSelector = '.toggle, .admonition.dropdown, .tag_hide_input div.cell_input, .tag_hide-input div.cell_input, .tag_hide_output div.cell_output, .tag_hide-output div.cell_output, .tag_hide_cell.cell, .tag_hide-cell.cell';</script>
  35. <script src="../_static/sphinx-book-theme.d59cb220de22ca1c485ebbdc042f0030.js"></script>
  36. <script>const THEBE_JS_URL = "https://unpkg.com/thebe@0.8.2/lib/index.js"
  37. const thebe_selector = ".thebe,.cell"
  38. const thebe_selector_input = "pre"
  39. const thebe_selector_output = ".output, .cell_output"
  40. </script>
  41. <script async="async" src="../_static/sphinx-thebe.js"></script>
  42. <script>window.MathJax = {"options": {"processHtmlClass": "tex2jax_process|mathjax_process|math|output_area"}}</script>
  43. <script defer="defer" src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
  44. <link rel="index" title="Index" href="../genindex.html" />
  45. <link rel="search" title="Search" href="../search.html" />
  46. <meta name="viewport" content="width=device-width, initial-scale=1" />
  47. <meta name="docsearch:language" content="None">
  48. <!-- Google Analytics -->
  49. </head>
  50. <body data-spy="scroll" data-target="#bd-toc-nav" data-offset="80">
  51. <div class="container-fluid" id="banner"></div>
  52. <div class="container-xl">
  53. <div class="row">
  54. <div class="col-12 col-md-3 bd-sidebar site-navigation show" id="site-navigation">
  55. <div class="navbar-brand-box">
  56. <a class="navbar-brand text-wrap" href="../index.html">
  57. <h1 class="site-logo" id="site-title">State Space Models: A Modern Approach</h1>
  58. </a>
  59. </div><form class="bd-search d-flex align-items-center" action="../search.html" method="get">
  60. <i class="icon fas fa-search"></i>
  61. <input type="search" class="form-control" name="q" id="search-input" placeholder="Search this book..." aria-label="Search this book..." autocomplete="off" >
  62. </form><nav class="bd-links" id="bd-docs-nav" aria-label="Main">
  63. <div class="bd-toc-item active">
  64. <ul class="nav bd-sidenav">
  65. <li class="toctree-l1">
  66. <a class="reference internal" href="../root.html">
  67. State Space Models: A Modern Approach
  68. </a>
  69. </li>
  70. </ul>
  71. <ul class="nav bd-sidenav">
  72. <li class="toctree-l1">
  73. <a class="reference internal" href="../chapters/scratch.html">
  74. Scratchpad
  75. </a>
  76. </li>
  77. <li class="toctree-l1 has-children">
  78. <a class="reference internal" href="../chapters/ssm/ssm_index.html">
  79. State Space Models
  80. </a>
  81. <input class="toctree-checkbox" id="toctree-checkbox-1" name="toctree-checkbox-1" type="checkbox"/>
  82. <label for="toctree-checkbox-1">
  83. <i class="fas fa-chevron-down">
  84. </i>
  85. </label>
  86. <ul>
  87. <li class="toctree-l2">
  88. <a class="reference internal" href="../chapters/ssm/ssm_intro.html">
  89. What are State Space Models?
  90. </a>
  91. </li>
  92. <li class="toctree-l2">
  93. <a class="reference internal" href="../chapters/ssm/hmm.html">
  94. Hidden Markov Models
  95. </a>
  96. </li>
  97. <li class="toctree-l2">
  98. <a class="reference internal" href="../chapters/ssm/lds.html">
  99. Linear Gaussian SSMs
  100. </a>
  101. </li>
  102. <li class="toctree-l2">
  103. <a class="reference internal" href="../chapters/ssm/nlds.html">
  104. Nonlinear Gaussian SSMs
  105. </a>
  106. </li>
  107. <li class="toctree-l2">
  108. <a class="reference internal" href="../chapters/ssm/inference.html">
  109. Inferential goals
  110. </a>
  111. </li>
  112. </ul>
  113. </li>
  114. <li class="toctree-l1 has-children">
  115. <a class="reference internal" href="../chapters/hmm/hmm_index.html">
  116. Hidden Markov Models
  117. </a>
  118. <input class="toctree-checkbox" id="toctree-checkbox-2" name="toctree-checkbox-2" type="checkbox"/>
  119. <label for="toctree-checkbox-2">
  120. <i class="fas fa-chevron-down">
  121. </i>
  122. </label>
  123. <ul>
  124. <li class="toctree-l2">
  125. <a class="reference internal" href="../chapters/hmm/hmm.html">
  126. Hidden Markov Models
  127. </a>
  128. </li>
  129. <li class="toctree-l2">
  130. <a class="reference internal" href="../chapters/hmm/hmm_filter.html">
  131. HMM filtering (forwards algorithm)
  132. </a>
  133. </li>
  134. <li class="toctree-l2">
  135. <a class="reference internal" href="../chapters/hmm/hmm_smoother.html">
  136. HMM smoothing (forwards-backwards algorithm)
  137. </a>
  138. </li>
  139. <li class="toctree-l2">
  140. <a class="reference internal" href="../chapters/hmm/hmm_viterbi.html">
  141. Viterbi algorithm
  142. </a>
  143. </li>
  144. <li class="toctree-l2">
  145. <a class="reference internal" href="../chapters/hmm/hmm_parallel.html">
  146. Parallel HMM smoothing
  147. </a>
  148. </li>
  149. <li class="toctree-l2">
  150. <a class="reference internal" href="../chapters/hmm/hmm_sampling.html">
  151. Forwards-filtering backwards-sampling algorithm
  152. </a>
  153. </li>
  154. </ul>
  155. </li>
  156. <li class="toctree-l1 has-children">
  157. <a class="reference internal" href="../chapters/lgssm/lgssm_index.html">
  158. Inference in linear-Gaussian SSMs
  159. </a>
  160. <input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
  161. <label for="toctree-checkbox-3">
  162. <i class="fas fa-chevron-down">
  163. </i>
  164. </label>
  165. <ul>
  166. <li class="toctree-l2">
  167. <a class="reference internal" href="../chapters/lgssm/kalman_filter.html">
  168. Kalman filtering
  169. </a>
  170. </li>
  171. <li class="toctree-l2">
  172. <a class="reference internal" href="../chapters/lgssm/kalman_smoother.html">
  173. Kalman (RTS) smoother
  174. </a>
  175. </li>
  176. <li class="toctree-l2">
  177. <a class="reference internal" href="../chapters/lgssm/kalman_parallel.html">
  178. Parallel Kalman Smoother
  179. </a>
  180. </li>
  181. <li class="toctree-l2">
  182. <a class="reference internal" href="../chapters/lgssm/kalman_sampling.html">
  183. Forwards-filtering backwards sampling
  184. </a>
  185. </li>
  186. </ul>
  187. </li>
  188. <li class="toctree-l1 has-children">
  189. <a class="reference internal" href="../chapters/extended/extended_index.html">
  190. Extended (linearized) methods
  191. </a>
  192. <input class="toctree-checkbox" id="toctree-checkbox-4" name="toctree-checkbox-4" type="checkbox"/>
  193. <label for="toctree-checkbox-4">
  194. <i class="fas fa-chevron-down">
  195. </i>
  196. </label>
  197. <ul>
  198. <li class="toctree-l2">
  199. <a class="reference internal" href="../chapters/extended/extended_filter.html">
  200. Extended Kalman filtering
  201. </a>
  202. </li>
  203. <li class="toctree-l2">
  204. <a class="reference internal" href="../chapters/extended/extended_smoother.html">
  205. Extended Kalman smoother
  206. </a>
  207. </li>
  208. <li class="toctree-l2">
  209. <a class="reference internal" href="../chapters/extended/extended_parallel.html">
  210. Parallel extended Kalman smoothing
  211. </a>
  212. </li>
  213. </ul>
  214. </li>
  215. <li class="toctree-l1 has-children">
  216. <a class="reference internal" href="../chapters/unscented/unscented_index.html">
  217. Unscented methods
  218. </a>
  219. <input class="toctree-checkbox" id="toctree-checkbox-5" name="toctree-checkbox-5" type="checkbox"/>
  220. <label for="toctree-checkbox-5">
  221. <i class="fas fa-chevron-down">
  222. </i>
  223. </label>
  224. <ul>
  225. <li class="toctree-l2">
  226. <a class="reference internal" href="../chapters/unscented/unscented_filter.html">
  227. Unscented filtering
  228. </a>
  229. </li>
  230. <li class="toctree-l2">
  231. <a class="reference internal" href="../chapters/unscented/unscented_smoother.html">
  232. Unscented smoothing
  233. </a>
  234. </li>
  235. </ul>
  236. </li>
  237. <li class="toctree-l1">
  238. <a class="reference internal" href="../chapters/quadrature/quadrature_index.html">
  239. Quadrature and cubature methods
  240. </a>
  241. </li>
  242. <li class="toctree-l1">
  243. <a class="reference internal" href="../chapters/postlin/postlin_index.html">
  244. Posterior linearization
  245. </a>
  246. </li>
  247. <li class="toctree-l1">
  248. <a class="reference internal" href="../chapters/adf/adf_index.html">
  249. Assumed Density Filtering
  250. </a>
  251. </li>
  252. <li class="toctree-l1">
  253. <a class="reference internal" href="../chapters/vi/vi_index.html">
  254. Variational inference
  255. </a>
  256. </li>
  257. <li class="toctree-l1">
  258. <a class="reference internal" href="../chapters/pf/pf_index.html">
  259. Particle filtering
  260. </a>
  261. </li>
  262. <li class="toctree-l1">
  263. <a class="reference internal" href="../chapters/smc/smc_index.html">
  264. Sequential Monte Carlo
  265. </a>
  266. </li>
  267. <li class="toctree-l1 has-children">
  268. <a class="reference internal" href="../chapters/learning/learning_index.html">
  269. Offline parameter estimation (learning)
  270. </a>
  271. <input class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" type="checkbox"/>
  272. <label for="toctree-checkbox-6">
  273. <i class="fas fa-chevron-down">
  274. </i>
  275. </label>
  276. <ul>
  277. <li class="toctree-l2">
  278. <a class="reference internal" href="../chapters/learning/em.html">
  279. Expectation Maximization (EM)
  280. </a>
  281. </li>
  282. <li class="toctree-l2">
  283. <a class="reference internal" href="../chapters/learning/sgd.html">
  284. Stochastic Gradient Descent (SGD)
  285. </a>
  286. </li>
  287. <li class="toctree-l2">
  288. <a class="reference internal" href="../chapters/learning/vb.html">
  289. Variational Bayes (VB)
  290. </a>
  291. </li>
  292. <li class="toctree-l2">
  293. <a class="reference internal" href="../chapters/learning/mcmc.html">
  294. Markov Chain Monte Carlo (MCMC)
  295. </a>
  296. </li>
  297. </ul>
  298. </li>
  299. <li class="toctree-l1">
  300. <a class="reference internal" href="../chapters/tracking/tracking_index.html">
  301. Multi-target tracking
  302. </a>
  303. </li>
  304. <li class="toctree-l1">
  305. <a class="reference internal" href="../chapters/ensemble/ensemble_index.html">
  306. Data assimilation using Ensemble Kalman filter
  307. </a>
  308. </li>
  309. <li class="toctree-l1">
  310. <a class="reference internal" href="../chapters/bnp/bnp_index.html">
  311. Bayesian non-parametric SSMs
  312. </a>
  313. </li>
  314. <li class="toctree-l1">
  315. <a class="reference internal" href="../chapters/changepoint/changepoint_index.html">
  316. Changepoint detection
  317. </a>
  318. </li>
  319. <li class="toctree-l1">
  320. <a class="reference internal" href="../chapters/timeseries/timeseries_index.html">
  321. Timeseries forecasting
  322. </a>
  323. </li>
  324. <li class="toctree-l1">
  325. <a class="reference internal" href="../chapters/gp/gp_index.html">
  326. Markovian Gaussian processes
  327. </a>
  328. </li>
  329. <li class="toctree-l1">
  330. <a class="reference internal" href="../chapters/ode/ode_index.html">
  331. Differential equations and SSMs
  332. </a>
  333. </li>
  334. <li class="toctree-l1">
  335. <a class="reference internal" href="../chapters/control/control_index.html">
  336. Optimal control
  337. </a>
  338. </li>
  339. <li class="toctree-l1">
  340. <a class="reference internal" href="../bib.html">
  341. Bibliography
  342. </a>
  343. </li>
  344. </ul>
  345. </div>
  346. </nav> <!-- To handle the deprecated key -->
  347. <div class="navbar_extra_footer">
  348. Powered by <a href="https://jupyterbook.org">Jupyter Book</a>
  349. </div>
  350. </div>
  351. <main class="col py-md-3 pl-md-4 bd-content overflow-auto" role="main">
  352. <div class="topbar container-xl fixed-top">
  353. <div class="topbar-contents row">
  354. <div class="col-12 col-md-3 bd-topbar-whitespace site-navigation show"></div>
  355. <div class="col pl-md-4 topbar-main">
  356. <button id="navbar-toggler" class="navbar-toggler ml-0" type="button" data-toggle="collapse"
  357. data-toggle="tooltip" data-placement="bottom" data-target=".site-navigation" aria-controls="navbar-menu"
  358. aria-expanded="true" aria-label="Toggle navigation" aria-controls="site-navigation"
  359. title="Toggle navigation" data-toggle="tooltip" data-placement="left">
  360. <i class="fas fa-bars"></i>
  361. <i class="fas fa-arrow-left"></i>
  362. <i class="fas fa-arrow-up"></i>
  363. </button>
  364. <div class="dropdown-buttons-trigger">
  365. <button id="dropdown-buttons-trigger" class="btn btn-secondary topbarbtn" aria-label="Download this page"><i
  366. class="fas fa-download"></i></button>
  367. <div class="dropdown-buttons">
  368. <!-- ipynb file if we had a myst markdown file -->
  369. <!-- Download raw file -->
  370. <a class="dropdown-buttons" href="../_sources/old/ssm_old.ipynb"><button type="button"
  371. class="btn btn-secondary topbarbtn" title="Download source file" data-toggle="tooltip"
  372. data-placement="left">.ipynb</button></a>
  373. <!-- Download PDF via print -->
  374. <button type="button" id="download-print" class="btn btn-secondary topbarbtn" title="Print to PDF"
  375. onclick="printPdf(this)" data-toggle="tooltip" data-placement="left">.pdf</button>
  376. </div>
  377. </div>
  378. <!-- Source interaction buttons -->
  379. <div class="dropdown-buttons-trigger">
  380. <button id="dropdown-buttons-trigger" class="btn btn-secondary topbarbtn"
  381. aria-label="Connect with source repository"><i class="fab fa-github"></i></button>
  382. <div class="dropdown-buttons sourcebuttons">
  383. <a class="repository-button"
  384. href="https://github.com/probml/ssm-book"><button type="button" class="btn btn-secondary topbarbtn"
  385. data-toggle="tooltip" data-placement="left" title="Source repository"><i
  386. class="fab fa-github"></i>repository</button></a>
  387. <a class="issues-button"
  388. href="https://github.com/probml/ssm-book/issues/new?title=Issue%20on%20page%20%2Fold/ssm_old.html&body=Your%20issue%20content%20here."><button
  389. type="button" class="btn btn-secondary topbarbtn" data-toggle="tooltip" data-placement="left"
  390. title="Open an issue"><i class="fas fa-lightbulb"></i>open issue</button></a>
  391. </div>
  392. </div>
  393. <!-- Full screen (wrap in <a> to have style consistency -->
  394. <a class="full-screen-button"><button type="button" class="btn btn-secondary topbarbtn" data-toggle="tooltip"
  395. data-placement="bottom" onclick="toggleFullScreen()" aria-label="Fullscreen mode"
  396. title="Fullscreen mode"><i
  397. class="fas fa-expand"></i></button></a>
  398. <!-- Launch buttons -->
  399. <div class="dropdown-buttons-trigger">
  400. <button id="dropdown-buttons-trigger" class="btn btn-secondary topbarbtn"
  401. aria-label="Launch interactive content"><i class="fas fa-rocket"></i></button>
  402. <div class="dropdown-buttons">
  403. <a class="binder-button" href="https://mybinder.org/v2/gh/probml/ssm-book/main?urlpath=tree/old/ssm_old.ipynb"><button type="button"
  404. class="btn btn-secondary topbarbtn" title="Launch Binder" data-toggle="tooltip"
  405. data-placement="left"><img class="binder-button-logo"
  406. src="../_static/images/logo_binder.svg"
  407. alt="Interact on binder">Binder</button></a>
  408. <a class="colab-button" href="https://colab.research.google.com/github/probml/ssm-book/blob/main/old/ssm_old.ipynb"><button type="button" class="btn btn-secondary topbarbtn"
  409. title="Launch Colab" data-toggle="tooltip" data-placement="left"><img class="colab-button-logo"
  410. src="../_static/images/logo_colab.png"
  411. alt="Interact on Colab">Colab</button></a>
  412. </div>
  413. </div>
  414. </div>
  415. <!-- Table of contents -->
  416. <div class="d-none d-md-block col-md-2 bd-toc show noprint">
  417. <div class="tocsection onthispage pt-5 pb-3">
  418. <i class="fas fa-list"></i> Contents
  419. </div>
  420. <nav id="bd-toc-nav" aria-label="Page">
  421. <ul class="visible nav section-nav flex-column">
  422. <li class="toc-h1 nav-item toc-entry">
  423. <a class="reference internal nav-link" href="#">
  424. What are State Space Models?
  425. </a>
  426. </li>
  427. <li class="toc-h1 nav-item toc-entry">
  428. <a class="reference internal nav-link" href="#hidden-markov-models">
  429. Hidden Markov Models
  430. </a>
  431. <ul class="visible nav section-nav flex-column">
  432. <li class="toc-h2 nav-item toc-entry">
  433. <a class="reference internal nav-link" href="#example-casino-hmm">
  434. Example: Casino HMM
  435. </a>
  436. </li>
  437. <li class="toc-h2 nav-item toc-entry">
  438. <a class="reference internal nav-link" href="#example-lillypad-hmm">
  439. Example: Lillypad HMM
  440. </a>
  441. </li>
  442. </ul>
  443. </li>
  444. <li class="toc-h1 nav-item toc-entry">
  445. <a class="reference internal nav-link" href="#linear-gaussian-ssms">
  446. Linear Gaussian SSMs
  447. </a>
  448. <ul class="visible nav section-nav flex-column">
  449. <li class="toc-h2 nav-item toc-entry">
  450. <a class="reference internal nav-link" href="#example-tracking-a-2d-point">
  451. Example: tracking a 2d point
  452. </a>
  453. </li>
  454. </ul>
  455. </li>
  456. <li class="toc-h1 nav-item toc-entry">
  457. <a class="reference internal nav-link" href="#nonlinear-gaussian-ssms">
  458. Nonlinear Gaussian SSMs
  459. </a>
  460. <ul class="visible nav section-nav flex-column">
  461. <li class="toc-h2 nav-item toc-entry">
  462. <a class="reference internal nav-link" href="#example-tracking-a-1d-pendulum">
  463. Example: tracking a 1d pendulum
  464. </a>
  465. </li>
  466. </ul>
  467. </li>
  468. <li class="toc-h1 nav-item toc-entry">
  469. <a class="reference internal nav-link" href="#inferential-goals">
  470. Inferential goals
  471. </a>
  472. <ul class="visible nav section-nav flex-column">
  473. <li class="toc-h2 nav-item toc-entry">
  474. <a class="reference internal nav-link" href="#example-inference-in-the-casino-hmm">
  475. Example: inference in the casino HMM
  476. </a>
  477. </li>
  478. <li class="toc-h2 nav-item toc-entry">
  479. <a class="reference internal nav-link" href="#example-inference-in-the-tracking-ssm">
  480. Example: inference in the tracking SSM
  481. </a>
  482. </li>
  483. </ul>
  484. </li>
  485. </ul>
  486. </nav>
  487. </div>
  488. </div>
  489. </div>
  490. <div id="main-content" class="row">
  491. <div class="col-12 col-md-9 pl-md-3 pr-md-0">
  492. <!-- Table of contents that is only displayed when printing the page -->
  493. <div id="jb-print-docs-body" class="onlyprint">
  494. <h1>What are State Space Models?</h1>
  495. <!-- Table of contents -->
  496. <div id="print-main-content">
  497. <div id="jb-print-toc">
  498. <div>
  499. <h2> Contents </h2>
  500. </div>
  501. <nav aria-label="Page">
  502. <ul class="visible nav section-nav flex-column">
  503. <li class="toc-h1 nav-item toc-entry">
  504. <a class="reference internal nav-link" href="#">
  505. What are State Space Models?
  506. </a>
  507. </li>
  508. <li class="toc-h1 nav-item toc-entry">
  509. <a class="reference internal nav-link" href="#hidden-markov-models">
  510. Hidden Markov Models
  511. </a>
  512. <ul class="visible nav section-nav flex-column">
  513. <li class="toc-h2 nav-item toc-entry">
  514. <a class="reference internal nav-link" href="#example-casino-hmm">
  515. Example: Casino HMM
  516. </a>
  517. </li>
  518. <li class="toc-h2 nav-item toc-entry">
  519. <a class="reference internal nav-link" href="#example-lillypad-hmm">
  520. Example: Lillypad HMM
  521. </a>
  522. </li>
  523. </ul>
  524. </li>
  525. <li class="toc-h1 nav-item toc-entry">
  526. <a class="reference internal nav-link" href="#linear-gaussian-ssms">
  527. Linear Gaussian SSMs
  528. </a>
  529. <ul class="visible nav section-nav flex-column">
  530. <li class="toc-h2 nav-item toc-entry">
  531. <a class="reference internal nav-link" href="#example-tracking-a-2d-point">
  532. Example: tracking a 2d point
  533. </a>
  534. </li>
  535. </ul>
  536. </li>
  537. <li class="toc-h1 nav-item toc-entry">
  538. <a class="reference internal nav-link" href="#nonlinear-gaussian-ssms">
  539. Nonlinear Gaussian SSMs
  540. </a>
  541. <ul class="visible nav section-nav flex-column">
  542. <li class="toc-h2 nav-item toc-entry">
  543. <a class="reference internal nav-link" href="#example-tracking-a-1d-pendulum">
  544. Example: tracking a 1d pendulum
  545. </a>
  546. </li>
  547. </ul>
  548. </li>
  549. <li class="toc-h1 nav-item toc-entry">
  550. <a class="reference internal nav-link" href="#inferential-goals">
  551. Inferential goals
  552. </a>
  553. <ul class="visible nav section-nav flex-column">
  554. <li class="toc-h2 nav-item toc-entry">
  555. <a class="reference internal nav-link" href="#example-inference-in-the-casino-hmm">
  556. Example: inference in the casino HMM
  557. </a>
  558. </li>
  559. <li class="toc-h2 nav-item toc-entry">
  560. <a class="reference internal nav-link" href="#example-inference-in-the-tracking-ssm">
  561. Example: inference in the tracking SSM
  562. </a>
  563. </li>
  564. </ul>
  565. </li>
  566. </ul>
  567. </nav>
  568. </div>
  569. </div>
  570. </div>
  571. <div>
  572. <div class="cell docutils container">
  573. <div class="cell_input docutils container">
  574. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># meta-data does not work yet in VScode</span>
  575. <span class="c1"># https://github.com/microsoft/vscode-jupyter/issues/1121</span>
  576. <span class="p">{</span>
  577. <span class="s2">&quot;tags&quot;</span><span class="p">:</span> <span class="p">[</span>
  578. <span class="s2">&quot;hide-cell&quot;</span>
  579. <span class="p">]</span>
  580. <span class="p">}</span>
  581. <span class="c1">### Install necessary libraries</span>
  582. <span class="k">try</span><span class="p">:</span>
  583. <span class="kn">import</span> <span class="nn">jax</span>
  584. <span class="k">except</span><span class="p">:</span>
  585. <span class="c1"># For cuda version, see https://github.com/google/jax#installation</span>
  586. <span class="o">%</span><span class="k">pip</span> install --upgrade &quot;jax[cpu]&quot;
  587. <span class="kn">import</span> <span class="nn">jax</span>
  588. <span class="k">try</span><span class="p">:</span>
  589. <span class="kn">import</span> <span class="nn">distrax</span>
  590. <span class="k">except</span><span class="p">:</span>
  591. <span class="o">%</span><span class="k">pip</span> install --upgrade distrax
  592. <span class="kn">import</span> <span class="nn">distrax</span>
  593. <span class="k">try</span><span class="p">:</span>
  594. <span class="kn">import</span> <span class="nn">jsl</span>
  595. <span class="k">except</span><span class="p">:</span>
  596. <span class="o">%</span><span class="k">pip</span> install git+https://github.com/probml/jsl
  597. <span class="kn">import</span> <span class="nn">jsl</span>
  598. <span class="k">try</span><span class="p">:</span>
  599. <span class="kn">import</span> <span class="nn">rich</span>
  600. <span class="k">except</span><span class="p">:</span>
  601. <span class="o">%</span><span class="k">pip</span> install rich
  602. <span class="kn">import</span> <span class="nn">rich</span>
  603. </pre></div>
  604. </div>
  605. </div>
  606. </div>
  607. <div class="cell docutils container">
  608. <div class="cell_input docutils container">
  609. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="p">{</span>
  610. <span class="s2">&quot;tags&quot;</span><span class="p">:</span> <span class="p">[</span>
  611. <span class="s2">&quot;hide-cell&quot;</span>
  612. <span class="p">]</span>
  613. <span class="p">}</span>
  614. <span class="c1">### Import standard libraries</span>
  615. <span class="kn">import</span> <span class="nn">abc</span>
  616. <span class="kn">from</span> <span class="nn">dataclasses</span> <span class="kn">import</span> <span class="n">dataclass</span>
  617. <span class="kn">import</span> <span class="nn">functools</span>
  618. <span class="kn">import</span> <span class="nn">itertools</span>
  619. <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span><span class="p">,</span> <span class="n">Callable</span><span class="p">,</span> <span class="n">NamedTuple</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Union</span><span class="p">,</span> <span class="n">Tuple</span>
  620. <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span>
  621. <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
  622. <span class="kn">import</span> <span class="nn">jax</span>
  623. <span class="kn">import</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span>
  624. <span class="kn">from</span> <span class="nn">jax</span> <span class="kn">import</span> <span class="n">lax</span><span class="p">,</span> <span class="n">vmap</span><span class="p">,</span> <span class="n">jit</span><span class="p">,</span> <span class="n">grad</span>
  625. <span class="kn">from</span> <span class="nn">jax.scipy.special</span> <span class="kn">import</span> <span class="n">logit</span>
  626. <span class="kn">from</span> <span class="nn">jax.nn</span> <span class="kn">import</span> <span class="n">softmax</span>
  627. <span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">partial</span>
  628. <span class="kn">from</span> <span class="nn">jax.random</span> <span class="kn">import</span> <span class="n">PRNGKey</span><span class="p">,</span> <span class="n">split</span>
  629. <span class="kn">import</span> <span class="nn">inspect</span>
  630. <span class="kn">import</span> <span class="nn">inspect</span> <span class="k">as</span> <span class="nn">py_inspect</span>
  631. <span class="kn">import</span> <span class="nn">rich</span>
  632. <span class="kn">from</span> <span class="nn">rich</span> <span class="kn">import</span> <span class="n">inspect</span> <span class="k">as</span> <span class="n">r_inspect</span>
  633. <span class="kn">from</span> <span class="nn">rich</span> <span class="kn">import</span> <span class="nb">print</span> <span class="k">as</span> <span class="n">r_print</span>
  634. <span class="k">def</span> <span class="nf">print_source</span><span class="p">(</span><span class="n">fname</span><span class="p">):</span>
  635. <span class="n">r_print</span><span class="p">(</span><span class="n">py_inspect</span><span class="o">.</span><span class="n">getsource</span><span class="p">(</span><span class="n">fname</span><span class="p">))</span>
  636. </pre></div>
  637. </div>
  638. </div>
  639. </div>
  640. <div class="math notranslate nohighlight">
  641. \[ \begin{align}\begin{aligned}\newcommand\floor[1]{\lfloor#1\rfloor}\\\newcommand{\real}{\mathbb{R}}\\% Numbers
  642. \newcommand{\vzero}{\boldsymbol{0}}
  643. \newcommand{\vone}{\boldsymbol{1}}\\% Greek https://www.latex-tutorial.com/symbols/greek-alphabet/
  644. \newcommand{\valpha}{\boldsymbol{\alpha}}
  645. \newcommand{\vbeta}{\boldsymbol{\beta}}
  646. \newcommand{\vchi}{\boldsymbol{\chi}}
  647. \newcommand{\vdelta}{\boldsymbol{\delta}}
  648. \newcommand{\vDelta}{\boldsymbol{\Delta}}
  649. \newcommand{\vepsilon}{\boldsymbol{\epsilon}}
  650. \newcommand{\vzeta}{\boldsymbol{\zeta}}
  651. \newcommand{\vXi}{\boldsymbol{\Xi}}
  652. \newcommand{\vell}{\boldsymbol{\ell}}
  653. \newcommand{\veta}{\boldsymbol{\eta}}
  654. %\newcommand{\vEta}{\boldsymbol{\Eta}}
  655. \newcommand{\vgamma}{\boldsymbol{\gamma}}
  656. \newcommand{\vGamma}{\boldsymbol{\Gamma}}
  657. \newcommand{\vmu}{\boldsymbol{\mu}}
  658. \newcommand{\vmut}{\boldsymbol{\tilde{\mu}}}
  659. \newcommand{\vnu}{\boldsymbol{\nu}}
  660. \newcommand{\vkappa}{\boldsymbol{\kappa}}
  661. \newcommand{\vlambda}{\boldsymbol{\lambda}}
  662. \newcommand{\vLambda}{\boldsymbol{\Lambda}}
  663. \newcommand{\vLambdaBar}{\overline{\vLambda}}
  664. %\newcommand{\vnu}{\boldsymbol{\nu}}
  665. \newcommand{\vomega}{\boldsymbol{\omega}}
  666. \newcommand{\vOmega}{\boldsymbol{\Omega}}
  667. \newcommand{\vphi}{\boldsymbol{\phi}}
  668. \newcommand{\vvarphi}{\boldsymbol{\varphi}}
  669. \newcommand{\vPhi}{\boldsymbol{\Phi}}
  670. \newcommand{\vpi}{\boldsymbol{\pi}}
  671. \newcommand{\vPi}{\boldsymbol{\Pi}}
  672. \newcommand{\vpsi}{\boldsymbol{\psi}}
  673. \newcommand{\vPsi}{\boldsymbol{\Psi}}
  674. \newcommand{\vrho}{\boldsymbol{\rho}}
  675. \newcommand{\vtheta}{\boldsymbol{\theta}}
  676. \newcommand{\vthetat}{\boldsymbol{\tilde{\theta}}}
  677. \newcommand{\vTheta}{\boldsymbol{\Theta}}
  678. \newcommand{\vsigma}{\boldsymbol{\sigma}}
  679. \newcommand{\vSigma}{\boldsymbol{\Sigma}}
  680. \newcommand{\vSigmat}{\boldsymbol{\tilde{\Sigma}}}
  681. \newcommand{\vsigmoid}{\vsigma}
  682. \newcommand{\vtau}{\boldsymbol{\tau}}
  683. \newcommand{\vxi}{\boldsymbol{\xi}}\\
  684. % Lower Roman (Vectors)
  685. \newcommand{\va}{\mathbf{a}}
  686. \newcommand{\vb}{\mathbf{b}}
  687. \newcommand{\vBt}{\mathbf{\tilde{B}}}
  688. \newcommand{\vc}{\mathbf{c}}
  689. \newcommand{\vct}{\mathbf{\tilde{c}}}
  690. \newcommand{\vd}{\mathbf{d}}
  691. \newcommand{\ve}{\mathbf{e}}
  692. \newcommand{\vf}{\mathbf{f}}
  693. \newcommand{\vg}{\mathbf{g}}
  694. \newcommand{\vh}{\mathbf{h}}
  695. %\newcommand{\myvh}{\mathbf{h}}
  696. \newcommand{\vi}{\mathbf{i}}
  697. \newcommand{\vj}{\mathbf{j}}
  698. \newcommand{\vk}{\mathbf{k}}
  699. \newcommand{\vl}{\mathbf{l}}
  700. \newcommand{\vm}{\mathbf{m}}
  701. \newcommand{\vn}{\mathbf{n}}
  702. \newcommand{\vo}{\mathbf{o}}
  703. \newcommand{\vp}{\mathbf{p}}
  704. \newcommand{\vq}{\mathbf{q}}
  705. \newcommand{\vr}{\mathbf{r}}
  706. \newcommand{\vs}{\mathbf{s}}
  707. \newcommand{\vt}{\mathbf{t}}
  708. \newcommand{\vu}{\mathbf{u}}
  709. \newcommand{\vv}{\mathbf{v}}
  710. \newcommand{\vw}{\mathbf{w}}
  711. \newcommand{\vws}{\vw_s}
  712. \newcommand{\vwt}{\mathbf{\tilde{w}}}
  713. \newcommand{\vWt}{\mathbf{\tilde{W}}}
  714. \newcommand{\vwh}{\hat{\vw}}
  715. \newcommand{\vx}{\mathbf{x}}
  716. %\newcommand{\vx}{\mathbf{x}}
  717. \newcommand{\vxt}{\mathbf{\tilde{x}}}
  718. \newcommand{\vy}{\mathbf{y}}
  719. \newcommand{\vyt}{\mathbf{\tilde{y}}}
  720. \newcommand{\vz}{\mathbf{z}}
  721. %\newcommand{\vzt}{\mathbf{\tilde{z}}}\\
  722. % Upper Roman (Matrices)
  723. \newcommand{\vA}{\mathbf{A}}
  724. \newcommand{\vB}{\mathbf{B}}
  725. \newcommand{\vC}{\mathbf{C}}
  726. \newcommand{\vD}{\mathbf{D}}
  727. \newcommand{\vE}{\mathbf{E}}
  728. \newcommand{\vF}{\mathbf{F}}
  729. \newcommand{\vG}{\mathbf{G}}
  730. \newcommand{\vH}{\mathbf{H}}
  731. \newcommand{\vI}{\mathbf{I}}
  732. \newcommand{\vJ}{\mathbf{J}}
  733. \newcommand{\vK}{\mathbf{K}}
  734. \newcommand{\vL}{\mathbf{L}}
  735. \newcommand{\vM}{\mathbf{M}}
  736. \newcommand{\vMt}{\mathbf{\tilde{M}}}
  737. \newcommand{\vN}{\mathbf{N}}
  738. \newcommand{\vO}{\mathbf{O}}
  739. \newcommand{\vP}{\mathbf{P}}
  740. \newcommand{\vQ}{\mathbf{Q}}
  741. \newcommand{\vR}{\mathbf{R}}
  742. \newcommand{\vS}{\mathbf{S}}
  743. \newcommand{\vT}{\mathbf{T}}
  744. \newcommand{\vU}{\mathbf{U}}
  745. \newcommand{\vV}{\mathbf{V}}
  746. \newcommand{\vW}{\mathbf{W}}
  747. \newcommand{\vX}{\mathbf{X}}
  748. %\newcommand{\vXs}{\vX_{\vs}}
  749. \newcommand{\vXs}{\vX_{s}}
  750. \newcommand{\vXt}{\mathbf{\tilde{X}}}
  751. \newcommand{\vY}{\mathbf{Y}}
  752. \newcommand{\vZ}{\mathbf{Z}}
  753. \newcommand{\vZt}{\mathbf{\tilde{Z}}}
  754. \newcommand{\vzt}{\mathbf{\tilde{z}}}\\
  755. %%%%
  756. \newcommand{\hidden}{\vz}
  757. \newcommand{\hid}{\hidden}
  758. \newcommand{\observed}{\vy}
  759. \newcommand{\obs}{\observed}
  760. \newcommand{\inputs}{\vu}
  761. \newcommand{\input}{\inputs}\\\newcommand{\hmmTrans}{\vA}
  762. \newcommand{\hmmObs}{\vB}
  763. \newcommand{\hmmInit}{\vpi}
  764. \newcommand{\hmmhid}{\hidden}
  765. \newcommand{\hmmobs}{\obs}\\\newcommand{\ldsDyn}{\vA}
  766. \newcommand{\ldsObs}{\vC}
  767. \newcommand{\ldsDynIn}{\vB}
  768. \newcommand{\ldsObsIn}{\vD}
  769. \newcommand{\ldsDynNoise}{\vQ}
  770. \newcommand{\ldsObsNoise}{\vR}\\\newcommand{\ssmDynFn}{f}
  771. \newcommand{\ssmObsFn}{h}\\
  772. %%%
  773. \newcommand{\gauss}{\mathcal{N}}\\\newcommand{\diag}{\mathrm{diag}}\end{aligned}\end{align} \]</div>
  774. <div class="tex2jax_ignore mathjax_ignore section" id="what-are-state-space-models">
  775. <span id="sec-ssm-intro"></span><h1>What are State Space Models?<a class="headerlink" href="#what-are-state-space-models" title="Permalink to this headline">¶</a></h1>
  776. <p>A state space model or SSM
  777. is a partially observed Markov model,
  778. in which the hidden state, <span class="math notranslate nohighlight">\(\hidden_t\)</span>,
  779. evolves over time according to a Markov process,
  780. possibly conditional on external inputs or controls <span class="math notranslate nohighlight">\(\input_t\)</span>,
  781. and each hidden state generates some
  782. observations <span class="math notranslate nohighlight">\(\obs_t\)</span> at each time step.
  783. (In this book, we mostly focus on discrete time systems,
  784. although we consider the continuous-time case in XXX.)
  785. We get to see the observations, but not the hidden state.
  786. Our main goal is to infer the hidden state given the observations.
  787. However, we can also use the model to predict future observations,
  788. by first predicting future hidden states, and then predicting
  789. what observations they might generate.
  790. By using a hidden state <span class="math notranslate nohighlight">\(\hidden_t\)</span>
  791. to represent the past observations, <span class="math notranslate nohighlight">\(\obs_{1:t-1}\)</span>,
  792. the model can have ``infinite’’ memory,
  793. unlike a standard Markov model.</p>
  794. <p>Formally we can define an SSM
  795. as the following joint distribution:</p>
  796. <div class="math notranslate nohighlight" id="equation-eq-ssm-ar">
  797. <span class="eqno">()<a class="headerlink" href="#equation-eq-ssm-ar" title="Permalink to this equation">¶</a></span>\[p(\hmmobs_{1:T},\hmmhid_{1:T}|\inputs_{1:T})
  798. = \left[ p(\hmmhid_1|\inputs_1) \prod_{t=2}^{T}
  799. p(\hmmhid_t|\hmmhid_{t-1},\inputs_t) \right]
  800. \left[ \prod_{t=1}^T p(\hmmobs_t|\hmmhid_t, \inputs_t, \hmmobs_{t-1}) \right]\]</div>
  801. <p>where <span class="math notranslate nohighlight">\(p(\hmmhid_t|\hmmhid_{t-1},\inputs_t)\)</span> is the
  802. transition model,
  803. <span class="math notranslate nohighlight">\(p(\hmmobs_t|\hmmhid_t, \inputs_t, \hmmobs_{t-1})\)</span> is the
  804. observation model,
  805. and <span class="math notranslate nohighlight">\(\inputs_{t}\)</span> is an optional input or action.
  806. See <code class="xref std std-numref docutils literal notranslate"><span class="pre">Figure</span> <span class="pre">%s</span></code>
  807. for an illustration of the corresponding graphical model.</p>
  808. <div class="figure align-default" id="ssm-ar">
  809. <a class="reference internal image-reference" href="../_images/SSM-AR-inputs.png"><img alt="../_images/SSM-AR-inputs.png" src="../_images/SSM-AR-inputs.png" style="width: 152.0px; height: 165.0px;" /></a>
  810. <p class="caption"><span class="caption-text">Illustration of an SSM as a graphical model.</span><a class="headerlink" href="#ssm-ar" title="Permalink to this image">¶</a></p>
  811. </div>
  812. <p>We often consider a simpler setting in which the
  813. observations are conditionally independent of each other
  814. (rather than having Markovian dependencies) given the hidden state.
  815. In this case the joint simplifies to</p>
  816. <div class="math notranslate nohighlight" id="equation-eq-ssm-input">
  817. <span class="eqno">()<a class="headerlink" href="#equation-eq-ssm-input" title="Permalink to this equation">¶</a></span>\[p(\hmmobs_{1:T},\hmmhid_{1:T}|\inputs_{1:T})
  818. = \left[ p(\hmmhid_1|\inputs_1) \prod_{t=2}^{T}
  819. p(\hmmhid_t|\hmmhid_{t-1},\inputs_t) \right]
  820. \left[ \prod_{t=1}^T p(\hmmobs_t|\hmmhid_t, \inputs_t) \right]\]</div>
  821. <p>Sometimes there are no external inputs, so the model further
  822. simplifies to the following unconditional generative model:</p>
  823. <div class="math notranslate nohighlight" id="equation-eq-ssm-no-input">
  824. <span class="eqno">()<a class="headerlink" href="#equation-eq-ssm-no-input" title="Permalink to this equation">¶</a></span>\[p(\hmmobs_{1:T},\hmmhid_{1:T})
  825. = \left[ p(\hmmhid_1) \prod_{t=2}^{T}
  826. p(\hmmhid_t|\hmmhid_{t-1}) \right]
  827. \left[ \prod_{t=1}^T p(\hmmobs_t|\hmmhid_t) \right]\]</div>
  828. <p>See <code class="xref std std-numref docutils literal notranslate"><span class="pre">Figure</span> <span class="pre">%s</span></code>
  829. for an illustration of the corresponding graphical model.</p>
  830. <div class="figure align-default" id="ssm-simplified">
  831. <a class="reference internal image-reference" href="../_images/SSM-simplified.png"><img alt="../_images/SSM-simplified.png" src="../_images/SSM-simplified.png" style="width: 136.0px; height: 98.0px;" /></a>
  832. <p class="caption"><span class="caption-text">Illustration of a simplified SSM.</span><a class="headerlink" href="#ssm-simplified" title="Permalink to this image">¶</a></p>
  833. </div>
  834. </div>
  835. <div class="tex2jax_ignore mathjax_ignore section" id="hidden-markov-models">
  836. <span id="sec-hmm-intro"></span><h1>Hidden Markov Models<a class="headerlink" href="#hidden-markov-models" title="Permalink to this headline">¶</a></h1>
  837. <p>In this section, we discuss the
  838. hidden Markov model or HMM,
  839. which is a state space model in which the hidden states
  840. are discrete, so <span class="math notranslate nohighlight">\(\hmmhid_t \in \{1,\ldots, K\}\)</span>.
  841. The observations may be discrete,
  842. <span class="math notranslate nohighlight">\(\hmmobs_t \in \{1,\ldots, C\}\)</span>,
  843. or continuous,
  844. <span class="math notranslate nohighlight">\(\hmmobs_t \in \real^D\)</span>,
  845. or some combination,
  846. as we illustrate below.
  847. More details can be found in e.g.,
  848. <span id="id1">[<a class="reference internal" href="../bib.html#id34" title="O. Cappe, E. Moulines, and T. Ryden. Inference in Hidden Markov Models. Springer, 2005.">CMR05</a>, <a class="reference internal" href="../bib.html#id33" title="A. Fraser. Hidden Markov Models and Dynamical Systems. SIAM Press, 2008.">Fra08</a>, <a class="reference internal" href="../bib.html#id32" title="L. R. Rabiner. A tutorial on Hidden Markov Models and selected applications in speech recognition. Proc. of the IEEE, 77(2):257–286, 1989.">Rab89</a>]</span>.
  849. For an interactive introduction,
  850. see <a class="reference external" href="https://nipunbatra.github.io/hmm/">https://nipunbatra.github.io/hmm/</a>.</p>
  851. <div class="section" id="example-casino-hmm">
  852. <h2>Example: Casino HMM<a class="headerlink" href="#example-casino-hmm" title="Permalink to this headline">¶</a></h2>
  853. <p>To illustrate HMMs with categorical observation model,
  854. we consider the “Ocassionally dishonest casino” model from <span id="id2">[<a class="reference internal" href="../bib.html#id3" title="R. Durbin, S. Eddy, A. Krogh, and G. Mitchison. Biological Sequence Analysis: Probabilistic Models of Proteins and Nucleic Acids. Cambridge University Press, 1998.">DEKM98</a>]</span>.
  855. There are 2 hidden states, representing whether the dice being used in the casino is fair or loaded.
  856. Each state defines a distribution over the 6 possible observations.</p>
  857. <p>The transition model is denoted by</p>
  858. <div class="math notranslate nohighlight">
  859. \[p(z_t=j|z_{t-1}=i) = \hmmTrans_{ij}\]</div>
  860. <p>Here the <span class="math notranslate nohighlight">\(i\)</span>’th row of <span class="math notranslate nohighlight">\(\vA\)</span> corresponds to the outgoing distribution from state <span class="math notranslate nohighlight">\(i\)</span>.
  861. This is a row stochastic matrix,
  862. meaning each row sums to one.
  863. We can visualize
  864. the non-zero entries in the transition matrix by creating a state transition diagram,
  865. as shown in
  866. <code class="xref std std-numref docutils literal notranslate"><span class="pre">Figure</span> <span class="pre">%s</span></code></p>
  867. <div class="figure align-default" id="casino-fig">
  868. <a class="reference internal image-reference" href="../_images/casino.png"><img alt="../_images/casino.png" src="../_images/casino.png" style="width: 208.5px; height: 142.5px;" /></a>
  869. <p class="caption"><span class="caption-text">Illustration of the casino HMM.</span><a class="headerlink" href="#casino-fig" title="Permalink to this image">¶</a></p>
  870. </div>
  871. <p>The observation model
  872. <span class="math notranslate nohighlight">\(p(\obs_t|\hidden_t=j)\)</span> has the form</p>
  873. <div class="math notranslate nohighlight">
  874. \[p(\obs_t=k|\hidden_t=j) = \hmmObs_{jk} \]</div>
  875. <p>This is represented by the histograms associated with each
  876. state in <a class="reference internal" href="#casino-fig"><span class="std std-ref">Illustration of the casino HMM.</span></a>.</p>
  877. <p>Finally,
  878. the initial state distribution is denoted by</p>
  879. <div class="math notranslate nohighlight">
  880. \[p(z_1=j) = \hmmInit_j\]</div>
  881. <p>Collectively we denote all the parameters by <span class="math notranslate nohighlight">\(\vtheta=(\hmmTrans, \hmmObs, \hmmInit)\)</span>.</p>
  882. <p>Now let us implement this model in code.</p>
  883. <div class="cell docutils container">
  884. <div class="cell_input docutils container">
  885. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># state transition matrix</span>
  886. <span class="n">A</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span>
  887. <span class="p">[</span><span class="mf">0.95</span><span class="p">,</span> <span class="mf">0.05</span><span class="p">],</span>
  888. <span class="p">[</span><span class="mf">0.10</span><span class="p">,</span> <span class="mf">0.90</span><span class="p">]</span>
  889. <span class="p">])</span>
  890. <span class="c1"># observation matrix</span>
  891. <span class="n">B</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span>
  892. <span class="p">[</span><span class="mi">1</span><span class="o">/</span><span class="mi">6</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">6</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">6</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">6</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">6</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">6</span><span class="p">],</span> <span class="c1"># fair die</span>
  893. <span class="p">[</span><span class="mi">1</span><span class="o">/</span><span class="mi">10</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">10</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">10</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">10</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">10</span><span class="p">,</span> <span class="mi">5</span><span class="o">/</span><span class="mi">10</span><span class="p">]</span> <span class="c1"># loaded die</span>
  894. <span class="p">])</span>
  895. <span class="n">pi</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">])</span>
  896. <span class="p">(</span><span class="n">nstates</span><span class="p">,</span> <span class="n">nobs</span><span class="p">)</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">B</span><span class="p">)</span>
  897. </pre></div>
  898. </div>
  899. </div>
  900. </div>
  901. <div class="cell docutils container">
  902. <div class="cell_input docutils container">
  903. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">distrax</span>
  904. <span class="kn">from</span> <span class="nn">distrax</span> <span class="kn">import</span> <span class="n">HMM</span>
  905. <span class="n">hmm</span> <span class="o">=</span> <span class="n">HMM</span><span class="p">(</span><span class="n">trans_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">Categorical</span><span class="p">(</span><span class="n">probs</span><span class="o">=</span><span class="n">A</span><span class="p">),</span>
  906. <span class="n">init_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">Categorical</span><span class="p">(</span><span class="n">probs</span><span class="o">=</span><span class="n">pi</span><span class="p">),</span>
  907. <span class="n">obs_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">Categorical</span><span class="p">(</span><span class="n">probs</span><span class="o">=</span><span class="n">B</span><span class="p">))</span>
  908. <span class="nb">print</span><span class="p">(</span><span class="n">hmm</span><span class="p">)</span>
  909. </pre></div>
  910. </div>
  911. </div>
  912. <div class="cell_output docutils container">
  913. <div class="output stderr highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
  914. </pre></div>
  915. </div>
  916. <div class="output stream highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>&lt;distrax._src.utils.hmm.HMM object at 0x7feaf905e250&gt;
  917. </pre></div>
  918. </div>
  919. </div>
  920. </div>
  921. <p>Let’s sample from the model. We will generate a sequence of latent states, <span class="math notranslate nohighlight">\(\hid_{1:T}\)</span>,
  922. which we then convert to a sequence of observations, <span class="math notranslate nohighlight">\(\obs_{1:T}\)</span>.</p>
  923. <div class="cell docutils container">
  924. <div class="cell_input docutils container">
  925. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">seed</span> <span class="o">=</span> <span class="mi">314</span>
  926. <span class="n">n_samples</span> <span class="o">=</span> <span class="mi">300</span>
  927. <span class="n">z_hist</span><span class="p">,</span> <span class="n">x_hist</span> <span class="o">=</span> <span class="n">hmm</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="n">seed</span><span class="o">=</span><span class="n">PRNGKey</span><span class="p">(</span><span class="n">seed</span><span class="p">),</span> <span class="n">seq_len</span><span class="o">=</span><span class="n">n_samples</span><span class="p">)</span>
  928. <span class="n">z_hist_str</span> <span class="o">=</span> <span class="s2">&quot;&quot;</span><span class="o">.</span><span class="n">join</span><span class="p">((</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">z_hist</span><span class="p">)</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="nb">str</span><span class="p">))[:</span><span class="mi">60</span><span class="p">]</span>
  929. <span class="n">x_hist_str</span> <span class="o">=</span> <span class="s2">&quot;&quot;</span><span class="o">.</span><span class="n">join</span><span class="p">((</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">x_hist</span><span class="p">)</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="nb">str</span><span class="p">))[:</span><span class="mi">60</span><span class="p">]</span>
  930. <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Printing sample observed/latent...&quot;</span><span class="p">)</span>
  931. <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;x: </span><span class="si">{</span><span class="n">x_hist_str</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
  932. <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;z: </span><span class="si">{</span><span class="n">z_hist_str</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
  933. </pre></div>
  934. </div>
  935. </div>
  936. <div class="cell_output docutils container">
  937. <div class="output stream highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>Printing sample observed/latent...
  938. x: 633665342652353616444236412331351246651613325161656366246242
  939. z: 222222211111111111111111111111111111111222111111112222211111
  940. </pre></div>
  941. </div>
  942. </div>
  943. </div>
  944. <div class="cell docutils container">
  945. <div class="cell_input docutils container">
  946. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># Here is the source code for the sampling algorithm.</span>
  947. <span class="n">print_source</span><span class="p">(</span><span class="n">hmm</span><span class="o">.</span><span class="n">sample</span><span class="p">)</span>
  948. </pre></div>
  949. </div>
  950. </div>
  951. <div class="cell_output docutils container">
  952. <div class="output text_html"><pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"> def sample<span style="font-weight: bold">(</span>self,
  953. *,
  954. seed: chex.PRNGKey,
  955. seq_len: chex.Array<span style="font-weight: bold">)</span> -&gt; Tuple:
  956. <span style="color: #008000; text-decoration-color: #008000">""</span>"Sample from this HMM.
  957. Samples an observation of given length according to this
  958. Hidden Markov Model and gives the sequence of the hidden states
  959. as well as the observation.
  960. Args:
  961. seed: Random key of shape <span style="font-weight: bold">(</span><span style="color: #000080; text-decoration-color: #000080; font-weight: bold">2</span>,<span style="font-weight: bold">)</span> and dtype uint32.
  962. seq_len: The length of the observation sequence.
  963. Returns:
  964. Tuple of hidden state sequence, and observation sequence.
  965. <span style="color: #008000; text-decoration-color: #008000">""</span>"
  966. rng_key, rng_init = jax.random.split<span style="font-weight: bold">(</span>seed<span style="font-weight: bold">)</span>
  967. initial_state = self._init_dist.sample<span style="font-weight: bold">(</span><span style="color: #808000; text-decoration-color: #808000">seed</span>=<span style="color: #800080; text-decoration-color: #800080">rng_init</span><span style="font-weight: bold">)</span>
  968. def draw_state<span style="font-weight: bold">(</span>prev_state, key<span style="font-weight: bold">)</span>:
  969. state = self._trans_dist.sample<span style="font-weight: bold">(</span><span style="color: #808000; text-decoration-color: #808000">seed</span>=<span style="color: #800080; text-decoration-color: #800080">key</span><span style="font-weight: bold">)</span>
  970. return state, state
  971. rng_state, rng_obs = jax.random.split<span style="font-weight: bold">(</span>rng_key<span style="font-weight: bold">)</span>
  972. keys = jax.random.split<span style="font-weight: bold">(</span>rng_state, seq_len - <span style="color: #000080; text-decoration-color: #000080; font-weight: bold">1</span><span style="font-weight: bold">)</span>
  973. _, states = jax.lax.scan<span style="font-weight: bold">(</span>draw_state, initial_state, keys<span style="font-weight: bold">)</span>
  974. states = jnp.append<span style="font-weight: bold">(</span>initial_state, states<span style="font-weight: bold">)</span>
  975. def draw_obs<span style="font-weight: bold">(</span>state, key<span style="font-weight: bold">)</span>:
  976. return self._obs_dist.sample<span style="font-weight: bold">(</span><span style="color: #808000; text-decoration-color: #808000">seed</span>=<span style="color: #800080; text-decoration-color: #800080">key</span><span style="font-weight: bold">)</span>
  977. keys = jax.random.split<span style="font-weight: bold">(</span>rng_obs, seq_len<span style="font-weight: bold">)</span>
  978. obs_seq = jax.vmap<span style="font-weight: bold">(</span>draw_obs, <span style="color: #808000; text-decoration-color: #808000">in_axes</span>=<span style="font-weight: bold">(</span><span style="color: #000080; text-decoration-color: #000080; font-weight: bold">0</span>, <span style="color: #000080; text-decoration-color: #000080; font-weight: bold">0</span><span style="font-weight: bold">))(</span>states, keys<span style="font-weight: bold">)</span>
  979. return states, obs_seq
  980. </pre>
  981. </div></div>
  982. </div>
  983. <p>Our primary goal will be to infer the latent state from the observations,
  984. so we can detect if the casino is being dishonest or not. This will
  985. affect how we choose to gamble our money.
  986. We discuss various ways to perform this inference below.</p>
  987. </div>
  988. <div class="section" id="example-lillypad-hmm">
  989. <span id="sec-lillypad"></span><h2>Example: Lillypad HMM<a class="headerlink" href="#example-lillypad-hmm" title="Permalink to this headline">¶</a></h2>
  990. <p>If <span class="math notranslate nohighlight">\(\obs_t\)</span> is continuous, it is common to use a Gaussian
  991. observation model:</p>
  992. <div class="math notranslate nohighlight">
  993. \[p(\obs_t|\hidden_t=j) = \gauss(\obs_t|\vmu_j,\vSigma_j)\]</div>
  994. <p>This is sometimes called a Gaussian HMM.</p>
  995. <p>As a simple example, suppose we have an HMM with 3 hidden states,
  996. each of which generates a 2d Gaussian.
  997. We can represent these Gaussian distributions are 2d ellipses,
  998. as we show below.
  999. We call these ``lilly pads’’, because of their shape.
  1000. We can imagine a frog hopping from one lilly pad to another.
  1001. (This analogy is due to the late Sam Roweis.)
  1002. The frog will stay on a pad for a while (corresponding to remaining in the same
  1003. discrete state <span class="math notranslate nohighlight">\(\hidden_t\)</span>), and then jump to a new pad
  1004. (corresponding to a transition to a new state).
  1005. The data we see are just the 2d points (e.g., water droplets)
  1006. coming from near the pad that the frog is currently on.
  1007. Thus this model is like a Gaussian mixture model,
  1008. in that it generates clusters of observations,
  1009. except now there is temporal correlation between the data points.</p>
  1010. <p>Let us now illustrate this model in code.</p>
  1011. <div class="cell docutils container">
  1012. <div class="cell_input docutils container">
  1013. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># Let us create the model</span>
  1014. <span class="n">initial_probs</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.2</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">])</span>
  1015. <span class="c1"># transition matrix</span>
  1016. <span class="n">A</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span>
  1017. <span class="p">[</span><span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.4</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">],</span>
  1018. <span class="p">[</span><span class="mf">0.1</span><span class="p">,</span> <span class="mf">0.6</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">],</span>
  1019. <span class="p">[</span><span class="mf">0.2</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">]</span>
  1020. <span class="p">])</span>
  1021. <span class="c1"># Observation model</span>
  1022. <span class="n">mu_collection</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span>
  1023. <span class="p">[</span><span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">],</span>
  1024. <span class="p">[</span><span class="mf">0.8</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">],</span>
  1025. <span class="p">[</span><span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.8</span><span class="p">]</span>
  1026. <span class="p">])</span>
  1027. <span class="n">S1</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="mf">1.1</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">]])</span>
  1028. <span class="n">S2</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="mf">0.3</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.5</span><span class="p">],</span> <span class="p">[</span><span class="o">-</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">1.3</span><span class="p">]])</span>
  1029. <span class="n">S3</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="mf">0.8</span><span class="p">,</span> <span class="mf">0.4</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.4</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">]])</span>
  1030. <span class="n">cov_collection</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">S1</span><span class="p">,</span> <span class="n">S2</span><span class="p">,</span> <span class="n">S3</span><span class="p">])</span> <span class="o">/</span> <span class="mi">60</span>
  1031. <span class="kn">import</span> <span class="nn">tensorflow_probability</span> <span class="k">as</span> <span class="nn">tfp</span>
  1032. <span class="k">if</span> <span class="kc">False</span><span class="p">:</span>
  1033. <span class="n">hmm</span> <span class="o">=</span> <span class="n">HMM</span><span class="p">(</span><span class="n">trans_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">Categorical</span><span class="p">(</span><span class="n">probs</span><span class="o">=</span><span class="n">A</span><span class="p">),</span>
  1034. <span class="n">init_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">Categorical</span><span class="p">(</span><span class="n">probs</span><span class="o">=</span><span class="n">initial_probs</span><span class="p">),</span>
  1035. <span class="n">obs_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">MultivariateNormalFullCovariance</span><span class="p">(</span>
  1036. <span class="n">loc</span><span class="o">=</span><span class="n">mu_collection</span><span class="p">,</span> <span class="n">covariance_matrix</span><span class="o">=</span><span class="n">cov_collection</span><span class="p">))</span>
  1037. <span class="k">else</span><span class="p">:</span>
  1038. <span class="n">hmm</span> <span class="o">=</span> <span class="n">HMM</span><span class="p">(</span><span class="n">trans_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">Categorical</span><span class="p">(</span><span class="n">probs</span><span class="o">=</span><span class="n">A</span><span class="p">),</span>
  1039. <span class="n">init_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">Categorical</span><span class="p">(</span><span class="n">probs</span><span class="o">=</span><span class="n">initial_probs</span><span class="p">),</span>
  1040. <span class="n">obs_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">as_distribution</span><span class="p">(</span>
  1041. <span class="n">tfp</span><span class="o">.</span><span class="n">substrates</span><span class="o">.</span><span class="n">jax</span><span class="o">.</span><span class="n">distributions</span><span class="o">.</span><span class="n">MultivariateNormalFullCovariance</span><span class="p">(</span><span class="n">loc</span><span class="o">=</span><span class="n">mu_collection</span><span class="p">,</span>
  1042. <span class="n">covariance_matrix</span><span class="o">=</span><span class="n">cov_collection</span><span class="p">)))</span>
  1043. <span class="nb">print</span><span class="p">(</span><span class="n">hmm</span><span class="p">)</span>
  1044. </pre></div>
  1045. </div>
  1046. </div>
  1047. <div class="cell_output docutils container">
  1048. <div class="output stream highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>&lt;distrax._src.utils.hmm.HMM object at 0x7feb19733220&gt;
  1049. </pre></div>
  1050. </div>
  1051. </div>
  1052. </div>
  1053. <div class="cell docutils container">
  1054. <div class="cell_input docutils container">
  1055. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">n_samples</span><span class="p">,</span> <span class="n">seed</span> <span class="o">=</span> <span class="mi">50</span><span class="p">,</span> <span class="mi">10</span>
  1056. <span class="n">samples_state</span><span class="p">,</span> <span class="n">samples_obs</span> <span class="o">=</span> <span class="n">hmm</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="n">seed</span><span class="o">=</span><span class="n">PRNGKey</span><span class="p">(</span><span class="n">seed</span><span class="p">),</span> <span class="n">seq_len</span><span class="o">=</span><span class="n">n_samples</span><span class="p">)</span>
  1057. <span class="nb">print</span><span class="p">(</span><span class="n">samples_state</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
  1058. <span class="nb">print</span><span class="p">(</span><span class="n">samples_obs</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
  1059. </pre></div>
  1060. </div>
  1061. </div>
  1062. <div class="cell_output docutils container">
  1063. <div class="output stream highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>(50,)
  1064. (50, 2)
  1065. </pre></div>
  1066. </div>
  1067. </div>
  1068. </div>
  1069. <div class="cell docutils container">
  1070. <div class="cell_input docutils container">
  1071. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># Let&#39;s plot the observed data in 2d</span>
  1072. <span class="n">xmin</span><span class="p">,</span> <span class="n">xmax</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span>
  1073. <span class="n">ymin</span><span class="p">,</span> <span class="n">ymax</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="mf">1.2</span>
  1074. <span class="n">colors</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;tab:green&quot;</span><span class="p">,</span> <span class="s2">&quot;tab:blue&quot;</span><span class="p">,</span> <span class="s2">&quot;tab:red&quot;</span><span class="p">]</span>
  1075. <span class="k">def</span> <span class="nf">plot_2dhmm</span><span class="p">(</span><span class="n">hmm</span><span class="p">,</span> <span class="n">samples_obs</span><span class="p">,</span> <span class="n">samples_state</span><span class="p">,</span> <span class="n">colors</span><span class="p">,</span> <span class="n">ax</span><span class="p">,</span> <span class="n">xmin</span><span class="p">,</span> <span class="n">xmax</span><span class="p">,</span> <span class="n">ymin</span><span class="p">,</span> <span class="n">ymax</span><span class="p">,</span> <span class="n">step</span><span class="o">=</span><span class="mf">1e-2</span><span class="p">):</span>
  1076. <span class="n">obs_dist</span> <span class="o">=</span> <span class="n">hmm</span><span class="o">.</span><span class="n">obs_dist</span>
  1077. <span class="n">color_sample</span> <span class="o">=</span> <span class="p">[</span><span class="n">colors</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">samples_state</span><span class="p">]</span>
  1078. <span class="n">xs</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">xmin</span><span class="p">,</span> <span class="n">xmax</span><span class="p">,</span> <span class="n">step</span><span class="p">)</span>
  1079. <span class="n">ys</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">ymin</span><span class="p">,</span> <span class="n">ymax</span><span class="p">,</span> <span class="n">step</span><span class="p">)</span>
  1080. <span class="n">v_prob</span> <span class="o">=</span> <span class="n">vmap</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">obs_dist</span><span class="o">.</span><span class="n">prob</span><span class="p">(</span><span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">])),</span> <span class="n">in_axes</span><span class="o">=</span><span class="p">(</span><span class="kc">None</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span>
  1081. <span class="n">z</span> <span class="o">=</span> <span class="n">vmap</span><span class="p">(</span><span class="n">v_prob</span><span class="p">,</span> <span class="n">in_axes</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="kc">None</span><span class="p">))(</span><span class="n">xs</span><span class="p">,</span> <span class="n">ys</span><span class="p">)</span>
  1082. <span class="n">grid</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">mgrid</span><span class="p">[</span><span class="n">xmin</span><span class="p">:</span><span class="n">xmax</span><span class="p">:</span><span class="n">step</span><span class="p">,</span> <span class="n">ymin</span><span class="p">:</span><span class="n">ymax</span><span class="p">:</span><span class="n">step</span><span class="p">]</span>
  1083. <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">color</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">colors</span><span class="p">):</span>
  1084. <span class="n">ax</span><span class="o">.</span><span class="n">contour</span><span class="p">(</span><span class="o">*</span><span class="n">grid</span><span class="p">,</span> <span class="n">z</span><span class="p">[:,</span> <span class="p">:,</span> <span class="n">k</span><span class="p">],</span> <span class="n">levels</span><span class="o">=</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">colors</span><span class="o">=</span><span class="n">color</span><span class="p">,</span> <span class="n">linewidths</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span>
  1085. <span class="n">ax</span><span class="o">.</span><span class="n">text</span><span class="p">(</span><span class="o">*</span><span class="p">(</span><span class="n">obs_dist</span><span class="o">.</span><span class="n">mean</span><span class="p">()[</span><span class="n">k</span><span class="p">]</span> <span class="o">+</span> <span class="mf">0.13</span><span class="p">),</span> <span class="sa">f</span><span class="s2">&quot;$k$=</span><span class="si">{</span><span class="n">k</span> <span class="o">+</span> <span class="mi">1</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">13</span><span class="p">,</span> <span class="n">horizontalalignment</span><span class="o">=</span><span class="s2">&quot;right&quot;</span><span class="p">)</span>
  1086. <span class="n">ax</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="o">*</span><span class="n">samples_obs</span><span class="o">.</span><span class="n">T</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="s2">&quot;black&quot;</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.3</span><span class="p">,</span> <span class="n">zorder</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
  1087. <span class="n">ax</span><span class="o">.</span><span class="n">scatter</span><span class="p">(</span><span class="o">*</span><span class="n">samples_obs</span><span class="o">.</span><span class="n">T</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="n">color_sample</span><span class="p">,</span> <span class="n">s</span><span class="o">=</span><span class="mi">30</span><span class="p">,</span> <span class="n">zorder</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.8</span><span class="p">)</span>
  1088. <span class="k">return</span> <span class="n">ax</span><span class="p">,</span> <span class="n">color_sample</span>
  1089. <span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">()</span>
  1090. <span class="n">_</span><span class="p">,</span> <span class="n">color_sample</span> <span class="o">=</span> <span class="n">plot_2dhmm</span><span class="p">(</span><span class="n">hmm</span><span class="p">,</span> <span class="n">samples_obs</span><span class="p">,</span> <span class="n">samples_state</span><span class="p">,</span> <span class="n">colors</span><span class="p">,</span> <span class="n">ax</span><span class="p">,</span> <span class="n">xmin</span><span class="p">,</span> <span class="n">xmax</span><span class="p">,</span> <span class="n">ymin</span><span class="p">,</span> <span class="n">ymax</span><span class="p">)</span>
  1091. </pre></div>
  1092. </div>
  1093. </div>
  1094. <div class="cell_output docutils container">
  1095. <img alt="../_images/ssm_old_15_0.png" src="../_images/ssm_old_15_0.png" />
  1096. </div>
  1097. </div>
  1098. <div class="cell docutils container">
  1099. <div class="cell_input docutils container">
  1100. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># Let&#39;s plot the hidden state sequence</span>
  1101. <span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">()</span>
  1102. <span class="n">ax</span><span class="o">.</span><span class="n">step</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">n_samples</span><span class="p">),</span> <span class="n">samples_state</span><span class="p">,</span> <span class="n">where</span><span class="o">=</span><span class="s2">&quot;post&quot;</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="s2">&quot;black&quot;</span><span class="p">,</span> <span class="n">linewidth</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.3</span><span class="p">)</span>
  1103. <span class="n">ax</span><span class="o">.</span><span class="n">scatter</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">n_samples</span><span class="p">),</span> <span class="n">samples_state</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="n">color_sample</span><span class="p">,</span> <span class="n">zorder</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span>
  1104. </pre></div>
  1105. </div>
  1106. </div>
  1107. <div class="cell_output docutils container">
  1108. <div class="output text_plain highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>&lt;matplotlib.collections.PathCollection at 0x7feb1957c0a0&gt;
  1109. </pre></div>
  1110. </div>
  1111. <img alt="../_images/ssm_old_16_1.png" src="../_images/ssm_old_16_1.png" />
  1112. </div>
  1113. </div>
  1114. </div>
  1115. </div>
  1116. <div class="tex2jax_ignore mathjax_ignore section" id="linear-gaussian-ssms">
  1117. <span id="sec-lds-intro"></span><h1>Linear Gaussian SSMs<a class="headerlink" href="#linear-gaussian-ssms" title="Permalink to this headline">¶</a></h1>
  1118. <p>Consider the state space model in
  1119. <a class="reference internal" href="#equation-eq-ssm-ar">()</a>
  1120. where we assume the observations are conditionally iid given the
  1121. hidden states and inputs (i.e. there are no auto-regressive dependencies
  1122. between the observables).
  1123. We can rewrite this model as
  1124. a stochastic nonlinear dynamical system (NLDS)
  1125. by defining the distribution of the next hidden state
  1126. as a deterministic function of the past state
  1127. plus random process noise <span class="math notranslate nohighlight">\(\vepsilon_t\)</span></p>
  1128. <div class="amsmath math notranslate nohighlight" id="equation-3536df60-5800-49ba-99b0-cfe3995291c6">
  1129. <span class="eqno">()<a class="headerlink" href="#equation-3536df60-5800-49ba-99b0-cfe3995291c6" title="Permalink to this equation">¶</a></span>\[\begin{align}
  1130. \hmmhid_t &amp;= \ssmDynFn(\hmmhid_{t-1}, \inputs_t, \vepsilon_t)
  1131. \end{align}\]</div>
  1132. <p>where <span class="math notranslate nohighlight">\(\vepsilon_t\)</span> is drawn from the distribution such
  1133. that the induced distribution
  1134. on <span class="math notranslate nohighlight">\(\hmmhid_t\)</span> matches <span class="math notranslate nohighlight">\(p(\hmmhid_t|\hmmhid_{t-1}, \inputs_t)\)</span>.
  1135. Similarly we can rewrite the observation distributions
  1136. as a deterministic function of the hidden state
  1137. plus observation noise <span class="math notranslate nohighlight">\(\veta_t\)</span>:</p>
  1138. <div class="amsmath math notranslate nohighlight" id="equation-a6cb71e2-fb89-4cd0-b11f-43d39e5b68fa">
  1139. <span class="eqno">()<a class="headerlink" href="#equation-a6cb71e2-fb89-4cd0-b11f-43d39e5b68fa" title="Permalink to this equation">¶</a></span>\[\begin{align}
  1140. \hmmobs_t &amp;= \ssmObsFn(\hmmhid_{t}, \inputs_t, \veta_t)
  1141. \end{align}\]</div>
  1142. <p>If we assume additive Gaussian noise,
  1143. the model becomes</p>
  1144. <div class="amsmath math notranslate nohighlight" id="equation-5dba5e0d-f668-43c5-8d30-ccde8cf1d8b5">
  1145. <span class="eqno">()<a class="headerlink" href="#equation-5dba5e0d-f668-43c5-8d30-ccde8cf1d8b5" title="Permalink to this equation">¶</a></span>\[\begin{align}
  1146. \hmmhid_t &amp;= \ssmDynFn(\hmmhid_{t-1}, \inputs_t) + \vepsilon_t \\
  1147. \hmmobs_t &amp;= \ssmObsFn(\hmmhid_{t}, \inputs_t) + \veta_t
  1148. \end{align}\]</div>
  1149. <p>where <span class="math notranslate nohighlight">\(\vepsilon_t \sim \gauss(\vzero,\vQ_t)\)</span>
  1150. and <span class="math notranslate nohighlight">\(\veta_t \sim \gauss(\vzero,\vR_t)\)</span>.
  1151. We will call these Gaussian SSMs.</p>
  1152. <p>If we additionally assume
  1153. the transition function <span class="math notranslate nohighlight">\(\ssmDynFn\)</span>
  1154. and the observation function <span class="math notranslate nohighlight">\(\ssmObsFn\)</span> are both linear,
  1155. then we can rewrite the model as follows:</p>
  1156. <div class="amsmath math notranslate nohighlight" id="equation-ae1eaadf-e90f-49d3-b9f3-0520cc68759c">
  1157. <span class="eqno">()<a class="headerlink" href="#equation-ae1eaadf-e90f-49d3-b9f3-0520cc68759c" title="Permalink to this equation">¶</a></span>\[\begin{align}
  1158. p(\hmmhid_t|\hmmhid_{t-1},\inputs_t) &amp;= \gauss(\hmmhid_t|\ldsDyn_t \hmmhid_{t-1}
  1159. + \ldsDynIn_t \inputs_t, \vQ_t)
  1160. \\
  1161. p(\hmmobs_t|\hmmhid_t,\inputs_t) &amp;= \gauss(\hmmobs_t|\ldsObs_t \hmmhid_{t}
  1162. + \ldsObsIn_t \inputs_t, \vR_t)
  1163. \end{align}\]</div>
  1164. <p>This is called a
  1165. linear-Gaussian state space model
  1166. (LG-SSM),
  1167. or a
  1168. linear dynamical system (LDS).
  1169. We usually assume the parameters are independent of time, in which case
  1170. the model is said to be time-invariant or homogeneous.</p>
  1171. <div class="section" id="example-tracking-a-2d-point">
  1172. <span id="sec-kalman-tracking"></span><span id="sec-tracking-lds"></span><h2>Example: tracking a 2d point<a class="headerlink" href="#example-tracking-a-2d-point" title="Permalink to this headline">¶</a></h2>
  1173. <p>Consider an object moving in <span class="math notranslate nohighlight">\(\real^2\)</span>.
  1174. Let the state be
  1175. the position and velocity of the object,
  1176. $<span class="math notranslate nohighlight">\(\vz_t =\begin{pmatrix} u_t &amp; \dot{u}_t &amp; v_t &amp; \dot{v}_t \end{pmatrix}\)</span><span class="math notranslate nohighlight">\(.
  1177. (We use \)</span>u<span class="math notranslate nohighlight">\( and \)</span>v$ for the two coordinates,
  1178. to avoid confusion with the state and observation variables.)
  1179. If we use Euler discretization,
  1180. the dynamics become</p>
  1181. <div class="amsmath math notranslate nohighlight" id="equation-567370ff-6316-4f30-a817-15050de99ec9">
  1182. <span class="eqno">()<a class="headerlink" href="#equation-567370ff-6316-4f30-a817-15050de99ec9" title="Permalink to this equation">¶</a></span>\[\begin{align}
  1183. \underbrace{\begin{pmatrix} u_t\\ \dot{u}_t \\ v_t \\ \dot{v}_t \end{pmatrix}}_{\vz_t}
  1184. =
  1185. \underbrace{
  1186. \begin{pmatrix}
  1187. 1 &amp; 0 &amp; \Delta &amp; 0 \\
  1188. 0 &amp; 1 &amp; 0 &amp; \Delta\\
  1189. 0 &amp; 0 &amp; 1 &amp; 0 \\
  1190. 0 &amp; 0 &amp; 0 &amp; 1
  1191. \end{pmatrix}
  1192. }_{\ldsDyn}
  1193. \
  1194. \underbrace{\begin{pmatrix} u_{t-1} \\ \dot{u}_{t-1} \\ v_{t-1} \\ \dot{v}_{t-1} \end{pmatrix}}_{\vz_{t-1}}
  1195. + \vepsilon_t
  1196. \end{align}\]</div>
  1197. <p>where <span class="math notranslate nohighlight">\(\vepsilon_t \sim \gauss(\vzero,\vQ)\)</span> is
  1198. the process noise.</p>
  1199. <p>Let us assume
  1200. that the process noise is
  1201. a white noise process added to the velocity components
  1202. of the state, but not to the location.
  1203. (This is known as a random accelerations model.)
  1204. We can approximate the resulting process in discrete time by assuming
  1205. <span class="math notranslate nohighlight">\(\vQ = \diag(0, q, 0, q)\)</span>.
  1206. (See <span id="id3">[<a class="reference internal" href="../bib.html#id18" title="Simo Sarkka. Bayesian Filtering and Smoothing. Cambridge University Press, 2013. URL: https://users.aalto.fi/~ssarkka/pub/cup_book_online_20131111.pdf.">Sar13</a>]</span> p60 for a more accurate way
  1207. to convert the continuous time process to discrete time.)</p>
  1208. <p>Now suppose that at each discrete time point we
  1209. observe the location,
  1210. corrupted by Gaussian noise.
  1211. Thus the observation model becomes</p>
  1212. <div class="amsmath math notranslate nohighlight" id="equation-bb5d3cb2-5e61-472e-a04a-66081f35b5e8">
  1213. <span class="eqno">()<a class="headerlink" href="#equation-bb5d3cb2-5e61-472e-a04a-66081f35b5e8" title="Permalink to this equation">¶</a></span>\[\begin{align}
  1214. \underbrace{\begin{pmatrix} y_{1,t} \\ y_{2,t} \end{pmatrix}}_{\vy_t}
  1215. &amp;=
  1216. \underbrace{
  1217. \begin{pmatrix}
  1218. 1 &amp; 0 &amp; 0 &amp; 0 \\
  1219. 0 &amp; 0 &amp; 1 &amp; 0
  1220. \end{pmatrix}
  1221. }_{\ldsObs}
  1222. \
  1223. \underbrace{\begin{pmatrix} u_t\\ \dot{u}_t \\ v_t \\ \dot{v}_t \end{pmatrix}}_{\vz_t}
  1224. + \veta_t
  1225. \end{align}\]</div>
  1226. <p>where <span class="math notranslate nohighlight">\(\veta_t \sim \gauss(\vzero,\vR)\)</span> is the \keywordDef{observation noise}.
  1227. We see that the observation matrix <span class="math notranslate nohighlight">\(\ldsObs\)</span> simply ``extracts’’ the
  1228. relevant parts of the state vector.</p>
  1229. <p>Suppose we sample a trajectory and corresponding set
  1230. of noisy observations from this model,
  1231. <span class="math notranslate nohighlight">\((\vz_{1:T}, \vy_{1:T}) \sim p(\vz,\vy|\vtheta)\)</span>.
  1232. (We use diagonal observation noise,
  1233. <span class="math notranslate nohighlight">\(\vR = \diag(\sigma_1^2, \sigma_2^2)\)</span>.)
  1234. The results are shown below.</p>
  1235. <div class="cell docutils container">
  1236. <div class="cell_input docutils container">
  1237. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">key</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">PRNGKey</span><span class="p">(</span><span class="mi">314</span><span class="p">)</span>
  1238. <span class="n">timesteps</span> <span class="o">=</span> <span class="mi">15</span>
  1239. <span class="n">delta</span> <span class="o">=</span> <span class="mf">1.0</span>
  1240. <span class="n">A</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span>
  1241. <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">delta</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span>
  1242. <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">delta</span><span class="p">],</span>
  1243. <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span>
  1244. <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">]</span>
  1245. <span class="p">])</span>
  1246. <span class="n">C</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span>
  1247. <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span>
  1248. <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span>
  1249. <span class="p">])</span>
  1250. <span class="n">state_size</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">A</span><span class="o">.</span><span class="n">shape</span>
  1251. <span class="n">observation_size</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">C</span><span class="o">.</span><span class="n">shape</span>
  1252. <span class="n">Q</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="n">state_size</span><span class="p">)</span> <span class="o">*</span> <span class="mf">0.001</span>
  1253. <span class="n">R</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="n">observation_size</span><span class="p">)</span> <span class="o">*</span> <span class="mf">1.0</span>
  1254. <span class="c1"># Prior parameter distribution</span>
  1255. <span class="n">mu0</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mi">8</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">])</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="nb">float</span><span class="p">)</span>
  1256. <span class="n">Sigma0</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="n">state_size</span><span class="p">)</span> <span class="o">*</span> <span class="mf">1.0</span>
  1257. <span class="kn">from</span> <span class="nn">jsl.lds.kalman_filter</span> <span class="kn">import</span> <span class="n">LDS</span><span class="p">,</span> <span class="n">smooth</span><span class="p">,</span> <span class="nb">filter</span>
  1258. <span class="n">lds</span> <span class="o">=</span> <span class="n">LDS</span><span class="p">(</span><span class="n">A</span><span class="p">,</span> <span class="n">C</span><span class="p">,</span> <span class="n">Q</span><span class="p">,</span> <span class="n">R</span><span class="p">,</span> <span class="n">mu0</span><span class="p">,</span> <span class="n">Sigma0</span><span class="p">)</span>
  1259. <span class="nb">print</span><span class="p">(</span><span class="n">lds</span><span class="p">)</span>
  1260. </pre></div>
  1261. </div>
  1262. </div>
  1263. <div class="cell_output docutils container">
  1264. <div class="output stream highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>LDS(A=DeviceArray([[1., 0., 1., 0.],
  1265. [0., 1., 0., 1.],
  1266. [0., 0., 1., 0.],
  1267. [0., 0., 0., 1.]], dtype=float32), C=DeviceArray([[1, 0, 0, 0],
  1268. [0, 1, 0, 0]], dtype=int32), Q=DeviceArray([[0.001, 0. , 0. , 0. ],
  1269. [0. , 0.001, 0. , 0. ],
  1270. [0. , 0. , 0.001, 0. ],
  1271. [0. , 0. , 0. , 0.001]], dtype=float32), R=DeviceArray([[1., 0.],
  1272. [0., 1.]], dtype=float32), mu=DeviceArray([ 8., 10., 1., 0.], dtype=float32), Sigma=DeviceArray([[1., 0., 0., 0.],
  1273. [0., 1., 0., 0.],
  1274. [0., 0., 1., 0.],
  1275. [0., 0., 0., 1.]], dtype=float32), state_offset=None, obs_offset=None, nstates=4, nobs=2)
  1276. </pre></div>
  1277. </div>
  1278. </div>
  1279. </div>
  1280. <div class="cell docutils container">
  1281. <div class="cell_input docutils container">
  1282. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">jsl.demos.plot_utils</span> <span class="kn">import</span> <span class="n">plot_ellipse</span>
  1283. <span class="k">def</span> <span class="nf">plot_tracking_values</span><span class="p">(</span><span class="n">observed</span><span class="p">,</span> <span class="n">filtered</span><span class="p">,</span> <span class="n">cov_hist</span><span class="p">,</span> <span class="n">signal_label</span><span class="p">,</span> <span class="n">ax</span><span class="p">):</span>
  1284. <span class="n">timesteps</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">observed</span><span class="o">.</span><span class="n">shape</span>
  1285. <span class="n">ax</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">observed</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">observed</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">marker</span><span class="o">=</span><span class="s2">&quot;o&quot;</span><span class="p">,</span> <span class="n">linewidth</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
  1286. <span class="n">markerfacecolor</span><span class="o">=</span><span class="s2">&quot;none&quot;</span><span class="p">,</span> <span class="n">markeredgewidth</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">markersize</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s2">&quot;observed&quot;</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="s2">&quot;tab:green&quot;</span><span class="p">)</span>
  1287. <span class="n">ax</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="o">*</span><span class="n">filtered</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">T</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="n">signal_label</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="s2">&quot;tab:red&quot;</span><span class="p">,</span> <span class="n">marker</span><span class="o">=</span><span class="s2">&quot;x&quot;</span><span class="p">,</span> <span class="n">linewidth</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
  1288. <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">timesteps</span><span class="p">,</span> <span class="mi">1</span><span class="p">):</span>
  1289. <span class="n">covn</span> <span class="o">=</span> <span class="n">cov_hist</span><span class="p">[</span><span class="n">t</span><span class="p">][:</span><span class="mi">2</span><span class="p">,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]</span>
  1290. <span class="n">plot_ellipse</span><span class="p">(</span><span class="n">covn</span><span class="p">,</span> <span class="n">filtered</span><span class="p">[</span><span class="n">t</span><span class="p">,</span> <span class="p">:</span><span class="mi">2</span><span class="p">],</span> <span class="n">ax</span><span class="p">,</span> <span class="n">n_std</span><span class="o">=</span><span class="mf">2.0</span><span class="p">,</span> <span class="n">plot_center</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
  1291. <span class="n">ax</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">&quot;equal&quot;</span><span class="p">)</span>
  1292. <span class="n">ax</span><span class="o">.</span><span class="n">legend</span><span class="p">()</span>
  1293. </pre></div>
  1294. </div>
  1295. </div>
  1296. </div>
  1297. <div class="cell docutils container">
  1298. <div class="cell_input docutils container">
  1299. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">z_hist</span><span class="p">,</span> <span class="n">x_hist</span> <span class="o">=</span> <span class="n">lds</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">timesteps</span><span class="p">)</span>
  1300. <span class="n">fig_truth</span><span class="p">,</span> <span class="n">axs</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">()</span>
  1301. <span class="n">axs</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">x_hist</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">x_hist</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">],</span>
  1302. <span class="n">marker</span><span class="o">=</span><span class="s2">&quot;o&quot;</span><span class="p">,</span> <span class="n">linewidth</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">markerfacecolor</span><span class="o">=</span><span class="s2">&quot;none&quot;</span><span class="p">,</span>
  1303. <span class="n">markeredgewidth</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">markersize</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span>
  1304. <span class="n">label</span><span class="o">=</span><span class="s2">&quot;observed&quot;</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="s2">&quot;tab:green&quot;</span><span class="p">)</span>
  1305. <span class="n">axs</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">z_hist</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">z_hist</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">],</span>
  1306. <span class="n">linewidth</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s2">&quot;truth&quot;</span><span class="p">,</span>
  1307. <span class="n">marker</span><span class="o">=</span><span class="s2">&quot;s&quot;</span><span class="p">,</span> <span class="n">markersize</span><span class="o">=</span><span class="mi">8</span><span class="p">)</span>
  1308. <span class="n">axs</span><span class="o">.</span><span class="n">legend</span><span class="p">()</span>
  1309. <span class="n">axs</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">&quot;equal&quot;</span><span class="p">)</span>
  1310. </pre></div>
  1311. </div>
  1312. </div>
  1313. <div class="cell_output docutils container">
  1314. <div class="output text_plain highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>(7.24486608505249, 23.857812213897706, 8.042076778411865, 11.636079120635987)
  1315. </pre></div>
  1316. </div>
  1317. <img alt="../_images/ssm_old_21_1.png" src="../_images/ssm_old_21_1.png" />
  1318. </div>
  1319. </div>
  1320. <p>The main task is to infer the hidden states given the noisy
  1321. observations, i.e., <span class="math notranslate nohighlight">\(p(\vz|\vy,\vtheta)\)</span>. We discuss the topic of inference in <a class="reference internal" href="#sec-inference"><span class="std std-ref">Inferential goals</span></a>.</p>
  1322. </div>
  1323. </div>
  1324. <div class="tex2jax_ignore mathjax_ignore section" id="nonlinear-gaussian-ssms">
  1325. <span id="sec-nlds-intro"></span><h1>Nonlinear Gaussian SSMs<a class="headerlink" href="#nonlinear-gaussian-ssms" title="Permalink to this headline">¶</a></h1>
  1326. <p>In this section, we consider SSMs in which the dynamics and/or observation models are nonlinear,
  1327. but the process noise and observation noise are Gaussian.
  1328. That is,</p>
  1329. <div class="amsmath math notranslate nohighlight" id="equation-9b2f4563-7156-4411-8ff1-8b8c16adf162">
  1330. <span class="eqno">()<a class="headerlink" href="#equation-9b2f4563-7156-4411-8ff1-8b8c16adf162" title="Permalink to this equation">¶</a></span>\[\begin{align}
  1331. \hmmhid_t &amp;= \ssmDynFn(\hmmhid_{t-1}, \inputs_t) + \vepsilon_t \\
  1332. \hmmobs_t &amp;= \ssmObsFn(\hmmhid_{t}, \inputs_t) + \veta_t
  1333. \end{align}\]</div>
  1334. <p>where <span class="math notranslate nohighlight">\(\vepsilon_t \sim \gauss(\vzero,\vQ_t)\)</span>
  1335. and <span class="math notranslate nohighlight">\(\veta_t \sim \gauss(\vzero,\vR_t)\)</span>.
  1336. This is a very widely used model class. We give some examples below.</p>
  1337. <div class="section" id="example-tracking-a-1d-pendulum">
  1338. <span id="sec-pendulum"></span><h2>Example: tracking a 1d pendulum<a class="headerlink" href="#example-tracking-a-1d-pendulum" title="Permalink to this headline">¶</a></h2>
  1339. <div class="figure align-default" id="fig-pendulum">
  1340. <a class="reference internal image-reference" href="../_images/pendulum.png"><img alt="../_images/pendulum.png" src="../_images/pendulum.png" style="width: 265.0px; height: 295.0px;" /></a>
  1341. <p class="caption"><span class="caption-text">Illustration of a pendulum swinging.
  1342. <span class="math notranslate nohighlight">\(g\)</span> is the force of gravity,
  1343. <span class="math notranslate nohighlight">\(w(t)\)</span> is a random external force,
  1344. and <span class="math notranslate nohighlight">\(\alpha\)</span> is the angle wrt the vertical.
  1345. Based on <span id="id4">[<a class="reference internal" href="../bib.html#id18" title="Simo Sarkka. Bayesian Filtering and Smoothing. Cambridge University Press, 2013. URL: https://users.aalto.fi/~ssarkka/pub/cup_book_online_20131111.pdf.">Sar13</a>]</span> fig 3.10.</span><a class="headerlink" href="#fig-pendulum" title="Permalink to this image">¶</a></p>
  1346. </div>
  1347. <p>Consider a simple pendulum of unit mass and length swinging from
  1348. a fixed attachment, as in <a class="reference internal" href="#fig-pendulum"><span class="std std-ref">Illustration of a pendulum swinging.
  1349. g is the force of gravity,
  1350. w(t) is a random external force,
  1351. and \alpha is the angle wrt the vertical.
  1352. Based on Sarkka13 fig 3.10.</span></a>.
  1353. Such an object is in principle entirely deterministic in its behavior.
  1354. However, in the real world, there are often unknown forces at work
  1355. (e.g., air turbulence, friction).
  1356. We will model these by a continuous time random Gaussian noise process <span class="math notranslate nohighlight">\(w(t)\)</span>.
  1357. This gives rise to the following differential equation:</p>
  1358. <div class="amsmath math notranslate nohighlight" id="equation-dcdda512-ae3f-4ffb-bd90-db678338bb27">
  1359. <span class="eqno">()<a class="headerlink" href="#equation-dcdda512-ae3f-4ffb-bd90-db678338bb27" title="Permalink to this equation">¶</a></span>\[\begin{align}
  1360. \frac{d^2 \alpha}{d t^2}
  1361. = -g \sin(\alpha) + w(t)
  1362. \end{align}\]</div>
  1363. <p>We can write this as a nonlinear SSM by defining the state to be
  1364. <span class="math notranslate nohighlight">\(z_1(t) = \alpha(t)\)</span> and <span class="math notranslate nohighlight">\(z_2(t) = d\alpha(t)/dt\)</span>.
  1365. Thus</p>
  1366. <div class="amsmath math notranslate nohighlight" id="equation-27ccd595-dcb3-4feb-8f28-e9470220fd48">
  1367. <span class="eqno">()<a class="headerlink" href="#equation-27ccd595-dcb3-4feb-8f28-e9470220fd48" title="Permalink to this equation">¶</a></span>\[\begin{align}
  1368. \frac{d \vz}{dt}
  1369. = \begin{pmatrix} z_2 \\ -g \sin(z_1) \end{pmatrix}
  1370. + \begin{pmatrix} 0 \\ 1 \end{pmatrix} w(t)
  1371. \end{align}\]</div>
  1372. <p>If we discretize this step size <span class="math notranslate nohighlight">\(\Delta\)</span>,
  1373. we get the following
  1374. formulation <span id="id5">[<a class="reference internal" href="../bib.html#id18" title="Simo Sarkka. Bayesian Filtering and Smoothing. Cambridge University Press, 2013. URL: https://users.aalto.fi/~ssarkka/pub/cup_book_online_20131111.pdf.">Sar13</a>]</span> p74:</p>
  1375. <div class="amsmath math notranslate nohighlight" id="equation-37e2b498-70d7-45cf-a98f-e957e1dd0963">
  1376. <span class="eqno">()<a class="headerlink" href="#equation-37e2b498-70d7-45cf-a98f-e957e1dd0963" title="Permalink to this equation">¶</a></span>\[\begin{align}
  1377. \underbrace{
  1378. \begin{pmatrix} z_{1,t} \\ z_{2,t} \end{pmatrix}
  1379. }_{\hmmhid_t}
  1380. =
  1381. \underbrace{
  1382. \begin{pmatrix} z_{1,t-1} + z_{2,t-1} \Delta \\
  1383. z_{2,t-1} -g \sin(z_{1,t-1}) \Delta \end{pmatrix}
  1384. }_{\vf(\hmmhid_{t-1})}
  1385. +\vq_{t-1}
  1386. \end{align}\]</div>
  1387. <p>where <span class="math notranslate nohighlight">\(\vq_{t-1} \sim \gauss(\vzero,\vQ)\)</span> with</p>
  1388. <div class="amsmath math notranslate nohighlight" id="equation-65c3fab7-9891-4368-865b-935fc7a297fd">
  1389. <span class="eqno">()<a class="headerlink" href="#equation-65c3fab7-9891-4368-865b-935fc7a297fd" title="Permalink to this equation">¶</a></span>\[\begin{align}
  1390. \vQ = q^c \begin{pmatrix}
  1391. \frac{\Delta^3}{3} &amp; \frac{\Delta^2}{2} \\
  1392. \frac{\Delta^2}{2} &amp; \Delta
  1393. \end{pmatrix}
  1394. \end{align}\]</div>
  1395. <p>where <span class="math notranslate nohighlight">\(q^c\)</span> is the spectral density (continuous time variance)
  1396. of the continuous-time noise process.</p>
  1397. <p>If we observe the angular position, we
  1398. get the linear observation model</p>
  1399. <div class="amsmath math notranslate nohighlight" id="equation-714de64b-7aa6-4c3f-a17c-0bf21bbc0ec5">
  1400. <span class="eqno">()<a class="headerlink" href="#equation-714de64b-7aa6-4c3f-a17c-0bf21bbc0ec5" title="Permalink to this equation">¶</a></span>\[\begin{align}
  1401. y_t = \alpha_t + r_t = h(\hmmhid_t) + r_t
  1402. \end{align}\]</div>
  1403. <p>where <span class="math notranslate nohighlight">\(h(\hmmhid_t) = z_{1,t}\)</span>
  1404. and <span class="math notranslate nohighlight">\(r_t\)</span> is the observation noise.
  1405. If we only observe the horizontal position,
  1406. we get the nonlinear observation model</p>
  1407. <div class="amsmath math notranslate nohighlight" id="equation-ab4fa705-661d-4ef6-8f1d-a394e212e500">
  1408. <span class="eqno">()<a class="headerlink" href="#equation-ab4fa705-661d-4ef6-8f1d-a394e212e500" title="Permalink to this equation">¶</a></span>\[\begin{align}
  1409. y_t = \sin(\alpha_t) + r_t = h(\hmmhid_t) + r_t
  1410. \end{align}\]</div>
  1411. <p>where <span class="math notranslate nohighlight">\(h(\hmmhid_t) = \sin(z_{1,t})\)</span>.</p>
  1412. </div>
  1413. </div>
  1414. <div class="tex2jax_ignore mathjax_ignore section" id="inferential-goals">
  1415. <span id="sec-inference"></span><h1>Inferential goals<a class="headerlink" href="#inferential-goals" title="Permalink to this headline">¶</a></h1>
  1416. <div class="figure align-default" id="fig-dbn-inference">
  1417. <a class="reference internal image-reference" href="../_images/inference-problems-tikz.png"><img alt="../_images/inference-problems-tikz.png" src="../_images/inference-problems-tikz.png" style="width: 1174.0px; height: 898.0px;" /></a>
  1418. <p class="caption"><span class="caption-text">Illustration of the different kinds of inference in an SSM.
  1419. The main kinds of inference for state-space models.
  1420. The shaded region is the interval for which we have data.
  1421. The arrow represents the time step at which we want to perform inference.
  1422. <span class="math notranslate nohighlight">\(t\)</span> is the current time, <span class="math notranslate nohighlight">\(T\)</span> is the sequence length,
  1423. <span class="math notranslate nohighlight">\(\ell\)</span> is the lag and <span class="math notranslate nohighlight">\(h\)</span> is the prediction horizon.</span><a class="headerlink" href="#fig-dbn-inference" title="Permalink to this image">¶</a></p>
  1424. </div>
  1425. <p>Given the sequence of observations, and a known model,
  1426. one of the main tasks with SSMs
  1427. to perform posterior inference,
  1428. about the hidden states; this is also called
  1429. state estimation.
  1430. At each time step <span class="math notranslate nohighlight">\(t\)</span>,
  1431. there are multiple forms of posterior we may be interested in computing,
  1432. including the following:</p>
  1433. <ul class="simple">
  1434. <li><p>the filtering distribution
  1435. <span class="math notranslate nohighlight">\(p(\hmmhid_t|\hmmobs_{1:t})\)</span></p></li>
  1436. <li><p>the smoothing distribution
  1437. <span class="math notranslate nohighlight">\(p(\hmmhid_t|\hmmobs_{1:T})\)</span> (note that this conditions on future data <span class="math notranslate nohighlight">\(T&gt;t\)</span>)</p></li>
  1438. <li><p>the fixed-lag smoothing distribution
  1439. <span class="math notranslate nohighlight">\(p(\hmmhid_{t-\ell}|\hmmobs_{1:t})\)</span> (note that this
  1440. infers <span class="math notranslate nohighlight">\(\ell\)</span> steps in the past given data up to the present).</p></li>
  1441. </ul>
  1442. <p>We may also want to compute the
  1443. predictive distribution <span class="math notranslate nohighlight">\(h\)</span> steps into the future:</p>
  1444. <div class="amsmath math notranslate nohighlight" id="equation-47f5e405-e42b-41c8-9bea-867b94bfd570">
  1445. <span class="eqno">()<a class="headerlink" href="#equation-47f5e405-e42b-41c8-9bea-867b94bfd570" title="Permalink to this equation">¶</a></span>\[\begin{align}
  1446. p(\hmmobs_{t+h}|\hmmobs_{1:t})
  1447. &amp;= \sum_{\hmmhid_{t+h}} p(\hmmobs_{t+h}|\hmmhid_{t+h}) p(\hmmhid_{t+h}|\hmmobs_{1:t})
  1448. \end{align}\]</div>
  1449. <p>where the hidden state predictive distribution is</p>
  1450. <div class="amsmath math notranslate nohighlight" id="equation-9fb3fbe3-19d8-4734-96f6-00ee598c2510">
  1451. <span class="eqno">()<a class="headerlink" href="#equation-9fb3fbe3-19d8-4734-96f6-00ee598c2510" title="Permalink to this equation">¶</a></span>\[\begin{align}
  1452. p(\hmmhid_{t+h}|\hmmobs_{1:t})
  1453. &amp;= \sum_{\hmmhid_{t:t+h-1}}
  1454. p(\hmmhid_t|\hmmobs_{1:t})
  1455. p(\hmmhid_{t+1}|\hmmhid_{t})
  1456. p(\hmmhid_{t+2}|\hmmhid_{t+1})
  1457. \cdots
  1458. p(\hmmhid_{t+h}|\hmmhid_{t+h-1})
  1459. \end{align}\]</div>
  1460. <p>See <a class="reference internal" href="#fig-dbn-inference"><span class="std std-ref">Illustration of the different kinds of inference in an SSM.
  1461. The main kinds of inference for state-space models.
  1462. The shaded region is the interval for which we have data.
  1463. The arrow represents the time step at which we want to perform inference.
  1464. t is the current time, T is the sequence length,
  1465. \ell is the lag and h is the prediction horizon.</span></a> for a summary of these distributions.</p>
  1466. <p>In addition to comuting posterior marginals,
  1467. we may want to compute the most probable hidden sequence,
  1468. i.e., the joint MAP estimate</p>
  1469. <div class="math notranslate nohighlight">
  1470. \[\arg \max_{\hmmhid_{1:T}} p(\hmmhid_{1:T}|\hmmobs_{1:T})\]</div>
  1471. <p>or sample sequences from the posterior</p>
  1472. <div class="math notranslate nohighlight">
  1473. \[\hmmhid_{1:T} \sim p(\hmmhid_{1:T}|\hmmobs_{1:T})\]</div>
  1474. <p>Algorithms for all these task are discussed in the following chapters,
  1475. since the details depend on the form of the SSM.</p>
  1476. <div class="section" id="example-inference-in-the-casino-hmm">
  1477. <h2>Example: inference in the casino HMM<a class="headerlink" href="#example-inference-in-the-casino-hmm" title="Permalink to this headline">¶</a></h2>
  1478. <p>We now illustrate filtering, smoothing and MAP decoding applied
  1479. to the casino HMM from <span class="xref std std-ref">sec:casino</span>.</p>
  1480. <div class="cell docutils container">
  1481. <div class="cell_input docutils container">
  1482. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># Call inference engine</span>
  1483. <span class="n">filtered_dist</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">smoothed_dist</span><span class="p">,</span> <span class="n">loglik</span> <span class="o">=</span> <span class="n">hmm</span><span class="o">.</span><span class="n">forward_backward</span><span class="p">(</span><span class="n">x_hist</span><span class="p">)</span>
  1484. <span class="n">map_path</span> <span class="o">=</span> <span class="n">hmm</span><span class="o">.</span><span class="n">viterbi</span><span class="p">(</span><span class="n">x_hist</span><span class="p">)</span>
  1485. </pre></div>
  1486. </div>
  1487. </div>
  1488. <div class="cell_output docutils container">
  1489. <div class="output stderr highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>/opt/anaconda3/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:5154: UserWarning: Explicitly requested dtype &lt;class &#39;jax._src.numpy.lax_numpy.int64&#39;&gt; requested in astype is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  1490. lax_internal._check_user_dtype_supported(dtype, &quot;astype&quot;)
  1491. </pre></div>
  1492. </div>
  1493. </div>
  1494. </div>
  1495. <div class="cell docutils container">
  1496. <div class="cell_input docutils container">
  1497. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># Find the span of timesteps that the simulated systems turns to be in state 1</span>
  1498. <span class="k">def</span> <span class="nf">find_dishonest_intervals</span><span class="p">(</span><span class="n">z_hist</span><span class="p">):</span>
  1499. <span class="n">spans</span> <span class="o">=</span> <span class="p">[]</span>
  1500. <span class="n">x_init</span> <span class="o">=</span> <span class="mi">0</span>
  1501. <span class="k">for</span> <span class="n">t</span><span class="p">,</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">z_hist</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]):</span>
  1502. <span class="k">if</span> <span class="n">z_hist</span><span class="p">[</span><span class="n">t</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">z_hist</span><span class="p">[</span><span class="n">t</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
  1503. <span class="n">x_end</span> <span class="o">=</span> <span class="n">t</span>
  1504. <span class="n">spans</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">x_init</span><span class="p">,</span> <span class="n">x_end</span><span class="p">))</span>
  1505. <span class="k">elif</span> <span class="n">z_hist</span><span class="p">[</span><span class="n">t</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">and</span> <span class="n">z_hist</span><span class="p">[</span><span class="n">t</span><span class="p">]</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
  1506. <span class="n">x_init</span> <span class="o">=</span> <span class="n">t</span> <span class="o">+</span> <span class="mi">1</span>
  1507. <span class="k">return</span> <span class="n">spans</span>
  1508. </pre></div>
  1509. </div>
  1510. </div>
  1511. </div>
  1512. <div class="cell docutils container">
  1513. <div class="cell_input docutils container">
  1514. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># Plot posterior</span>
  1515. <span class="k">def</span> <span class="nf">plot_inference</span><span class="p">(</span><span class="n">inference_values</span><span class="p">,</span> <span class="n">z_hist</span><span class="p">,</span> <span class="n">ax</span><span class="p">,</span> <span class="n">state</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">map_estimate</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
  1516. <span class="n">n_samples</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">inference_values</span><span class="p">)</span>
  1517. <span class="n">xspan</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">n_samples</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
  1518. <span class="n">spans</span> <span class="o">=</span> <span class="n">find_dishonest_intervals</span><span class="p">(</span><span class="n">z_hist</span><span class="p">)</span>
  1519. <span class="k">if</span> <span class="n">map_estimate</span><span class="p">:</span>
  1520. <span class="n">ax</span><span class="o">.</span><span class="n">step</span><span class="p">(</span><span class="n">xspan</span><span class="p">,</span> <span class="n">inference_values</span><span class="p">,</span> <span class="n">where</span><span class="o">=</span><span class="s2">&quot;post&quot;</span><span class="p">)</span>
  1521. <span class="k">else</span><span class="p">:</span>
  1522. <span class="n">ax</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">xspan</span><span class="p">,</span> <span class="n">inference_values</span><span class="p">[:,</span> <span class="n">state</span><span class="p">])</span>
  1523. <span class="k">for</span> <span class="n">span</span> <span class="ow">in</span> <span class="n">spans</span><span class="p">:</span>
  1524. <span class="n">ax</span><span class="o">.</span><span class="n">axvspan</span><span class="p">(</span><span class="o">*</span><span class="n">span</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span> <span class="n">facecolor</span><span class="o">=</span><span class="s2">&quot;tab:gray&quot;</span><span class="p">,</span> <span class="n">edgecolor</span><span class="o">=</span><span class="s2">&quot;none&quot;</span><span class="p">)</span>
  1525. <span class="n">ax</span><span class="o">.</span><span class="n">set_xlim</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">n_samples</span><span class="p">)</span>
  1526. <span class="c1"># ax.set_ylim(0, 1)</span>
  1527. <span class="n">ax</span><span class="o">.</span><span class="n">set_ylim</span><span class="p">(</span><span class="o">-</span><span class="mf">0.1</span><span class="p">,</span> <span class="mf">1.1</span><span class="p">)</span>
  1528. <span class="n">ax</span><span class="o">.</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s2">&quot;Observation number&quot;</span><span class="p">)</span>
  1529. </pre></div>
  1530. </div>
  1531. </div>
  1532. </div>
  1533. <div class="cell docutils container">
  1534. <div class="cell_input docutils container">
  1535. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span> <span class="c1"># Filtering</span>
  1536. <span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">()</span>
  1537. <span class="n">plot_inference</span><span class="p">(</span><span class="n">filtered_dist</span><span class="p">,</span> <span class="n">z_hist</span><span class="p">,</span> <span class="n">ax</span><span class="p">)</span>
  1538. <span class="n">ax</span><span class="o">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s2">&quot;p(loaded)&quot;</span><span class="p">)</span>
  1539. <span class="n">ax</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">&quot;Filtered&quot;</span><span class="p">)</span>
  1540. </pre></div>
  1541. </div>
  1542. </div>
  1543. <div class="cell_output docutils container">
  1544. <div class="output traceback highlight-ipythontb notranslate"><div class="highlight"><pre><span></span><span class="gt">---------------------------------------------------------------------------</span>
  1545. <span class="ne">ValueError</span><span class="g g-Whitespace"> </span>Traceback (most recent call last)
  1546. <span class="o">&lt;</span><span class="n">ipython</span><span class="o">-</span><span class="nb">input</span><span class="o">-</span><span class="mi">17</span><span class="o">-</span><span class="n">a3a9c11b46bd</span><span class="o">&gt;</span> <span class="ow">in</span> <span class="o">&lt;</span><span class="n">module</span><span class="o">&gt;</span>
  1547. <span class="g g-Whitespace"> </span><span class="mi">1</span> <span class="c1"># Filtering</span>
  1548. <span class="g g-Whitespace"> </span><span class="mi">2</span> <span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">()</span>
  1549. <span class="ne">----&gt; </span><span class="mi">3</span> <span class="n">plot_inference</span><span class="p">(</span><span class="n">filtered_dist</span><span class="p">,</span> <span class="n">z_hist</span><span class="p">,</span> <span class="n">ax</span><span class="p">)</span>
  1550. <span class="g g-Whitespace"> </span><span class="mi">4</span> <span class="n">ax</span><span class="o">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s2">&quot;p(loaded)&quot;</span><span class="p">)</span>
  1551. <span class="g g-Whitespace"> </span><span class="mi">5</span> <span class="n">ax</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">&quot;Filtered&quot;</span><span class="p">)</span>
  1552. <span class="nn">&lt;ipython-input-16-9c8ddad3f57c&gt;</span> in <span class="ni">plot_inference</span><span class="nt">(inference_values, z_hist, ax, state, map_estimate)</span>
  1553. <span class="g g-Whitespace"> </span><span class="mi">3</span> <span class="n">n_samples</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">inference_values</span><span class="p">)</span>
  1554. <span class="g g-Whitespace"> </span><span class="mi">4</span> <span class="n">xspan</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">n_samples</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
  1555. <span class="ne">----&gt; </span><span class="mi">5</span> <span class="n">spans</span> <span class="o">=</span> <span class="n">find_dishonest_intervals</span><span class="p">(</span><span class="n">z_hist</span><span class="p">)</span>
  1556. <span class="g g-Whitespace"> </span><span class="mi">6</span> <span class="k">if</span> <span class="n">map_estimate</span><span class="p">:</span>
  1557. <span class="g g-Whitespace"> </span><span class="mi">7</span> <span class="n">ax</span><span class="o">.</span><span class="n">step</span><span class="p">(</span><span class="n">xspan</span><span class="p">,</span> <span class="n">inference_values</span><span class="p">,</span> <span class="n">where</span><span class="o">=</span><span class="s2">&quot;post&quot;</span><span class="p">)</span>
  1558. <span class="nn">&lt;ipython-input-15-4606c615e17a&gt;</span> in <span class="ni">find_dishonest_intervals</span><span class="nt">(z_hist)</span>
  1559. <span class="g g-Whitespace"> </span><span class="mi">4</span> <span class="n">x_init</span> <span class="o">=</span> <span class="mi">0</span>
  1560. <span class="g g-Whitespace"> </span><span class="mi">5</span> <span class="k">for</span> <span class="n">t</span><span class="p">,</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">z_hist</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]):</span>
  1561. <span class="ne">----&gt; </span><span class="mi">6</span> <span class="k">if</span> <span class="n">z_hist</span><span class="p">[</span><span class="n">t</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">z_hist</span><span class="p">[</span><span class="n">t</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
  1562. <span class="g g-Whitespace"> </span><span class="mi">7</span> <span class="n">x_end</span> <span class="o">=</span> <span class="n">t</span>
  1563. <span class="g g-Whitespace"> </span><span class="mi">8</span> <span class="n">spans</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">x_init</span><span class="p">,</span> <span class="n">x_end</span><span class="p">))</span>
  1564. <span class="nn">/opt/anaconda3/lib/python3.8/functools.py</span> in <span class="ni">_method</span><span class="nt">(cls_or_self, *args, **keywords)</span>
  1565. <span class="g g-Whitespace"> </span><span class="mi">397</span> <span class="k">def</span> <span class="nf">_method</span><span class="p">(</span><span class="n">cls_or_self</span><span class="p">,</span> <span class="o">/</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">keywords</span><span class="p">):</span>
  1566. <span class="g g-Whitespace"> </span><span class="mi">398</span> <span class="n">keywords</span> <span class="o">=</span> <span class="p">{</span><span class="o">**</span><span class="bp">self</span><span class="o">.</span><span class="n">keywords</span><span class="p">,</span> <span class="o">**</span><span class="n">keywords</span><span class="p">}</span>
  1567. <span class="ne">--&gt; </span><span class="mi">399</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">func</span><span class="p">(</span><span class="n">cls_or_self</span><span class="p">,</span> <span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">keywords</span><span class="p">)</span>
  1568. <span class="g g-Whitespace"> </span><span class="mi">400</span> <span class="n">_method</span><span class="o">.</span><span class="n">__isabstractmethod__</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">__isabstractmethod__</span>
  1569. <span class="g g-Whitespace"> </span><span class="mi">401</span> <span class="n">_method</span><span class="o">.</span><span class="n">_partialmethod</span> <span class="o">=</span> <span class="bp">self</span>
  1570. <span class="nn">/opt/anaconda3/lib/python3.8/site-packages/jax/_src/device_array.py</span> in <span class="ni">_forward_method</span><span class="nt">(attrname, self, fun, *args)</span>
  1571. <span class="g g-Whitespace"> </span><span class="mi">39</span>
  1572. <span class="g g-Whitespace"> </span><span class="mi">40</span> <span class="k">def</span> <span class="nf">_forward_method</span><span class="p">(</span><span class="n">attrname</span><span class="p">,</span> <span class="bp">self</span><span class="p">,</span> <span class="n">fun</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">):</span>
  1573. <span class="ne">---&gt; </span><span class="mi">41</span> <span class="k">return</span> <span class="n">fun</span><span class="p">(</span><span class="nb">getattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">attrname</span><span class="p">),</span> <span class="o">*</span><span class="n">args</span><span class="p">)</span>
  1574. <span class="g g-Whitespace"> </span><span class="mi">42</span> <span class="n">_forward_to_value</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">_forward_method</span><span class="p">,</span> <span class="s2">&quot;_value&quot;</span><span class="p">)</span>
  1575. <span class="g g-Whitespace"> </span><span class="mi">43</span>
  1576. <span class="ne">ValueError</span>: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
  1577. </pre></div>
  1578. </div>
  1579. <img alt="../_images/ssm_old_30_1.png" src="../_images/ssm_old_30_1.png" />
  1580. </div>
  1581. </div>
  1582. <div class="cell docutils container">
  1583. <div class="cell_input docutils container">
  1584. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># Smoothing</span>
  1585. <span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">()</span>
  1586. <span class="n">plot_inference</span><span class="p">(</span><span class="n">smoothed_dist</span><span class="p">,</span> <span class="n">z_hist</span><span class="p">,</span> <span class="n">ax</span><span class="p">)</span>
  1587. <span class="n">ax</span><span class="o">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s2">&quot;p(loaded)&quot;</span><span class="p">)</span>
  1588. <span class="n">ax</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">&quot;Smoothed&quot;</span><span class="p">)</span>
  1589. </pre></div>
  1590. </div>
  1591. </div>
  1592. <div class="cell_output docutils container">
  1593. <div class="output text_plain highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>Text(0.5, 1.0, &#39;Smoothed&#39;)
  1594. </pre></div>
  1595. </div>
  1596. <img alt="../_images/ssm_old_31_1.png" src="../_images/ssm_old_31_1.png" />
  1597. </div>
  1598. </div>
  1599. <div class="cell docutils container">
  1600. <div class="cell_input docutils container">
  1601. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># MAP estimation</span>
  1602. <span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">()</span>
  1603. <span class="n">plot_inference</span><span class="p">(</span><span class="n">map_path</span><span class="p">,</span> <span class="n">z_hist</span><span class="p">,</span> <span class="n">ax</span><span class="p">,</span> <span class="n">map_estimate</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
  1604. <span class="n">ax</span><span class="o">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s2">&quot;MAP state&quot;</span><span class="p">)</span>
  1605. <span class="n">ax</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">&quot;Viterbi&quot;</span><span class="p">)</span>
  1606. </pre></div>
  1607. </div>
  1608. </div>
  1609. </div>
  1610. <div class="cell docutils container">
  1611. <div class="cell_input docutils container">
  1612. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># TODO: posterior samples</span>
  1613. </pre></div>
  1614. </div>
  1615. </div>
  1616. </div>
  1617. </div>
  1618. <div class="section" id="example-inference-in-the-tracking-ssm">
  1619. <h2>Example: inference in the tracking SSM<a class="headerlink" href="#example-inference-in-the-tracking-ssm" title="Permalink to this headline">¶</a></h2>
  1620. <p>We now illustrate filtering, smoothing and MAP decoding applied
  1621. to the 2d tracking HMM from <a class="reference internal" href="#sec-tracking-lds"><span class="std std-ref">Example: tracking a 2d point</span></a>.</p>
  1622. </div>
  1623. </div>
  1624. <script type="text/x-thebe-config">
  1625. {
  1626. requestKernel: true,
  1627. binderOptions: {
  1628. repo: "binder-examples/jupyter-stacks-datascience",
  1629. ref: "master",
  1630. },
  1631. codeMirrorConfig: {
  1632. theme: "abcdef",
  1633. mode: "python"
  1634. },
  1635. kernelOptions: {
  1636. kernelName: "python3",
  1637. path: "./old"
  1638. },
  1639. predefinedOutput: true
  1640. }
  1641. </script>
  1642. <script>kernelName = 'python3'</script>
  1643. </div>
  1644. <!-- Previous / next buttons -->
  1645. <div class='prev-next-area'>
  1646. </div>
  1647. </div>
  1648. </div>
  1649. <footer class="footer">
  1650. <p>
  1651. By Kevin Murphy, Scott Linderman, et al.<br/>
  1652. &copy; Copyright 2021.<br/>
  1653. </p>
  1654. </footer>
  1655. </main>
  1656. </div>
  1657. </div>
  1658. <script src="../_static/js/index.be7d3bbb2ef33a8344ce.js"></script>
  1659. </body>
  1660. </html>