inference.html 64 KB


  1. <!DOCTYPE html>
  2. <html>
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>Inferential goals &#8212; State Space Models: A Modern Approach</title>
  7. <link href="../../_static/css/theme.css" rel="stylesheet">
  8. <link href="../../_static/css/index.ff1ffe594081f20da1ef19478df9384b.css" rel="stylesheet">
  9. <link rel="stylesheet"
  10. href="../../_static/vendor/fontawesome/5.13.0/css/all.min.css">
  11. <link rel="preload" as="font" type="font/woff2" crossorigin
  12. href="../../_static/vendor/fontawesome/5.13.0/webfonts/fa-solid-900.woff2">
  13. <link rel="preload" as="font" type="font/woff2" crossorigin
  14. href="../../_static/vendor/fontawesome/5.13.0/webfonts/fa-brands-400.woff2">
  15. <link rel="stylesheet" type="text/css" href="../../_static/pygments.css" />
  16. <link rel="stylesheet" type="text/css" href="../../_static/sphinx-book-theme.css?digest=c3fdc42140077d1ad13ad2f1588a4309" />
  17. <link rel="stylesheet" type="text/css" href="../../_static/togglebutton.css" />
  18. <link rel="stylesheet" type="text/css" href="../../_static/copybutton.css" />
  19. <link rel="stylesheet" type="text/css" href="../../_static/mystnb.css" />
  20. <link rel="stylesheet" type="text/css" href="../../_static/sphinx-thebe.css" />
  21. <link rel="stylesheet" type="text/css" href="../../_static/panels-main.c949a650a448cc0ae9fd3441c0e17fb0.css" />
  22. <link rel="stylesheet" type="text/css" href="../../_static/panels-variables.06eb56fa6e07937060861dad626602ad.css" />
  23. <link rel="preload" as="script" href="../../_static/js/index.be7d3bbb2ef33a8344ce.js">
  24. <script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
  25. <script src="../../_static/jquery.js"></script>
  26. <script src="../../_static/underscore.js"></script>
  27. <script src="../../_static/doctools.js"></script>
  28. <script src="../../_static/clipboard.min.js"></script>
  29. <script src="../../_static/copybutton.js"></script>
  30. <script>let toggleHintShow = 'Click to show';</script>
  31. <script>let toggleHintHide = 'Click to hide';</script>
  32. <script>let toggleOpenOnPrint = 'true';</script>
  33. <script src="../../_static/togglebutton.js"></script>
  34. <script>var togglebuttonSelector = '.toggle, .admonition.dropdown, .tag_hide_input div.cell_input, .tag_hide-input div.cell_input, .tag_hide_output div.cell_output, .tag_hide-output div.cell_output, .tag_hide_cell.cell, .tag_hide-cell.cell';</script>
  35. <script src="../../_static/sphinx-book-theme.d59cb220de22ca1c485ebbdc042f0030.js"></script>
  36. <script>const THEBE_JS_URL = "https://unpkg.com/thebe@0.8.2/lib/index.js"
  37. const thebe_selector = ".thebe,.cell"
  38. const thebe_selector_input = "pre"
  39. const thebe_selector_output = ".output, .cell_output"
  40. </script>
  41. <script async="async" src="../../_static/sphinx-thebe.js"></script>
  42. <script>window.MathJax = {"TeX": {"Macros": {"N": "\\mathbb{N}", "floor": ["\\lfloor#1\\rfloor", 1], "bmat": ["\\left[\\begin{array}"], "emat": ["\\end{array}\\right]"]}}, "options": {"processHtmlClass": "tex2jax_process|mathjax_process|math|output_area"}}</script>
  43. <script defer="defer" src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
  44. <link rel="index" title="Index" href="../../genindex.html" />
  45. <link rel="search" title="Search" href="../../search.html" />
  46. <link rel="next" title="Hidden Markov Models" href="../hmm/hmm_index.html" />
  47. <link rel="prev" title="Nonlinear Gaussian SSMs" href="nlds.html" />
  48. <meta name="viewport" content="width=device-width, initial-scale=1" />
  49. <meta name="docsearch:language" content="None">
  50. <!-- Google Analytics -->
  51. </head>
  52. <body data-spy="scroll" data-target="#bd-toc-nav" data-offset="80">
  53. <div class="container-fluid" id="banner"></div>
  54. <div class="container-xl">
  55. <div class="row">
  56. <div class="col-12 col-md-3 bd-sidebar site-navigation show" id="site-navigation">
  57. <div class="navbar-brand-box">
  58. <a class="navbar-brand text-wrap" href="../../index.html">
  59. <h1 class="site-logo" id="site-title">State Space Models: A Modern Approach</h1>
  60. </a>
  61. </div><form class="bd-search d-flex align-items-center" action="../../search.html" method="get">
  62. <i class="icon fas fa-search"></i>
  63. <input type="search" class="form-control" name="q" id="search-input" placeholder="Search this book..." aria-label="Search this book..." autocomplete="off" >
  64. </form><nav class="bd-links" id="bd-docs-nav" aria-label="Main">
  65. <div class="bd-toc-item active">
  66. <ul class="nav bd-sidenav">
  67. <li class="toctree-l1">
  68. <a class="reference internal" href="../../root.html">
  69. State Space Models: A Modern Approach
  70. </a>
  71. </li>
  72. </ul>
  73. <ul class="current nav bd-sidenav">
  74. <li class="toctree-l1">
  75. <a class="reference internal" href="../scratch.html">
  76. Scratchpad
  77. </a>
  78. </li>
  79. <li class="toctree-l1 current active has-children">
  80. <a class="reference internal" href="ssm_index.html">
  81. State Space Models
  82. </a>
  83. <input checked="" class="toctree-checkbox" id="toctree-checkbox-1" name="toctree-checkbox-1" type="checkbox"/>
  84. <label for="toctree-checkbox-1">
  85. <i class="fas fa-chevron-down">
  86. </i>
  87. </label>
  88. <ul class="current">
  89. <li class="toctree-l2">
  90. <a class="reference internal" href="ssm_intro.html">
  91. What are State Space Models?
  92. </a>
  93. </li>
  94. <li class="toctree-l2">
  95. <a class="reference internal" href="hmm.html">
  96. Hidden Markov Models
  97. </a>
  98. </li>
  99. <li class="toctree-l2">
  100. <a class="reference internal" href="lds.html">
  101. Linear Gaussian SSMs
  102. </a>
  103. </li>
  104. <li class="toctree-l2">
  105. <a class="reference internal" href="nlds.html">
  106. Nonlinear Gaussian SSMs
  107. </a>
  108. </li>
  109. <li class="toctree-l2 current active">
  110. <a class="current reference internal" href="#">
  111. Inferential goals
  112. </a>
  113. </li>
  114. </ul>
  115. </li>
  116. <li class="toctree-l1 has-children">
  117. <a class="reference internal" href="../hmm/hmm_index.html">
  118. Hidden Markov Models
  119. </a>
  120. <input class="toctree-checkbox" id="toctree-checkbox-2" name="toctree-checkbox-2" type="checkbox"/>
  121. <label for="toctree-checkbox-2">
  122. <i class="fas fa-chevron-down">
  123. </i>
  124. </label>
  125. <ul>
  126. <li class="toctree-l2">
  127. <a class="reference internal" href="../hmm/hmm_filter.html">
  128. HMM filtering (forwards algorithm)
  129. </a>
  130. </li>
  131. <li class="toctree-l2">
  132. <a class="reference internal" href="../hmm/hmm_smoother.html">
  133. HMM smoothing (forwards-backwards algorithm)
  134. </a>
  135. </li>
  136. <li class="toctree-l2">
  137. <a class="reference internal" href="../hmm/hmm_viterbi.html">
  138. Viterbi algorithm
  139. </a>
  140. </li>
  141. <li class="toctree-l2">
  142. <a class="reference internal" href="../hmm/hmm_parallel.html">
  143. Parallel HMM smoothing
  144. </a>
  145. </li>
  146. <li class="toctree-l2">
  147. <a class="reference internal" href="../hmm/hmm_sampling.html">
  148. Forwards-filtering backwards-sampling algorithm
  149. </a>
  150. </li>
  151. </ul>
  152. </li>
  153. <li class="toctree-l1 has-children">
  154. <a class="reference internal" href="../lgssm/lgssm_index.html">
  155. Linear-Gaussian SSMs
  156. </a>
  157. <input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
  158. <label for="toctree-checkbox-3">
  159. <i class="fas fa-chevron-down">
  160. </i>
  161. </label>
  162. <ul>
  163. <li class="toctree-l2">
  164. <a class="reference internal" href="../lgssm/kalman_filter.html">
  165. Kalman filtering
  166. </a>
  167. </li>
  168. <li class="toctree-l2">
  169. <a class="reference internal" href="../lgssm/kalman_smoother.html">
  170. Kalman (RTS) smoother
  171. </a>
  172. </li>
  173. <li class="toctree-l2">
  174. <a class="reference internal" href="../lgssm/kalman_parallel.html">
  175. Parallel Kalman Smoother
  176. </a>
  177. </li>
  178. <li class="toctree-l2">
  179. <a class="reference internal" href="../lgssm/kalman_sampling.html">
  180. Forwards-filtering backwards sampling
  181. </a>
  182. </li>
  183. </ul>
  184. </li>
  185. <li class="toctree-l1 has-children">
  186. <a class="reference internal" href="../extended/extended_index.html">
  187. Extended (linearized) methods
  188. </a>
  189. <input class="toctree-checkbox" id="toctree-checkbox-4" name="toctree-checkbox-4" type="checkbox"/>
  190. <label for="toctree-checkbox-4">
  191. <i class="fas fa-chevron-down">
  192. </i>
  193. </label>
  194. <ul>
  195. <li class="toctree-l2">
  196. <a class="reference internal" href="../extended/extended_filter.html">
  197. Extended Kalman filtering
  198. </a>
  199. </li>
  200. <li class="toctree-l2">
  201. <a class="reference internal" href="../extended/extended_smoother.html">
  202. Extended Kalman smoother
  203. </a>
  204. </li>
  205. <li class="toctree-l2">
  206. <a class="reference internal" href="../extended/extended_parallel.html">
  207. Parallel extended Kalman smoothing
  208. </a>
  209. </li>
  210. </ul>
  211. </li>
  212. <li class="toctree-l1 has-children">
  213. <a class="reference internal" href="../unscented/unscented_index.html">
  214. Unscented methods
  215. </a>
  216. <input class="toctree-checkbox" id="toctree-checkbox-5" name="toctree-checkbox-5" type="checkbox"/>
  217. <label for="toctree-checkbox-5">
  218. <i class="fas fa-chevron-down">
  219. </i>
  220. </label>
  221. <ul>
  222. <li class="toctree-l2">
  223. <a class="reference internal" href="../unscented/unscented_filter.html">
  224. Unscented filtering
  225. </a>
  226. </li>
  227. <li class="toctree-l2">
  228. <a class="reference internal" href="../unscented/unscented_smoother.html">
  229. Unscented smoothing
  230. </a>
  231. </li>
  232. </ul>
  233. </li>
  234. <li class="toctree-l1">
  235. <a class="reference internal" href="../quadrature/quadrature_index.html">
  236. Quadrature and cubature methods
  237. </a>
  238. </li>
  239. <li class="toctree-l1">
  240. <a class="reference internal" href="../postlin/postlin_index.html">
  241. Posterior linearization
  242. </a>
  243. </li>
  244. <li class="toctree-l1">
  245. <a class="reference internal" href="../adf/adf_index.html">
  246. Assumed Density Filtering
  247. </a>
  248. </li>
  249. <li class="toctree-l1">
  250. <a class="reference internal" href="../vi/vi_index.html">
  251. Variational inference
  252. </a>
  253. </li>
  254. <li class="toctree-l1">
  255. <a class="reference internal" href="../pf/pf_index.html">
  256. Particle filtering
  257. </a>
  258. </li>
  259. <li class="toctree-l1">
  260. <a class="reference internal" href="../smc/smc_index.html">
  261. Sequential Monte Carlo
  262. </a>
  263. </li>
  264. <li class="toctree-l1">
  265. <a class="reference internal" href="../learning/learning_index.html">
  266. Offline parameter estimation (learning)
  267. </a>
  268. </li>
  269. <li class="toctree-l1">
  270. <a class="reference internal" href="../tracking/tracking_index.html">
  271. Multi-target tracking
  272. </a>
  273. </li>
  274. <li class="toctree-l1">
  275. <a class="reference internal" href="../ensemble/ensemble_index.html">
  276. Data assimilation using Ensemble Kalman filter
  277. </a>
  278. </li>
  279. <li class="toctree-l1">
  280. <a class="reference internal" href="../bnp/bnp_index.html">
  281. Bayesian non-parametric SSMs
  282. </a>
  283. </li>
  284. <li class="toctree-l1">
  285. <a class="reference internal" href="../changepoint/changepoint_index.html">
  286. Changepoint detection
  287. </a>
  288. </li>
  289. <li class="toctree-l1">
  290. <a class="reference internal" href="../timeseries/timeseries_index.html">
  291. Timeseries forecasting
  292. </a>
  293. </li>
  294. <li class="toctree-l1">
  295. <a class="reference internal" href="../gp/gp_index.html">
  296. Markovian Gaussian processes
  297. </a>
  298. </li>
  299. <li class="toctree-l1">
  300. <a class="reference internal" href="../ode/ode_index.html">
  301. Differential equations and SSMs
  302. </a>
  303. </li>
  304. <li class="toctree-l1">
  305. <a class="reference internal" href="../control/control_index.html">
  306. Optimal control
  307. </a>
  308. </li>
  309. <li class="toctree-l1">
  310. <a class="reference internal" href="../../bib.html">
  311. Bibliography
  312. </a>
  313. </li>
  314. </ul>
  315. </div>
  316. </nav> <!-- To handle the deprecated key -->
  317. <div class="navbar_extra_footer">
  318. Powered by <a href="https://jupyterbook.org">Jupyter Book</a>
  319. </div>
  320. </div>
  321. <main class="col py-md-3 pl-md-4 bd-content overflow-auto" role="main">
  322. <div class="topbar container-xl fixed-top">
  323. <div class="topbar-contents row">
  324. <div class="col-12 col-md-3 bd-topbar-whitespace site-navigation show"></div>
  325. <div class="col pl-md-4 topbar-main">
  326. <button id="navbar-toggler" class="navbar-toggler ml-0" type="button" data-toggle="collapse"
  327. data-toggle="tooltip" data-placement="bottom" data-target=".site-navigation" aria-controls="navbar-menu"
  328. aria-expanded="true" aria-label="Toggle navigation" aria-controls="site-navigation"
  329. title="Toggle navigation" data-toggle="tooltip" data-placement="left">
  330. <i class="fas fa-bars"></i>
  331. <i class="fas fa-arrow-left"></i>
  332. <i class="fas fa-arrow-up"></i>
  333. </button>
  334. <div class="dropdown-buttons-trigger">
  335. <button id="dropdown-buttons-trigger" class="btn btn-secondary topbarbtn" aria-label="Download this page"><i
  336. class="fas fa-download"></i></button>
  337. <div class="dropdown-buttons">
  338. <!-- ipynb file if we had a myst markdown file -->
  339. <!-- Download raw file -->
  340. <a class="dropdown-buttons" href="../../_sources/chapters/ssm/inference.ipynb"><button type="button"
  341. class="btn btn-secondary topbarbtn" title="Download source file" data-toggle="tooltip"
  342. data-placement="left">.ipynb</button></a>
  343. <!-- Download PDF via print -->
  344. <button type="button" id="download-print" class="btn btn-secondary topbarbtn" title="Print to PDF"
  345. onclick="printPdf(this)" data-toggle="tooltip" data-placement="left">.pdf</button>
  346. </div>
  347. </div>
  348. <!-- Source interaction buttons -->
  349. <div class="dropdown-buttons-trigger">
  350. <button id="dropdown-buttons-trigger" class="btn btn-secondary topbarbtn"
  351. aria-label="Connect with source repository"><i class="fab fa-github"></i></button>
  352. <div class="dropdown-buttons sourcebuttons">
  353. <a class="repository-button"
  354. href="https://github.com/probml/ssm-book"><button type="button" class="btn btn-secondary topbarbtn"
  355. data-toggle="tooltip" data-placement="left" title="Source repository"><i
  356. class="fab fa-github"></i>repository</button></a>
  357. <a class="issues-button"
  358. href="https://github.com/probml/ssm-book/issues/new?title=Issue%20on%20page%20%2Fchapters/ssm/inference.html&body=Your%20issue%20content%20here."><button
  359. type="button" class="btn btn-secondary topbarbtn" data-toggle="tooltip" data-placement="left"
  360. title="Open an issue"><i class="fas fa-lightbulb"></i>open issue</button></a>
  361. </div>
  362. </div>
  363. <!-- Full screen (wrap in <a> to have style consistency -->
  364. <a class="full-screen-button"><button type="button" class="btn btn-secondary topbarbtn" data-toggle="tooltip"
  365. data-placement="bottom" onclick="toggleFullScreen()" aria-label="Fullscreen mode"
  366. title="Fullscreen mode"><i
  367. class="fas fa-expand"></i></button></a>
  368. <!-- Launch buttons -->
  369. <div class="dropdown-buttons-trigger">
  370. <button id="dropdown-buttons-trigger" class="btn btn-secondary topbarbtn"
  371. aria-label="Launch interactive content"><i class="fas fa-rocket"></i></button>
  372. <div class="dropdown-buttons">
  373. <a class="binder-button" href="https://mybinder.org/v2/gh/probml/ssm-book/main?urlpath=tree/chapters/ssm/inference.ipynb"><button type="button"
  374. class="btn btn-secondary topbarbtn" title="Launch Binder" data-toggle="tooltip"
  375. data-placement="left"><img class="binder-button-logo"
  376. src="../../_static/images/logo_binder.svg"
  377. alt="Interact on binder">Binder</button></a>
  378. <a class="colab-button" href="https://colab.research.google.com/github/probml/ssm-book/blob/main/chapters/ssm/inference.ipynb"><button type="button" class="btn btn-secondary topbarbtn"
  379. title="Launch Colab" data-toggle="tooltip" data-placement="left"><img class="colab-button-logo"
  380. src="../../_static/images/logo_colab.png"
  381. alt="Interact on Colab">Colab</button></a>
  382. </div>
  383. </div>
  384. </div>
  385. <!-- Table of contents -->
  386. <div class="d-none d-md-block col-md-2 bd-toc show noprint">
  387. <div class="tocsection onthispage pt-5 pb-3">
  388. <i class="fas fa-list"></i> Contents
  389. </div>
  390. <nav id="bd-toc-nav" aria-label="Page">
  391. <ul class="visible nav section-nav flex-column">
  392. <li class="toc-h2 nav-item toc-entry">
  393. <a class="reference internal nav-link" href="#example-inference-in-the-casino-hmm">
  394. Example: inference in the casino HMM
  395. </a>
  396. </li>
  397. <li class="toc-h2 nav-item toc-entry">
  398. <a class="reference internal nav-link" href="#example-inference-in-the-tracking-lg-ssm">
  399. Example: inference in the tracking LG-SSM
  400. </a>
  401. </li>
  402. </ul>
  403. </nav>
  404. </div>
  405. </div>
  406. </div>
  407. <div id="main-content" class="row">
  408. <div class="col-12 col-md-9 pl-md-3 pr-md-0">
  409. <!-- Table of contents that is only displayed when printing the page -->
  410. <div id="jb-print-docs-body" class="onlyprint">
  411. <h1>Inferential goals</h1>
  412. <!-- Table of contents -->
  413. <div id="print-main-content">
  414. <div id="jb-print-toc">
  415. <div>
  416. <h2> Contents </h2>
  417. </div>
  418. <nav aria-label="Page">
  419. <ul class="visible nav section-nav flex-column">
  420. <li class="toc-h2 nav-item toc-entry">
  421. <a class="reference internal nav-link" href="#example-inference-in-the-casino-hmm">
  422. Example: inference in the casino HMM
  423. </a>
  424. </li>
  425. <li class="toc-h2 nav-item toc-entry">
  426. <a class="reference internal nav-link" href="#example-inference-in-the-tracking-lg-ssm">
  427. Example: inference in the tracking LG-SSM
  428. </a>
  429. </li>
  430. </ul>
  431. </nav>
  432. </div>
  433. </div>
  434. </div>
  435. <div>
  436. <div class="cell docutils container">
  437. <div class="cell_input docutils container">
  438. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># meta-data does not work yet in VScode</span>
  439. <span class="c1"># https://github.com/microsoft/vscode-jupyter/issues/1121</span>
  440. <span class="p">{</span>
  441. <span class="s2">&quot;tags&quot;</span><span class="p">:</span> <span class="p">[</span>
  442. <span class="s2">&quot;hide-cell&quot;</span>
  443. <span class="p">]</span>
  444. <span class="p">}</span>
  445. <span class="c1">### Install necessary libraries</span>
  446. <span class="k">try</span><span class="p">:</span>
  447. <span class="kn">import</span> <span class="nn">jax</span>
  448. <span class="k">except</span><span class="p">:</span>
  449. <span class="c1"># For cuda version, see https://github.com/google/jax#installation</span>
  450. <span class="o">%</span><span class="k">pip</span> install --upgrade &quot;jax[cpu]&quot;
  451. <span class="kn">import</span> <span class="nn">jax</span>
  452. <span class="k">try</span><span class="p">:</span>
  453. <span class="kn">import</span> <span class="nn">distrax</span>
  454. <span class="k">except</span><span class="p">:</span>
  455. <span class="o">%</span><span class="k">pip</span> install --upgrade distrax
  456. <span class="kn">import</span> <span class="nn">distrax</span>
  457. <span class="k">try</span><span class="p">:</span>
  458. <span class="kn">import</span> <span class="nn">jsl</span>
  459. <span class="k">except</span><span class="p">:</span>
  460. <span class="o">%</span><span class="k">pip</span> install git+https://github.com/probml/jsl
  461. <span class="kn">import</span> <span class="nn">jsl</span>
  462. <span class="k">try</span><span class="p">:</span>
  463. <span class="kn">import</span> <span class="nn">rich</span>
  464. <span class="k">except</span><span class="p">:</span>
  465. <span class="o">%</span><span class="k">pip</span> install rich
  466. <span class="kn">import</span> <span class="nn">rich</span>
  467. </pre></div>
  468. </div>
  469. </div>
  470. </div>
  471. <div class="cell docutils container">
  472. <div class="cell_input docutils container">
  473. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="p">{</span>
  474. <span class="s2">&quot;tags&quot;</span><span class="p">:</span> <span class="p">[</span>
  475. <span class="s2">&quot;hide-cell&quot;</span>
  476. <span class="p">]</span>
  477. <span class="p">}</span>
  478. <span class="c1">### Import standard libraries</span>
  479. <span class="kn">import</span> <span class="nn">abc</span>
  480. <span class="kn">from</span> <span class="nn">dataclasses</span> <span class="kn">import</span> <span class="n">dataclass</span>
  481. <span class="kn">import</span> <span class="nn">functools</span>
  482. <span class="kn">import</span> <span class="nn">itertools</span>
  483. <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span><span class="p">,</span> <span class="n">Callable</span><span class="p">,</span> <span class="n">NamedTuple</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Union</span><span class="p">,</span> <span class="n">Tuple</span>
  484. <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span>
  485. <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
  486. <span class="kn">import</span> <span class="nn">jax</span>
  487. <span class="kn">import</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span>
  488. <span class="kn">from</span> <span class="nn">jax</span> <span class="kn">import</span> <span class="n">lax</span><span class="p">,</span> <span class="n">vmap</span><span class="p">,</span> <span class="n">jit</span><span class="p">,</span> <span class="n">grad</span>
  489. <span class="kn">from</span> <span class="nn">jax.scipy.special</span> <span class="kn">import</span> <span class="n">logit</span>
  490. <span class="kn">from</span> <span class="nn">jax.nn</span> <span class="kn">import</span> <span class="n">softmax</span>
  491. <span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">partial</span>
  492. <span class="kn">from</span> <span class="nn">jax.random</span> <span class="kn">import</span> <span class="n">PRNGKey</span><span class="p">,</span> <span class="n">split</span>
  493. <span class="kn">import</span> <span class="nn">inspect</span>
  494. <span class="kn">import</span> <span class="nn">inspect</span> <span class="k">as</span> <span class="nn">py_inspect</span>
  495. <span class="kn">import</span> <span class="nn">rich</span>
  496. <span class="kn">from</span> <span class="nn">rich</span> <span class="kn">import</span> <span class="n">inspect</span> <span class="k">as</span> <span class="n">r_inspect</span>
  497. <span class="kn">from</span> <span class="nn">rich</span> <span class="kn">import</span> <span class="nb">print</span> <span class="k">as</span> <span class="n">r_print</span>
  498. <span class="k">def</span> <span class="nf">print_source</span><span class="p">(</span><span class="n">fname</span><span class="p">):</span>
  499. <span class="n">r_print</span><span class="p">(</span><span class="n">py_inspect</span><span class="o">.</span><span class="n">getsource</span><span class="p">(</span><span class="n">fname</span><span class="p">))</span>
  500. </pre></div>
  501. </div>
  502. </div>
  503. </div>
  504. <div class="math notranslate nohighlight">
  505. \[ \begin{align}\begin{aligned}\newcommand\floor[1]{\lfloor#1\rfloor}\\\newcommand{\real}{\mathbb{R}}\\% Numbers
  506. \newcommand{\vzero}{\boldsymbol{0}}
  507. \newcommand{\vone}{\boldsymbol{1}}\\% Greek https://www.latex-tutorial.com/symbols/greek-alphabet/
  508. \newcommand{\valpha}{\boldsymbol{\alpha}}
  509. \newcommand{\vbeta}{\boldsymbol{\beta}}
  510. \newcommand{\vchi}{\boldsymbol{\chi}}
  511. \newcommand{\vdelta}{\boldsymbol{\delta}}
  512. \newcommand{\vDelta}{\boldsymbol{\Delta}}
  513. \newcommand{\vepsilon}{\boldsymbol{\epsilon}}
  514. \newcommand{\vzeta}{\boldsymbol{\zeta}}
  515. \newcommand{\vXi}{\boldsymbol{\Xi}}
  516. \newcommand{\vell}{\boldsymbol{\ell}}
  517. \newcommand{\veta}{\boldsymbol{\eta}}
  518. %\newcommand{\vEta}{\boldsymbol{\Eta}}
  519. \newcommand{\vgamma}{\boldsymbol{\gamma}}
  520. \newcommand{\vGamma}{\boldsymbol{\Gamma}}
  521. \newcommand{\vmu}{\boldsymbol{\mu}}
  522. \newcommand{\vmut}{\boldsymbol{\tilde{\mu}}}
  523. \newcommand{\vnu}{\boldsymbol{\nu}}
  524. \newcommand{\vkappa}{\boldsymbol{\kappa}}
  525. \newcommand{\vlambda}{\boldsymbol{\lambda}}
  526. \newcommand{\vLambda}{\boldsymbol{\Lambda}}
  527. \newcommand{\vLambdaBar}{\overline{\vLambda}}
  528. %\newcommand{\vnu}{\boldsymbol{\nu}}
  529. \newcommand{\vomega}{\boldsymbol{\omega}}
  530. \newcommand{\vOmega}{\boldsymbol{\Omega}}
  531. \newcommand{\vphi}{\boldsymbol{\phi}}
  532. \newcommand{\vvarphi}{\boldsymbol{\varphi}}
  533. \newcommand{\vPhi}{\boldsymbol{\Phi}}
  534. \newcommand{\vpi}{\boldsymbol{\pi}}
  535. \newcommand{\vPi}{\boldsymbol{\Pi}}
  536. \newcommand{\vpsi}{\boldsymbol{\psi}}
  537. \newcommand{\vPsi}{\boldsymbol{\Psi}}
  538. \newcommand{\vrho}{\boldsymbol{\rho}}
  539. \newcommand{\vtheta}{\boldsymbol{\theta}}
  540. \newcommand{\vthetat}{\boldsymbol{\tilde{\theta}}}
  541. \newcommand{\vTheta}{\boldsymbol{\Theta}}
  542. \newcommand{\vsigma}{\boldsymbol{\sigma}}
  543. \newcommand{\vSigma}{\boldsymbol{\Sigma}}
  544. \newcommand{\vSigmat}{\boldsymbol{\tilde{\Sigma}}}
  545. \newcommand{\vsigmoid}{\vsigma}
  546. \newcommand{\vtau}{\boldsymbol{\tau}}
  547. \newcommand{\vxi}{\boldsymbol{\xi}}\\
  548. % Lower Roman (Vectors)
  549. \newcommand{\va}{\mathbf{a}}
  550. \newcommand{\vb}{\mathbf{b}}
  551. \newcommand{\vBt}{\mathbf{\tilde{B}}}
  552. \newcommand{\vc}{\mathbf{c}}
  553. \newcommand{\vct}{\mathbf{\tilde{c}}}
  554. \newcommand{\vd}{\mathbf{d}}
  555. \newcommand{\ve}{\mathbf{e}}
  556. \newcommand{\vf}{\mathbf{f}}
  557. \newcommand{\vg}{\mathbf{g}}
  558. \newcommand{\vh}{\mathbf{h}}
  559. %\newcommand{\myvh}{\mathbf{h}}
  560. \newcommand{\vi}{\mathbf{i}}
  561. \newcommand{\vj}{\mathbf{j}}
  562. \newcommand{\vk}{\mathbf{k}}
  563. \newcommand{\vl}{\mathbf{l}}
  564. \newcommand{\vm}{\mathbf{m}}
  565. \newcommand{\vn}{\mathbf{n}}
  566. \newcommand{\vo}{\mathbf{o}}
  567. \newcommand{\vp}{\mathbf{p}}
  568. \newcommand{\vq}{\mathbf{q}}
  569. \newcommand{\vr}{\mathbf{r}}
  570. \newcommand{\vs}{\mathbf{s}}
  571. \newcommand{\vt}{\mathbf{t}}
  572. \newcommand{\vu}{\mathbf{u}}
  573. \newcommand{\vv}{\mathbf{v}}
  574. \newcommand{\vw}{\mathbf{w}}
  575. \newcommand{\vws}{\vw_s}
  576. \newcommand{\vwt}{\mathbf{\tilde{w}}}
  577. \newcommand{\vWt}{\mathbf{\tilde{W}}}
  578. \newcommand{\vwh}{\hat{\vw}}
  579. \newcommand{\vx}{\mathbf{x}}
  580. %\newcommand{\vx}{\mathbf{x}}
  581. \newcommand{\vxt}{\mathbf{\tilde{x}}}
  582. \newcommand{\vy}{\mathbf{y}}
  583. \newcommand{\vyt}{\mathbf{\tilde{y}}}
  584. \newcommand{\vz}{\mathbf{z}}
  585. %\newcommand{\vzt}{\mathbf{\tilde{z}}}\\
  586. % Upper Roman (Matrices)
  587. \newcommand{\vA}{\mathbf{A}}
  588. \newcommand{\vB}{\mathbf{B}}
  589. \newcommand{\vC}{\mathbf{C}}
  590. \newcommand{\vD}{\mathbf{D}}
  591. \newcommand{\vE}{\mathbf{E}}
  592. \newcommand{\vF}{\mathbf{F}}
  593. \newcommand{\vG}{\mathbf{G}}
  594. \newcommand{\vH}{\mathbf{H}}
  595. \newcommand{\vI}{\mathbf{I}}
  596. \newcommand{\vJ}{\mathbf{J}}
  597. \newcommand{\vK}{\mathbf{K}}
  598. \newcommand{\vL}{\mathbf{L}}
  599. \newcommand{\vM}{\mathbf{M}}
  600. \newcommand{\vMt}{\mathbf{\tilde{M}}}
  601. \newcommand{\vN}{\mathbf{N}}
  602. \newcommand{\vO}{\mathbf{O}}
  603. \newcommand{\vP}{\mathbf{P}}
  604. \newcommand{\vQ}{\mathbf{Q}}
  605. \newcommand{\vR}{\mathbf{R}}
  606. \newcommand{\vS}{\mathbf{S}}
  607. \newcommand{\vT}{\mathbf{T}}
  608. \newcommand{\vU}{\mathbf{U}}
  609. \newcommand{\vV}{\mathbf{V}}
  610. \newcommand{\vW}{\mathbf{W}}
  611. \newcommand{\vX}{\mathbf{X}}
  612. %\newcommand{\vXs}{\vX_{\vs}}
  613. \newcommand{\vXs}{\vX_{s}}
  614. \newcommand{\vXt}{\mathbf{\tilde{X}}}
  615. \newcommand{\vY}{\mathbf{Y}}
  616. \newcommand{\vZ}{\mathbf{Z}}
  617. \newcommand{\vZt}{\mathbf{\tilde{Z}}}
  618. \newcommand{\vzt}{\mathbf{\tilde{z}}}\\
  619. %%%%
  620. \newcommand{\hidden}{\vz}
  621. \newcommand{\hid}{\hidden}
  622. \newcommand{\observed}{\vy}
  623. \newcommand{\obs}{\observed}
  624. \newcommand{\inputs}{\vu}
  625. \newcommand{\input}{\inputs}\\\newcommand{\hmmTrans}{\vA}
  626. \newcommand{\hmmObs}{\vB}
  627. \newcommand{\hmmInit}{\vpi}
  628. \newcommand{\hmmhid}{\hidden}
  629. \newcommand{\hmmobs}{\obs}\\\newcommand{\ldsDyn}{\vA}
  630. \newcommand{\ldsObs}{\vC}
  631. \newcommand{\ldsDynIn}{\vB}
  632. \newcommand{\ldsObsIn}{\vD}
  633. \newcommand{\ldsDynNoise}{\vQ}
  634. \newcommand{\ldsObsNoise}{\vR}\\\newcommand{\ssmDynFn}{f}
  635. \newcommand{\ssmObsFn}{h}\\
  636. %%%
  637. \newcommand{\gauss}{\mathcal{N}}\\\newcommand{\diag}{\mathrm{diag}}\end{aligned}\end{align} \]</div>
  638. <div class="tex2jax_ignore mathjax_ignore section" id="inferential-goals">
  639. <span id="sec-inference"></span><h1>Inferential goals<a class="headerlink" href="#inferential-goals" title="Permalink to this headline">¶</a></h1>
  640. <div class="figure align-default" id="fig-dbn-inference">
  641. <a class="reference internal image-reference" href="../../_images/inference-problems-tikz.png"><img alt="../../_images/inference-problems-tikz.png" src="../../_images/inference-problems-tikz.png" style="width: 352.2px; height: 269.4px;" /></a>
  642. <p class="caption"><span class="caption-number">Fig. 7 </span><span class="caption-text">Illustration of the different kinds of inference in an SSM.
  643. The main kinds of inference for state-space models.
  644. The shaded region is the interval for which we have data.
  645. The arrow represents the time step at which we want to perform inference.
  646. <span class="math notranslate nohighlight">\(t\)</span> is the current time, <span class="math notranslate nohighlight">\(T\)</span> is the sequence length,
  647. <span class="math notranslate nohighlight">\(\ell\)</span> is the lag and <span class="math notranslate nohighlight">\(h\)</span> is the prediction horizon.</span><a class="headerlink" href="#fig-dbn-inference" title="Permalink to this image">¶</a></p>
  648. </div>
  649. <p>Given the sequence of observations, and a known model,
  650. one of the main tasks with SSMs
  651. to perform posterior inference,
  652. about the hidden states; this is also called
  653. state estimation.
  654. At each time step <span class="math notranslate nohighlight">\(t\)</span>,
  655. there are multiple forms of posterior we may be interested in computing,
  656. including the following:</p>
  657. <ul class="simple">
  658. <li><p>the filtering distribution
  659. <span class="math notranslate nohighlight">\(p(\hmmhid_t|\hmmobs_{1:t})\)</span></p></li>
  660. <li><p>the smoothing distribution
  661. <span class="math notranslate nohighlight">\(p(\hmmhid_t|\hmmobs_{1:T})\)</span> (note that this conditions on future data <span class="math notranslate nohighlight">\(T&gt;t\)</span>)</p></li>
  662. <li><p>the fixed-lag smoothing distribution
  663. <span class="math notranslate nohighlight">\(p(\hmmhid_{t-\ell}|\hmmobs_{1:t})\)</span> (note that this
  664. infers <span class="math notranslate nohighlight">\(\ell\)</span> steps in the past given data up to the present).</p></li>
  665. </ul>
  666. <p>We may also want to compute the
  667. predictive distribution <span class="math notranslate nohighlight">\(h\)</span> steps into the future:</p>
  668. <div class="math notranslate nohighlight">
  669. \[p(\hmmobs_{t+h}|\hmmobs_{1:t})
  670. = \sum_{\hmmhid_{t+h}} p(\hmmobs_{t+h}|\hmmhid_{t+h}) p(\hmmhid_{t+h}|\hmmobs_{1:t})\]</div>
  671. <p>where the hidden state predictive distribution is</p>
  672. <div class="amsmath math notranslate nohighlight" id="equation-37bb812e-8d87-45c3-b352-11b694274a9e">
  673. <span class="eqno">(18)<a class="headerlink" href="#equation-37bb812e-8d87-45c3-b352-11b694274a9e" title="Permalink to this equation">¶</a></span>\[\begin{align}
  674. p(\hmmhid_{t+h}|\hmmobs_{1:t})
  675. &amp;= \sum_{\hmmhid_{t:t+h-1}}
  676. p(\hmmhid_t|\hmmobs_{1:t})
  677. p(\hmmhid_{t+1}|\hmmhid_{t})
  678. p(\hmmhid_{t+2}|\hmmhid_{t+1})
  679. \cdots
  680. p(\hmmhid_{t+h}|\hmmhid_{t+h-1})
  681. \end{align}\]</div>
  682. <p>See
  683. <a class="reference internal" href="#fig-dbn-inference"><span class="std std-numref">Fig. 7</span></a> for a summary of these distributions.</p>
  684. <p>In addition to comuting posterior marginals,
  685. we may want to compute the most probable hidden sequence,
  686. i.e., the joint MAP estimate</p>
  687. <div class="math notranslate nohighlight">
  688. \[\arg \max_{\hmmhid_{1:T}} p(\hmmhid_{1:T}|\hmmobs_{1:T})\]</div>
  689. <p>or sample sequences from the posterior</p>
  690. <div class="math notranslate nohighlight">
  691. \[\hmmhid_{1:T} \sim p(\hmmhid_{1:T}|\hmmobs_{1:T})\]</div>
  692. <p>Algorithms for all these task are discussed in the following chapters,
  693. since the details depend on the form of the SSM.</p>
  694. <div class="section" id="example-inference-in-the-casino-hmm">
  695. <span id="sec-casino-inference"></span><h2>Example: inference in the casino HMM<a class="headerlink" href="#example-inference-in-the-casino-hmm" title="Permalink to this headline">¶</a></h2>
  696. <p>We now illustrate filtering, smoothing and MAP decoding applied
  697. to the casino HMM from <span class="xref std std-ref">sec:casino</span> and <span class="xref myst"></span>.</p>
  698. <div class="cell docutils container">
  699. <div class="cell_input docutils container">
  700. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># state transition matrix</span>
  701. <span class="n">A</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span>
  702. <span class="p">[</span><span class="mf">0.95</span><span class="p">,</span> <span class="mf">0.05</span><span class="p">],</span>
  703. <span class="p">[</span><span class="mf">0.10</span><span class="p">,</span> <span class="mf">0.90</span><span class="p">]</span>
  704. <span class="p">])</span>
  705. <span class="c1"># observation matrix</span>
  706. <span class="n">B</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span>
  707. <span class="p">[</span><span class="mi">1</span><span class="o">/</span><span class="mi">6</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">6</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">6</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">6</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">6</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">6</span><span class="p">],</span> <span class="c1"># fair die</span>
  708. <span class="p">[</span><span class="mi">1</span><span class="o">/</span><span class="mi">10</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">10</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">10</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">10</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">10</span><span class="p">,</span> <span class="mi">5</span><span class="o">/</span><span class="mi">10</span><span class="p">]</span> <span class="c1"># loaded die</span>
  709. <span class="p">])</span>
  710. <span class="n">pi</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">])</span>
  711. <span class="p">(</span><span class="n">nstates</span><span class="p">,</span> <span class="n">nobs</span><span class="p">)</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">B</span><span class="p">)</span>
  712. <span class="kn">import</span> <span class="nn">distrax</span>
  713. <span class="kn">from</span> <span class="nn">distrax</span> <span class="kn">import</span> <span class="n">HMM</span>
  714. <span class="n">hmm</span> <span class="o">=</span> <span class="n">HMM</span><span class="p">(</span><span class="n">trans_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">Categorical</span><span class="p">(</span><span class="n">probs</span><span class="o">=</span><span class="n">A</span><span class="p">),</span>
  715. <span class="n">init_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">Categorical</span><span class="p">(</span><span class="n">probs</span><span class="o">=</span><span class="n">pi</span><span class="p">),</span>
  716. <span class="n">obs_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">Categorical</span><span class="p">(</span><span class="n">probs</span><span class="o">=</span><span class="n">B</span><span class="p">))</span>
  717. <span class="n">seed</span> <span class="o">=</span> <span class="mi">314</span>
  718. <span class="n">n_samples</span> <span class="o">=</span> <span class="mi">300</span>
  719. <span class="n">z_hist</span><span class="p">,</span> <span class="n">x_hist</span> <span class="o">=</span> <span class="n">hmm</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="n">seed</span><span class="o">=</span><span class="n">PRNGKey</span><span class="p">(</span><span class="n">seed</span><span class="p">),</span> <span class="n">seq_len</span><span class="o">=</span><span class="n">n_samples</span><span class="p">)</span>
  720. </pre></div>
  721. </div>
  722. </div>
  723. <div class="cell_output docutils container">
  724. <div class="output stderr highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
  725. </pre></div>
  726. </div>
  727. </div>
  728. </div>
  729. <div class="cell docutils container">
  730. <div class="cell_input docutils container">
  731. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># Call inference engine</span>
  732. <span class="n">filtered_dist</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">smoothed_dist</span><span class="p">,</span> <span class="n">loglik</span> <span class="o">=</span> <span class="n">hmm</span><span class="o">.</span><span class="n">forward_backward</span><span class="p">(</span><span class="n">x_hist</span><span class="p">)</span>
  733. <span class="n">map_path</span> <span class="o">=</span> <span class="n">hmm</span><span class="o">.</span><span class="n">viterbi</span><span class="p">(</span><span class="n">x_hist</span><span class="p">)</span>
  734. </pre></div>
  735. </div>
  736. </div>
  737. <div class="cell_output docutils container">
  738. <div class="output stderr highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>/opt/anaconda3/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:5154: UserWarning: Explicitly requested dtype &lt;class &#39;jax._src.numpy.lax_numpy.int64&#39;&gt; requested in astype is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  739. lax_internal._check_user_dtype_supported(dtype, &quot;astype&quot;)
  740. </pre></div>
  741. </div>
  742. </div>
  743. </div>
  744. <div class="cell docutils container">
  745. <div class="cell_input docutils container">
  746. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># Find the span of timesteps that the simulated systems turns to be in state 1</span>
  747. <span class="k">def</span> <span class="nf">find_dishonest_intervals</span><span class="p">(</span><span class="n">z_hist</span><span class="p">):</span>
  748. <span class="n">spans</span> <span class="o">=</span> <span class="p">[]</span>
  749. <span class="n">x_init</span> <span class="o">=</span> <span class="mi">0</span>
  750. <span class="k">for</span> <span class="n">t</span><span class="p">,</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">z_hist</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]):</span>
  751. <span class="k">if</span> <span class="n">z_hist</span><span class="p">[</span><span class="n">t</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">z_hist</span><span class="p">[</span><span class="n">t</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
  752. <span class="n">x_end</span> <span class="o">=</span> <span class="n">t</span>
  753. <span class="n">spans</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">x_init</span><span class="p">,</span> <span class="n">x_end</span><span class="p">))</span>
  754. <span class="k">elif</span> <span class="n">z_hist</span><span class="p">[</span><span class="n">t</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">and</span> <span class="n">z_hist</span><span class="p">[</span><span class="n">t</span><span class="p">]</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
  755. <span class="n">x_init</span> <span class="o">=</span> <span class="n">t</span> <span class="o">+</span> <span class="mi">1</span>
  756. <span class="k">return</span> <span class="n">spans</span>
  757. </pre></div>
  758. </div>
  759. </div>
  760. </div>
  761. <div class="cell docutils container">
  762. <div class="cell_input docutils container">
  763. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># Plot posterior</span>
  764. <span class="k">def</span> <span class="nf">plot_inference</span><span class="p">(</span><span class="n">inference_values</span><span class="p">,</span> <span class="n">z_hist</span><span class="p">,</span> <span class="n">ax</span><span class="p">,</span> <span class="n">state</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">map_estimate</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
  765. <span class="n">n_samples</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">inference_values</span><span class="p">)</span>
  766. <span class="n">xspan</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">n_samples</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
  767. <span class="n">spans</span> <span class="o">=</span> <span class="n">find_dishonest_intervals</span><span class="p">(</span><span class="n">z_hist</span><span class="p">)</span>
  768. <span class="k">if</span> <span class="n">map_estimate</span><span class="p">:</span>
  769. <span class="n">ax</span><span class="o">.</span><span class="n">step</span><span class="p">(</span><span class="n">xspan</span><span class="p">,</span> <span class="n">inference_values</span><span class="p">,</span> <span class="n">where</span><span class="o">=</span><span class="s2">&quot;post&quot;</span><span class="p">)</span>
  770. <span class="k">else</span><span class="p">:</span>
  771. <span class="n">ax</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">xspan</span><span class="p">,</span> <span class="n">inference_values</span><span class="p">[:,</span> <span class="n">state</span><span class="p">])</span>
  772. <span class="k">for</span> <span class="n">span</span> <span class="ow">in</span> <span class="n">spans</span><span class="p">:</span>
  773. <span class="n">ax</span><span class="o">.</span><span class="n">axvspan</span><span class="p">(</span><span class="o">*</span><span class="n">span</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span> <span class="n">facecolor</span><span class="o">=</span><span class="s2">&quot;tab:gray&quot;</span><span class="p">,</span> <span class="n">edgecolor</span><span class="o">=</span><span class="s2">&quot;none&quot;</span><span class="p">)</span>
  774. <span class="n">ax</span><span class="o">.</span><span class="n">set_xlim</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">n_samples</span><span class="p">)</span>
  775. <span class="c1"># ax.set_ylim(0, 1)</span>
  776. <span class="n">ax</span><span class="o">.</span><span class="n">set_ylim</span><span class="p">(</span><span class="o">-</span><span class="mf">0.1</span><span class="p">,</span> <span class="mf">1.1</span><span class="p">)</span>
  777. <span class="n">ax</span><span class="o">.</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s2">&quot;Observation number&quot;</span><span class="p">)</span>
  778. </pre></div>
  779. </div>
  780. </div>
  781. </div>
  782. <div class="cell docutils container">
  783. <div class="cell_input docutils container">
  784. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span> <span class="c1"># Filtering</span>
  785. <span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">()</span>
  786. <span class="n">plot_inference</span><span class="p">(</span><span class="n">filtered_dist</span><span class="p">,</span> <span class="n">z_hist</span><span class="p">,</span> <span class="n">ax</span><span class="p">)</span>
  787. <span class="n">ax</span><span class="o">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s2">&quot;p(loaded)&quot;</span><span class="p">)</span>
  788. <span class="n">ax</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">&quot;Filtered&quot;</span><span class="p">)</span>
  789. </pre></div>
  790. </div>
  791. </div>
  792. <div class="cell_output docutils container">
  793. <div class="output text_plain highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>Text(0.5, 1.0, &#39;Filtered&#39;)
  794. </pre></div>
  795. </div>
  796. <img alt="../../_images/inference_9_1.png" src="../../_images/inference_9_1.png" />
  797. </div>
  798. </div>
  799. <div class="cell docutils container">
  800. <div class="cell_input docutils container">
  801. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># Smoothing</span>
  802. <span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">()</span>
  803. <span class="n">plot_inference</span><span class="p">(</span><span class="n">smoothed_dist</span><span class="p">,</span> <span class="n">z_hist</span><span class="p">,</span> <span class="n">ax</span><span class="p">)</span>
  804. <span class="n">ax</span><span class="o">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s2">&quot;p(loaded)&quot;</span><span class="p">)</span>
  805. <span class="n">ax</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">&quot;Smoothed&quot;</span><span class="p">)</span>
  806. </pre></div>
  807. </div>
  808. </div>
  809. <div class="cell_output docutils container">
  810. <div class="output text_plain highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>Text(0.5, 1.0, &#39;Smoothed&#39;)
  811. </pre></div>
  812. </div>
  813. <img alt="../../_images/inference_10_1.png" src="../../_images/inference_10_1.png" />
  814. </div>
  815. </div>
  816. <div class="cell docutils container">
  817. <div class="cell_input docutils container">
  818. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># MAP estimation</span>
  819. <span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">()</span>
  820. <span class="n">plot_inference</span><span class="p">(</span><span class="n">map_path</span><span class="p">,</span> <span class="n">z_hist</span><span class="p">,</span> <span class="n">ax</span><span class="p">,</span> <span class="n">map_estimate</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
  821. <span class="n">ax</span><span class="o">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s2">&quot;MAP state&quot;</span><span class="p">)</span>
  822. <span class="n">ax</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">&quot;Viterbi&quot;</span><span class="p">)</span>
  823. </pre></div>
  824. </div>
  825. </div>
  826. <div class="cell_output docutils container">
  827. <div class="output text_plain highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>Text(0.5, 1.0, &#39;Viterbi&#39;)
  828. </pre></div>
  829. </div>
  830. <img alt="../../_images/inference_11_1.png" src="../../_images/inference_11_1.png" />
  831. </div>
  832. </div>
  833. <div class="cell docutils container">
  834. <div class="cell_input docutils container">
  835. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># TODO: posterior samples</span>
  836. </pre></div>
  837. </div>
  838. </div>
  839. </div>
  840. </div>
  841. <div class="section" id="example-inference-in-the-tracking-lg-ssm">
  842. <h2>Example: inference in the tracking LG-SSM<a class="headerlink" href="#example-inference-in-the-tracking-lg-ssm" title="Permalink to this headline">¶</a></h2>
  843. <p>We now illustrate filtering, smoothing and MAP decoding applied
  844. to the 2d tracking HMM from <a class="reference internal" href="lds.html#sec-tracking-lds"><span class="std std-ref">Example: tracking a 2d point</span></a>.</p>
  845. <div class="cell docutils container">
  846. <div class="cell_input docutils container">
  847. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">key</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">PRNGKey</span><span class="p">(</span><span class="mi">314</span><span class="p">)</span>
  848. <span class="n">timesteps</span> <span class="o">=</span> <span class="mi">15</span>
  849. <span class="n">delta</span> <span class="o">=</span> <span class="mf">1.0</span>
  850. <span class="n">A</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span>
  851. <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">delta</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span>
  852. <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">delta</span><span class="p">],</span>
  853. <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span>
  854. <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">]</span>
  855. <span class="p">])</span>
  856. <span class="n">C</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span>
  857. <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span>
  858. <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span>
  859. <span class="p">])</span>
  860. <span class="n">state_size</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">A</span><span class="o">.</span><span class="n">shape</span>
  861. <span class="n">observation_size</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">C</span><span class="o">.</span><span class="n">shape</span>
  862. <span class="n">Q</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="n">state_size</span><span class="p">)</span> <span class="o">*</span> <span class="mf">0.001</span>
  863. <span class="n">R</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="n">observation_size</span><span class="p">)</span> <span class="o">*</span> <span class="mf">1.0</span>
  864. <span class="n">mu0</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mi">8</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">])</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="nb">float</span><span class="p">)</span>
  865. <span class="n">Sigma0</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="n">state_size</span><span class="p">)</span> <span class="o">*</span> <span class="mf">1.0</span>
  866. <span class="kn">from</span> <span class="nn">jsl.lds.kalman_filter</span> <span class="kn">import</span> <span class="n">LDS</span><span class="p">,</span> <span class="n">smooth</span><span class="p">,</span> <span class="nb">filter</span>
  867. <span class="n">lds</span> <span class="o">=</span> <span class="n">LDS</span><span class="p">(</span><span class="n">A</span><span class="p">,</span> <span class="n">C</span><span class="p">,</span> <span class="n">Q</span><span class="p">,</span> <span class="n">R</span><span class="p">,</span> <span class="n">mu0</span><span class="p">,</span> <span class="n">Sigma0</span><span class="p">)</span>
  868. <span class="n">z_hist</span><span class="p">,</span> <span class="n">x_hist</span> <span class="o">=</span> <span class="n">lds</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">timesteps</span><span class="p">)</span>
  869. </pre></div>
  870. </div>
  871. </div>
  872. </div>
  873. <div class="cell docutils container">
  874. <div class="cell_input docutils container">
  875. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">jsl.demos.plot_utils</span> <span class="kn">import</span> <span class="n">plot_ellipse</span>
  876. <span class="k">def</span> <span class="nf">plot_tracking_values</span><span class="p">(</span><span class="n">observed</span><span class="p">,</span> <span class="n">filtered</span><span class="p">,</span> <span class="n">cov_hist</span><span class="p">,</span> <span class="n">signal_label</span><span class="p">,</span> <span class="n">ax</span><span class="p">):</span>
  877. <span class="n">timesteps</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">observed</span><span class="o">.</span><span class="n">shape</span>
  878. <span class="n">ax</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">observed</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">observed</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">marker</span><span class="o">=</span><span class="s2">&quot;o&quot;</span><span class="p">,</span> <span class="n">linewidth</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
  879. <span class="n">markerfacecolor</span><span class="o">=</span><span class="s2">&quot;none&quot;</span><span class="p">,</span> <span class="n">markeredgewidth</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">markersize</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s2">&quot;observed&quot;</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="s2">&quot;tab:green&quot;</span><span class="p">)</span>
  880. <span class="n">ax</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="o">*</span><span class="n">filtered</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">T</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="n">signal_label</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="s2">&quot;tab:red&quot;</span><span class="p">,</span> <span class="n">marker</span><span class="o">=</span><span class="s2">&quot;x&quot;</span><span class="p">,</span> <span class="n">linewidth</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
  881. <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">timesteps</span><span class="p">,</span> <span class="mi">1</span><span class="p">):</span>
  882. <span class="n">covn</span> <span class="o">=</span> <span class="n">cov_hist</span><span class="p">[</span><span class="n">t</span><span class="p">][:</span><span class="mi">2</span><span class="p">,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]</span>
  883. <span class="n">plot_ellipse</span><span class="p">(</span><span class="n">covn</span><span class="p">,</span> <span class="n">filtered</span><span class="p">[</span><span class="n">t</span><span class="p">,</span> <span class="p">:</span><span class="mi">2</span><span class="p">],</span> <span class="n">ax</span><span class="p">,</span> <span class="n">n_std</span><span class="o">=</span><span class="mf">2.0</span><span class="p">,</span> <span class="n">plot_center</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
  884. <span class="n">ax</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">&quot;equal&quot;</span><span class="p">)</span>
  885. <span class="n">ax</span><span class="o">.</span><span class="n">legend</span><span class="p">()</span>
  886. </pre></div>
  887. </div>
  888. </div>
  889. </div>
  890. <div class="cell docutils container">
  891. <div class="cell_input docutils container">
  892. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># Filtering</span>
  893. <span class="n">mu_hist</span><span class="p">,</span> <span class="n">Sigma_hist</span><span class="p">,</span> <span class="n">mu_cond_hist</span><span class="p">,</span> <span class="n">Sigma_cond_hist</span> <span class="o">=</span> <span class="nb">filter</span><span class="p">(</span><span class="n">lds</span><span class="p">,</span> <span class="n">x_hist</span><span class="p">)</span>
  894. <span class="n">l2_filter</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">linalg</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">z_hist</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]</span> <span class="o">-</span> <span class="n">mu_hist</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">],</span> <span class="mi">2</span><span class="p">)</span>
  895. <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;L2-filter: </span><span class="si">{</span><span class="n">l2_filter</span><span class="si">:</span><span class="s2">0.4f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
  896. <span class="n">fig_filtered</span><span class="p">,</span> <span class="n">axs</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">()</span>
  897. <span class="n">plot_tracking_values</span><span class="p">(</span><span class="n">x_hist</span><span class="p">,</span> <span class="n">mu_hist</span><span class="p">,</span> <span class="n">Sigma_hist</span><span class="p">,</span> <span class="s2">&quot;filtered&quot;</span><span class="p">,</span> <span class="n">axs</span><span class="p">)</span>
  898. </pre></div>
  899. </div>
  900. </div>
  901. <div class="cell_output docutils container">
  902. <div class="output stream highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>L2-filter: 3.2481
  903. </pre></div>
  904. </div>
  905. <img alt="../../_images/inference_16_1.png" src="../../_images/inference_16_1.png" />
  906. </div>
  907. </div>
  908. <div class="cell docutils container">
  909. <div class="cell_input docutils container">
  910. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># Smoothing</span>
  911. <span class="n">mu_hist_smooth</span><span class="p">,</span> <span class="n">Sigma_hist_smooth</span> <span class="o">=</span> <span class="n">smooth</span><span class="p">(</span><span class="n">lds</span><span class="p">,</span> <span class="n">mu_hist</span><span class="p">,</span> <span class="n">Sigma_hist</span><span class="p">,</span> <span class="n">mu_cond_hist</span><span class="p">,</span> <span class="n">Sigma_cond_hist</span><span class="p">)</span>
  912. <span class="n">l2_smooth</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">linalg</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">z_hist</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]</span> <span class="o">-</span> <span class="n">mu_hist_smooth</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">],</span> <span class="mi">2</span><span class="p">)</span>
  913. <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;L2-smooth: </span><span class="si">{</span><span class="n">l2_smooth</span><span class="si">:</span><span class="s2">0.4f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
  914. <span class="n">fig_smoothed</span><span class="p">,</span> <span class="n">axs</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">()</span>
  915. <span class="n">plot_tracking_values</span><span class="p">(</span><span class="n">x_hist</span><span class="p">,</span> <span class="n">mu_hist_smooth</span><span class="p">,</span> <span class="n">Sigma_hist_smooth</span><span class="p">,</span> <span class="s2">&quot;smoothed&quot;</span><span class="p">,</span> <span class="n">axs</span><span class="p">)</span>
  916. </pre></div>
  917. </div>
  918. </div>
  919. <div class="cell_output docutils container">
  920. <div class="output stream highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>L2-smooth: 2.0450
  921. </pre></div>
  922. </div>
  923. <img alt="../../_images/inference_17_1.png" src="../../_images/inference_17_1.png" />
  924. </div>
  925. </div>
  926. </div>
  927. </div>
  928. <script type="text/x-thebe-config">
  929. {
  930. requestKernel: true,
  931. binderOptions: {
  932. repo: "binder-examples/jupyter-stacks-datascience",
  933. ref: "master",
  934. },
  935. codeMirrorConfig: {
  936. theme: "abcdef",
  937. mode: "python"
  938. },
  939. kernelOptions: {
  940. kernelName: "python3",
  941. path: "./chapters/ssm"
  942. },
  943. predefinedOutput: true
  944. }
  945. </script>
  946. <script>kernelName = 'python3'</script>
  947. </div>
  948. <!-- Previous / next buttons -->
  949. <div class='prev-next-area'>
  950. <a class='left-prev' id="prev-link" href="nlds.html" title="previous page">
  951. <i class="fas fa-angle-left"></i>
  952. <div class="prev-next-info">
  953. <p class="prev-next-subtitle">previous</p>
  954. <p class="prev-next-title">Nonlinear Gaussian SSMs</p>
  955. </div>
  956. </a>
  957. <a class='right-next' id="next-link" href="../hmm/hmm_index.html" title="next page">
  958. <div class="prev-next-info">
  959. <p class="prev-next-subtitle">next</p>
  960. <p class="prev-next-title">Hidden Markov Models</p>
  961. </div>
  962. <i class="fas fa-angle-right"></i>
  963. </a>
  964. </div>
  965. </div>
  966. </div>
  967. <footer class="footer">
  968. <p>
  969. By Kevin Murphy, Scott Linderman, et al.<br/>
  970. &copy; Copyright 2021.<br/>
  971. </p>
  972. </footer>
  973. </main>
  974. </div>
  975. </div>
  976. <script src="../../_static/js/index.be7d3bbb2ef33a8344ce.js"></script>
  977. </body>
  978. </html>