inference.html 59 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912
  1. <!DOCTYPE html>
  2. <html>
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>States estimation (inference) &#8212; State Space Models: A Modern Approach</title>
  7. <link href="../../_static/css/theme.css" rel="stylesheet">
  8. <link href="../../_static/css/index.ff1ffe594081f20da1ef19478df9384b.css" rel="stylesheet">
  9. <link rel="stylesheet"
  10. href="../../_static/vendor/fontawesome/5.13.0/css/all.min.css">
  11. <link rel="preload" as="font" type="font/woff2" crossorigin
  12. href="../../_static/vendor/fontawesome/5.13.0/webfonts/fa-solid-900.woff2">
  13. <link rel="preload" as="font" type="font/woff2" crossorigin
  14. href="../../_static/vendor/fontawesome/5.13.0/webfonts/fa-brands-400.woff2">
  15. <link rel="stylesheet" type="text/css" href="../../_static/pygments.css" />
  16. <link rel="stylesheet" type="text/css" href="../../_static/sphinx-book-theme.css?digest=c3fdc42140077d1ad13ad2f1588a4309" />
  17. <link rel="stylesheet" type="text/css" href="../../_static/togglebutton.css" />
  18. <link rel="stylesheet" type="text/css" href="../../_static/copybutton.css" />
  19. <link rel="stylesheet" type="text/css" href="../../_static/mystnb.css" />
  20. <link rel="stylesheet" type="text/css" href="../../_static/sphinx-thebe.css" />
  21. <link rel="stylesheet" type="text/css" href="../../_static/panels-main.c949a650a448cc0ae9fd3441c0e17fb0.css" />
  22. <link rel="stylesheet" type="text/css" href="../../_static/panels-variables.06eb56fa6e07937060861dad626602ad.css" />
  23. <link rel="preload" as="script" href="../../_static/js/index.be7d3bbb2ef33a8344ce.js">
  24. <script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
  25. <script src="../../_static/jquery.js"></script>
  26. <script src="../../_static/underscore.js"></script>
  27. <script src="../../_static/doctools.js"></script>
  28. <script src="../../_static/clipboard.min.js"></script>
  29. <script src="../../_static/copybutton.js"></script>
  30. <script>let toggleHintShow = 'Click to show';</script>
  31. <script>let toggleHintHide = 'Click to hide';</script>
  32. <script>let toggleOpenOnPrint = 'true';</script>
  33. <script src="../../_static/togglebutton.js"></script>
  34. <script>var togglebuttonSelector = '.toggle, .admonition.dropdown, .tag_hide_input div.cell_input, .tag_hide-input div.cell_input, .tag_hide_output div.cell_output, .tag_hide-output div.cell_output, .tag_hide_cell.cell, .tag_hide-cell.cell';</script>
  35. <script src="../../_static/sphinx-book-theme.d59cb220de22ca1c485ebbdc042f0030.js"></script>
  36. <script>const THEBE_JS_URL = "https://unpkg.com/thebe@0.8.2/lib/index.js"
  37. const thebe_selector = ".thebe,.cell"
  38. const thebe_selector_input = "pre"
  39. const thebe_selector_output = ".output, .cell_output"
  40. </script>
  41. <script async="async" src="../../_static/sphinx-thebe.js"></script>
  42. <script>window.MathJax = {"tex": {"macros": {"covMat": "\\boldsymbol{\\Sigma}", "data": "\\mathcal{D}", "defeq": "\\triangleq", "diag": "\\mathrm{diag}", "discreteState": "s", "dotstar": "\\odot", "dynamicsFn": "\\mathbf{f}", "floor": ["\\lfloor#1\\rfloor", 1], "gainMatrix": "\\mathbf{K}", "gainMatrixReverse": "\\mathbf{G}", "gauss": "\\mathcal{N}", "gaussInfo": "\\mathcal{N}_{\\text{info}}", "hidden": "\\mathbf{x}", "hiddenScalar": "x", "hmmInit": "\\boldsymbol{\\pi}", "hmmInitScalar": "\\pi", "hmmObs": "\\mathbf{B}", "hmmObsScalar": "B", "hmmTrans": "\\mathbf{A}", "hmmTransScalar": "A", "infoMat": "\\precMat", "input": "\\mathbf{u}", "inputs": "\\input", "inv": ["{#1}^{-1}", 1], "keyword": ["\\textbf{#1}", 1], "ldsDyn": "\\mathbf{F}", "ldsDynIn": "\\mathbf{B}", "initMean": "\\boldsymbol{\\mean}_0", "initCov": "\\boldsymbol{\\covMat}_0", "ldsObs": "\\mathbf{H}", "ldsObsIn": "\\mathbf{D}", "ldsTrans": "\\ldsDyn", "ldsTransIn": "\\ldsDynIn", "obsCov": "\\mathbf{R}", "obsNoise": "\\boldsymbol{r}", "map": "\\mathrm{map}", "measurementFn": "\\mathbf{h}", "mean": "\\boldsymbol{\\mu}", "mle": "\\mathrm{mle}", "nlatents": "n_x", "nhidden": "\\nlatents", "ninputs": "n_u", "nobs": "n_y", "nsymbols": "n_y", "nstates": "n_s", "obs": "\\mathbf{y}", "obsScalar": "y", "observed": "\\obs", "obsFn": "\\measurementFn", "params": "\\boldsymbol{\\theta}", "precMean": "\\boldsymbol{\\eta}", "precMat": "\\boldsymbol{\\Lambda}", "real": "\\mathbb{R}", "sigmoid": "\\sigma", "softmax": "\\boldsymbol{\\sigma}", "trans": "\\mathsf{T}", "transpose": ["{#1}^{\\trans}", 1], "transCov": "\\mathbf{Q}", "transNoise": "\\mathbf{q}", "valpha": "\\boldsymbol{\\alpha}", "vbeta": "\\boldsymbol{\\beta}", "vdelta": "\\boldsymbol{\\delta}", "vepsilon": "\\boldsymbol{\\epsilon}", "vlambda": "\\boldsymbol{\\lambda}", "vLambda": "\\boldsymbol{\\Lambda}", "vmu": "\\boldsymbol{\\mu}", "vpi": "\\boldsymbol{\\pi}", "vsigma": "\\boldsymbol{\\sigma}", "vSigma": "\\boldsymbol{\\Sigma}", "vone": "\\boldsymbol{1}", "vzero": "\\boldsymbol{0}"}}, "options": {"processHtmlClass": "tex2jax_process|mathjax_process|math|output_area"}}</script>
  43. <script defer="defer" src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
  44. <link rel="index" title="Index" href="../../genindex.html" />
  45. <link rel="search" title="Search" href="../../search.html" />
  46. <link rel="next" title="Parameter estimation (learning)" href="learning.html" />
  47. <link rel="prev" title="Nonlinear Gaussian SSMs" href="nlds.html" />
  48. <meta name="viewport" content="width=device-width, initial-scale=1" />
  49. <meta name="docsearch:language" content="None">
  50. <!-- Google Analytics -->
  51. </head>
  52. <body data-spy="scroll" data-target="#bd-toc-nav" data-offset="80">
  53. <div class="container-fluid" id="banner"></div>
  54. <div class="container-xl">
  55. <div class="row">
  56. <div class="col-12 col-md-3 bd-sidebar site-navigation show" id="site-navigation">
  57. <div class="navbar-brand-box">
  58. <a class="navbar-brand text-wrap" href="../../index.html">
  59. <h1 class="site-logo" id="site-title">State Space Models: A Modern Approach</h1>
  60. </a>
  61. </div><form class="bd-search d-flex align-items-center" action="../../search.html" method="get">
  62. <i class="icon fas fa-search"></i>
  63. <input type="search" class="form-control" name="q" id="search-input" placeholder="Search this book..." aria-label="Search this book..." autocomplete="off" >
  64. </form><nav class="bd-links" id="bd-docs-nav" aria-label="Main">
  65. <div class="bd-toc-item active">
  66. <ul class="nav bd-sidenav">
  67. <li class="toctree-l1">
  68. <a class="reference internal" href="../../root.html">
  69. State Space Models: A Modern Approach
  70. </a>
  71. </li>
  72. </ul>
  73. <ul class="current nav bd-sidenav">
  74. <li class="toctree-l1 current active has-children">
  75. <a class="reference internal" href="ssm_index.html">
  76. State Space Models
  77. </a>
  78. <input checked="" class="toctree-checkbox" id="toctree-checkbox-1" name="toctree-checkbox-1" type="checkbox"/>
  79. <label for="toctree-checkbox-1">
  80. <i class="fas fa-chevron-down">
  81. </i>
  82. </label>
  83. <ul class="current">
  84. <li class="toctree-l2">
  85. <a class="reference internal" href="ssm_intro.html">
  86. What are State Space Models?
  87. </a>
  88. </li>
  89. <li class="toctree-l2">
  90. <a class="reference internal" href="hmm.html">
  91. Hidden Markov Models
  92. </a>
  93. </li>
  94. <li class="toctree-l2">
  95. <a class="reference internal" href="lds.html">
  96. Linear Gaussian SSMs
  97. </a>
  98. </li>
  99. <li class="toctree-l2">
  100. <a class="reference internal" href="nlds.html">
  101. Nonlinear Gaussian SSMs
  102. </a>
  103. </li>
  104. <li class="toctree-l2 current active">
  105. <a class="current reference internal" href="#">
  106. States estimation (inference)
  107. </a>
  108. </li>
  109. <li class="toctree-l2">
  110. <a class="reference internal" href="learning.html">
  111. Parameter estimation (learning)
  112. </a>
  113. </li>
  114. </ul>
  115. </li>
  116. <li class="toctree-l1 has-children">
  117. <a class="reference internal" href="../hmm/hmm_index.html">
  118. Hidden Markov Models
  119. </a>
  120. <input class="toctree-checkbox" id="toctree-checkbox-2" name="toctree-checkbox-2" type="checkbox"/>
  121. <label for="toctree-checkbox-2">
  122. <i class="fas fa-chevron-down">
  123. </i>
  124. </label>
  125. <ul>
  126. <li class="toctree-l2">
  127. <a class="reference internal" href="../hmm/hmm_filter.html">
  128. HMM filtering (forwards algorithm)
  129. </a>
  130. </li>
  131. <li class="toctree-l2">
  132. <a class="reference internal" href="../hmm/hmm_smoother.html">
  133. HMM smoothing (forwards-backwards algorithm)
  134. </a>
  135. </li>
  136. <li class="toctree-l2">
  137. <a class="reference internal" href="../hmm/hmm_viterbi.html">
  138. Viterbi algorithm
  139. </a>
  140. </li>
  141. <li class="toctree-l2">
  142. <a class="reference internal" href="../hmm/hmm_parallel.html">
  143. Parallel HMM smoothing
  144. </a>
  145. </li>
  146. <li class="toctree-l2">
  147. <a class="reference internal" href="../hmm/hmm_sampling.html">
  148. Forwards-filtering backwards-sampling algorithm
  149. </a>
  150. </li>
  151. </ul>
  152. </li>
  153. <li class="toctree-l1 has-children">
  154. <a class="reference internal" href="../lgssm/lgssm_index.html">
  155. Linear-Gaussian SSMs
  156. </a>
  157. <input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
  158. <label for="toctree-checkbox-3">
  159. <i class="fas fa-chevron-down">
  160. </i>
  161. </label>
  162. <ul>
  163. <li class="toctree-l2">
  164. <a class="reference internal" href="../lgssm/kalman_filter.html">
  165. Kalman filtering
  166. </a>
  167. </li>
  168. <li class="toctree-l2">
  169. <a class="reference internal" href="../lgssm/kalman_smoother.html">
  170. Kalman (RTS) smoother
  171. </a>
  172. </li>
  173. <li class="toctree-l2">
  174. <a class="reference internal" href="../lgssm/kalman_parallel.html">
  175. Parallel Kalman Smoother
  176. </a>
  177. </li>
  178. <li class="toctree-l2">
  179. <a class="reference internal" href="../lgssm/kalman_sampling.html">
  180. Forwards-filtering backwards sampling
  181. </a>
  182. </li>
  183. </ul>
  184. </li>
  185. <li class="toctree-l1 has-children">
  186. <a class="reference internal" href="../extended/extended_index.html">
  187. Extended (linearized) methods
  188. </a>
  189. <input class="toctree-checkbox" id="toctree-checkbox-4" name="toctree-checkbox-4" type="checkbox"/>
  190. <label for="toctree-checkbox-4">
  191. <i class="fas fa-chevron-down">
  192. </i>
  193. </label>
  194. <ul>
  195. <li class="toctree-l2">
  196. <a class="reference internal" href="../extended/extended_filter.html">
  197. Extended Kalman filtering
  198. </a>
  199. </li>
  200. <li class="toctree-l2">
  201. <a class="reference internal" href="../extended/extended_smoother.html">
  202. Extended Kalman smoother
  203. </a>
  204. </li>
  205. <li class="toctree-l2">
  206. <a class="reference internal" href="../extended/extended_parallel.html">
  207. Parallel extended Kalman smoothing
  208. </a>
  209. </li>
  210. </ul>
  211. </li>
  212. <li class="toctree-l1 has-children">
  213. <a class="reference internal" href="../unscented/unscented_index.html">
  214. Unscented methods
  215. </a>
  216. <input class="toctree-checkbox" id="toctree-checkbox-5" name="toctree-checkbox-5" type="checkbox"/>
  217. <label for="toctree-checkbox-5">
  218. <i class="fas fa-chevron-down">
  219. </i>
  220. </label>
  221. <ul>
  222. <li class="toctree-l2">
  223. <a class="reference internal" href="../unscented/unscented_filter.html">
  224. Unscented filtering
  225. </a>
  226. </li>
  227. <li class="toctree-l2">
  228. <a class="reference internal" href="../unscented/unscented_smoother.html">
  229. Unscented smoothing
  230. </a>
  231. </li>
  232. </ul>
  233. </li>
  234. <li class="toctree-l1">
  235. <a class="reference internal" href="../quadrature/quadrature_index.html">
  236. Quadrature and cubature methods
  237. </a>
  238. </li>
  239. <li class="toctree-l1">
  240. <a class="reference internal" href="../postlin/postlin_index.html">
  241. Posterior linearization
  242. </a>
  243. </li>
  244. <li class="toctree-l1">
  245. <a class="reference internal" href="../adf/adf_index.html">
  246. Assumed Density Filtering
  247. </a>
  248. </li>
  249. <li class="toctree-l1">
  250. <a class="reference internal" href="../vi/vi_index.html">
  251. Variational inference
  252. </a>
  253. </li>
  254. <li class="toctree-l1">
  255. <a class="reference internal" href="../pf/pf_index.html">
  256. Particle filtering
  257. </a>
  258. </li>
  259. <li class="toctree-l1">
  260. <a class="reference internal" href="../smc/smc_index.html">
  261. Sequential Monte Carlo
  262. </a>
  263. </li>
  264. <li class="toctree-l1">
  265. <a class="reference internal" href="../learning/learning_index.html">
  266. Offline parameter estimation (learning)
  267. </a>
  268. </li>
  269. <li class="toctree-l1">
  270. <a class="reference internal" href="../tracking/tracking_index.html">
  271. Multi-target tracking
  272. </a>
  273. </li>
  274. <li class="toctree-l1">
  275. <a class="reference internal" href="../ensemble/ensemble_index.html">
  276. Data assimilation using Ensemble Kalman filter
  277. </a>
  278. </li>
  279. <li class="toctree-l1">
  280. <a class="reference internal" href="../bnp/bnp_index.html">
  281. Bayesian non-parametric SSMs
  282. </a>
  283. </li>
  284. <li class="toctree-l1">
  285. <a class="reference internal" href="../changepoint/changepoint_index.html">
  286. Changepoint detection
  287. </a>
  288. </li>
  289. <li class="toctree-l1">
  290. <a class="reference internal" href="../timeseries/timeseries_index.html">
  291. Timeseries forecasting
  292. </a>
  293. </li>
  294. <li class="toctree-l1">
  295. <a class="reference internal" href="../gp/gp_index.html">
  296. Markovian Gaussian processes
  297. </a>
  298. </li>
  299. <li class="toctree-l1">
  300. <a class="reference internal" href="../ode/ode_index.html">
  301. Differential equations and SSMs
  302. </a>
  303. </li>
  304. <li class="toctree-l1">
  305. <a class="reference internal" href="../control/control_index.html">
  306. Optimal control
  307. </a>
  308. </li>
  309. <li class="toctree-l1">
  310. <a class="reference internal" href="../../bib.html">
  311. Bibliography
  312. </a>
  313. </li>
  314. </ul>
  315. </div>
  316. </nav> <!-- To handle the deprecated key -->
  317. <div class="navbar_extra_footer">
  318. Powered by <a href="https://jupyterbook.org">Jupyter Book</a>
  319. </div>
  320. </div>
  321. <main class="col py-md-3 pl-md-4 bd-content overflow-auto" role="main">
  322. <div class="topbar container-xl fixed-top">
  323. <div class="topbar-contents row">
  324. <div class="col-12 col-md-3 bd-topbar-whitespace site-navigation show"></div>
  325. <div class="col pl-md-4 topbar-main">
  326. <button id="navbar-toggler" class="navbar-toggler ml-0" type="button" data-toggle="collapse"
  327. data-toggle="tooltip" data-placement="bottom" data-target=".site-navigation" aria-controls="navbar-menu"
  328. aria-expanded="true" aria-label="Toggle navigation" aria-controls="site-navigation"
  329. title="Toggle navigation" data-toggle="tooltip" data-placement="left">
  330. <i class="fas fa-bars"></i>
  331. <i class="fas fa-arrow-left"></i>
  332. <i class="fas fa-arrow-up"></i>
  333. </button>
  334. <div class="dropdown-buttons-trigger">
  335. <button id="dropdown-buttons-trigger" class="btn btn-secondary topbarbtn" aria-label="Download this page"><i
  336. class="fas fa-download"></i></button>
  337. <div class="dropdown-buttons">
  338. <!-- ipynb file if we had a myst markdown file -->
  339. <!-- Download raw file -->
  340. <a class="dropdown-buttons" href="../../_sources/chapters/ssm/inference.ipynb"><button type="button"
  341. class="btn btn-secondary topbarbtn" title="Download source file" data-toggle="tooltip"
  342. data-placement="left">.ipynb</button></a>
  343. <!-- Download PDF via print -->
  344. <button type="button" id="download-print" class="btn btn-secondary topbarbtn" title="Print to PDF"
  345. onclick="printPdf(this)" data-toggle="tooltip" data-placement="left">.pdf</button>
  346. </div>
  347. </div>
  348. <!-- Source interaction buttons -->
  349. <div class="dropdown-buttons-trigger">
  350. <button id="dropdown-buttons-trigger" class="btn btn-secondary topbarbtn"
  351. aria-label="Connect with source repository"><i class="fab fa-github"></i></button>
  352. <div class="dropdown-buttons sourcebuttons">
  353. <a class="repository-button"
  354. href="https://github.com/probml/ssm-book"><button type="button" class="btn btn-secondary topbarbtn"
  355. data-toggle="tooltip" data-placement="left" title="Source repository"><i
  356. class="fab fa-github"></i>repository</button></a>
  357. <a class="issues-button"
  358. href="https://github.com/probml/ssm-book/issues/new?title=Issue%20on%20page%20%2Fchapters/ssm/inference.html&body=Your%20issue%20content%20here."><button
  359. type="button" class="btn btn-secondary topbarbtn" data-toggle="tooltip" data-placement="left"
  360. title="Open an issue"><i class="fas fa-lightbulb"></i>open issue</button></a>
  361. </div>
  362. </div>
  363. <!-- Full screen (wrap in <a> to have style consistency -->
  364. <a class="full-screen-button"><button type="button" class="btn btn-secondary topbarbtn" data-toggle="tooltip"
  365. data-placement="bottom" onclick="toggleFullScreen()" aria-label="Fullscreen mode"
  366. title="Fullscreen mode"><i
  367. class="fas fa-expand"></i></button></a>
  368. <!-- Launch buttons -->
  369. <div class="dropdown-buttons-trigger">
  370. <button id="dropdown-buttons-trigger" class="btn btn-secondary topbarbtn"
  371. aria-label="Launch interactive content"><i class="fas fa-rocket"></i></button>
  372. <div class="dropdown-buttons">
  373. <a class="binder-button" href="https://mybinder.org/v2/gh/probml/ssm-book/main?urlpath=tree/chapters/ssm/inference.ipynb"><button type="button"
  374. class="btn btn-secondary topbarbtn" title="Launch Binder" data-toggle="tooltip"
  375. data-placement="left"><img class="binder-button-logo"
  376. src="../../_static/images/logo_binder.svg"
  377. alt="Interact on binder">Binder</button></a>
  378. <a class="colab-button" href="https://colab.research.google.com/github/probml/ssm-book/blob/main/chapters/ssm/inference.ipynb"><button type="button" class="btn btn-secondary topbarbtn"
  379. title="Launch Colab" data-toggle="tooltip" data-placement="left"><img class="colab-button-logo"
  380. src="../../_static/images/logo_colab.png"
  381. alt="Interact on Colab">Colab</button></a>
  382. </div>
  383. </div>
  384. </div>
  385. <!-- Table of contents -->
  386. <div class="d-none d-md-block col-md-2 bd-toc show noprint">
  387. <div class="tocsection onthispage pt-5 pb-3">
  388. <i class="fas fa-list"></i> Contents
  389. </div>
  390. <nav id="bd-toc-nav" aria-label="Page">
  391. <ul class="visible nav section-nav flex-column">
  392. <li class="toc-h2 nav-item toc-entry">
  393. <a class="reference internal nav-link" href="#example-inference-in-the-casino-hmm">
  394. Example: inference in the casino HMM
  395. </a>
  396. </li>
  397. <li class="toc-h2 nav-item toc-entry">
  398. <a class="reference internal nav-link" href="#example-inference-in-the-tracking-lg-ssm">
  399. Example: inference in the tracking LG-SSM
  400. </a>
  401. </li>
  402. </ul>
  403. </nav>
  404. </div>
  405. </div>
  406. </div>
  407. <div id="main-content" class="row">
  408. <div class="col-12 col-md-9 pl-md-3 pr-md-0">
  409. <!-- Table of contents that is only displayed when printing the page -->
  410. <div id="jb-print-docs-body" class="onlyprint">
  411. <h1>States estimation (inference)</h1>
  412. <!-- Table of contents -->
  413. <div id="print-main-content">
  414. <div id="jb-print-toc">
  415. <div>
  416. <h2> Contents </h2>
  417. </div>
  418. <nav aria-label="Page">
  419. <ul class="visible nav section-nav flex-column">
  420. <li class="toc-h2 nav-item toc-entry">
  421. <a class="reference internal nav-link" href="#example-inference-in-the-casino-hmm">
  422. Example: inference in the casino HMM
  423. </a>
  424. </li>
  425. <li class="toc-h2 nav-item toc-entry">
  426. <a class="reference internal nav-link" href="#example-inference-in-the-tracking-lg-ssm">
  427. Example: inference in the tracking LG-SSM
  428. </a>
  429. </li>
  430. </ul>
  431. </nav>
  432. </div>
  433. </div>
  434. </div>
  435. <div>
  436. <div class="cell docutils container">
  437. <div class="cell_input docutils container">
  438. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="p">{</span>
  439. <span class="s2">&quot;tags&quot;</span><span class="p">:</span> <span class="p">[</span>
  440. <span class="s2">&quot;hide-cell&quot;</span>
  441. <span class="p">]</span>
  442. <span class="p">}</span>
  443. <span class="c1">### Import standard libraries</span>
  444. <span class="kn">import</span> <span class="nn">abc</span>
  445. <span class="kn">from</span> <span class="nn">dataclasses</span> <span class="kn">import</span> <span class="n">dataclass</span>
  446. <span class="kn">import</span> <span class="nn">functools</span>
  447. <span class="kn">import</span> <span class="nn">itertools</span>
  448. <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span><span class="p">,</span> <span class="n">Callable</span><span class="p">,</span> <span class="n">NamedTuple</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Union</span><span class="p">,</span> <span class="n">Tuple</span>
  449. <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span>
  450. <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
  451. <span class="kn">import</span> <span class="nn">jax</span>
  452. <span class="kn">import</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span>
  453. <span class="kn">from</span> <span class="nn">jax</span> <span class="kn">import</span> <span class="n">lax</span><span class="p">,</span> <span class="n">vmap</span><span class="p">,</span> <span class="n">jit</span><span class="p">,</span> <span class="n">grad</span>
  454. <span class="kn">from</span> <span class="nn">jax.scipy.special</span> <span class="kn">import</span> <span class="n">logit</span>
  455. <span class="kn">from</span> <span class="nn">jax.nn</span> <span class="kn">import</span> <span class="n">softmax</span>
  456. <span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">partial</span>
  457. <span class="kn">from</span> <span class="nn">jax.random</span> <span class="kn">import</span> <span class="n">PRNGKey</span><span class="p">,</span> <span class="n">split</span>
  458. <span class="kn">import</span> <span class="nn">jsl</span>
  459. <span class="kn">import</span> <span class="nn">ssm_jax</span>
  460. </pre></div>
  461. </div>
  462. </div>
  463. </div>
  464. <div class="tex2jax_ignore mathjax_ignore section" id="states-estimation-inference">
  465. <span id="sec-inference"></span><h1>States estimation (inference)<a class="headerlink" href="#states-estimation-inference" title="Permalink to this headline">¶</a></h1>
  466. <p>Given the sequence of observations, and a known model,
  467. one of the main tasks with SSMs
  468. to perform posterior inference,
  469. about the hidden states; this is also called
  470. state estimation.
  471. At each time step <span class="math notranslate nohighlight">\(t\)</span>,
  472. there are multiple forms of posterior we may be interested in computing,
  473. including the following:</p>
  474. <ul class="simple">
  475. <li><p>the filtering distribution
  476. <span class="math notranslate nohighlight">\(p(\hidden_t|\obs_{1:t})\)</span></p></li>
  477. <li><p>the smoothing distribution
  478. <span class="math notranslate nohighlight">\(p(\hidden_t|\obs_{1:T})\)</span> (note that this conditions on future data <span class="math notranslate nohighlight">\(T&gt;t\)</span>)</p></li>
  479. <li><p>the fixed-lag smoothing distribution
  480. <span class="math notranslate nohighlight">\(p(\hidden_{t-\ell}|\obs_{1:t})\)</span> (note that this
  481. infers <span class="math notranslate nohighlight">\(\ell\)</span> steps in the past given data up to the present).</p></li>
  482. </ul>
  483. <p>We may also want to compute the
  484. predictive distribution <span class="math notranslate nohighlight">\(h\)</span> steps into the future:</p>
  485. <div class="amsmath math notranslate nohighlight" id="equation-1f208dad-d4e8-412e-aed2-aa723b3f8bbf">
  486. <span class="eqno">(15)<a class="headerlink" href="#equation-1f208dad-d4e8-412e-aed2-aa723b3f8bbf" title="Permalink to this equation">¶</a></span>\[\begin{align}
  487. p(\obs_{t+h}|\obs_{1:t})
  488. = \sum_{\hidden_{t+h}} p(\obs_{t+h}|\hidden_{t+h}) p(\hidden_{t+h}|\obs_{1:t})
  489. \end{align}\]</div>
  490. <p>where the hidden state predictive distribution is</p>
  491. <div class="amsmath math notranslate nohighlight" id="equation-aeb53008-8650-4419-932d-1f31ed6427f1">
  492. <span class="eqno">(16)<a class="headerlink" href="#equation-aeb53008-8650-4419-932d-1f31ed6427f1" title="Permalink to this equation">¶</a></span>\[\begin{align}
  493. p(\hidden_{t+h}|\obs_{1:t})
  494. &amp;= \sum_{\hidden_{t:t+h-1}}
  495. p(\hidden_t|\obs_{1:t})
  496. p(\hidden_{t+1}|\hidden_{t})
  497. p(\hidden_{t+2}|\hidden_{t+1})
  498. \cdots
  499. p(\hidden_{t+h}|\hidden_{t+h-1})
  500. \end{align}\]</div>
  501. <p>See
  502. <a class="reference internal" href="#fig-dbn-inference"><span class="std std-numref">Fig. 5</span></a> for a summary of these distributions.</p>
  503. <div class="figure align-default" id="fig-dbn-inference">
  504. <a class="reference internal image-reference" href="../../_images/inference-problems-tikz.png"><img alt="../../_images/inference-problems-tikz.png" src="../../_images/inference-problems-tikz.png" style="width: 352.2px; height: 269.4px;" /></a>
  505. <p class="caption"><span class="caption-number">Fig. 5 </span><span class="caption-text">Illustration of the different kinds of inference in an SSM.
  506. The main kinds of inference for state-space models.
  507. The shaded region is the interval for which we have data.
  508. The arrow represents the time step at which we want to perform inference.
  509. <span class="math notranslate nohighlight">\(t\)</span> is the current time, <span class="math notranslate nohighlight">\(T\)</span> is the sequence length,
  510. <span class="math notranslate nohighlight">\(\ell\)</span> is the lag and <span class="math notranslate nohighlight">\(h\)</span> is the prediction horizon.</span><a class="headerlink" href="#fig-dbn-inference" title="Permalink to this image">¶</a></p>
  511. </div>
  512. <p>In addition to comuting posterior marginals,
  513. we may want to compute the most probable hidden sequence,
  514. i.e., the joint MAP estimate</p>
  515. <div class="math notranslate nohighlight">
  516. \[\arg \max_{\hidden_{1:T}} p(\hidden_{1:T}|\obs_{1:T})\]</div>
  517. <p>or sample sequences from the posterior</p>
  518. <div class="math notranslate nohighlight">
  519. \[\hidden_{1:T} \sim p(\hidden_{1:T}|\obs_{1:T})\]</div>
  520. <p>Algorithms for all these task are discussed in the following chapters,
  521. since the details depend on the form of the SSM.</p>
  522. <div class="section" id="example-inference-in-the-casino-hmm">
  523. <span id="sec-casino-inference"></span><h2>Example: inference in the casino HMM<a class="headerlink" href="#example-inference-in-the-casino-hmm" title="Permalink to this headline">¶</a></h2>
  524. <p>We now illustrate filtering, smoothing and MAP decoding applied
  525. to the casino HMM from <span class="xref std std-ref">sec:casino</span> and <span class="xref myst"></span>.</p>
  526. <div class="cell docutils container">
  527. <div class="cell_input docutils container">
  528. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># state transition matrix</span>
  529. <span class="n">A</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span>
  530. <span class="p">[</span><span class="mf">0.95</span><span class="p">,</span> <span class="mf">0.05</span><span class="p">],</span>
  531. <span class="p">[</span><span class="mf">0.10</span><span class="p">,</span> <span class="mf">0.90</span><span class="p">]</span>
  532. <span class="p">])</span>
  533. <span class="c1"># observation matrix</span>
  534. <span class="n">B</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span>
  535. <span class="p">[</span><span class="mi">1</span><span class="o">/</span><span class="mi">6</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">6</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">6</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">6</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">6</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">6</span><span class="p">],</span> <span class="c1"># fair die</span>
  536. <span class="p">[</span><span class="mi">1</span><span class="o">/</span><span class="mi">10</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">10</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">10</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">10</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">10</span><span class="p">,</span> <span class="mi">5</span><span class="o">/</span><span class="mi">10</span><span class="p">]</span> <span class="c1"># loaded die</span>
  537. <span class="p">])</span>
  538. <span class="n">pi</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">])</span>
  539. <span class="p">(</span><span class="n">nstates</span><span class="p">,</span> <span class="n">nobs</span><span class="p">)</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">B</span><span class="p">)</span>
  540. <span class="kn">import</span> <span class="nn">distrax</span>
  541. <span class="kn">from</span> <span class="nn">distrax</span> <span class="kn">import</span> <span class="n">HMM</span>
  542. <span class="n">hmm</span> <span class="o">=</span> <span class="n">HMM</span><span class="p">(</span><span class="n">trans_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">Categorical</span><span class="p">(</span><span class="n">probs</span><span class="o">=</span><span class="n">A</span><span class="p">),</span>
  543. <span class="n">init_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">Categorical</span><span class="p">(</span><span class="n">probs</span><span class="o">=</span><span class="n">pi</span><span class="p">),</span>
  544. <span class="n">obs_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">Categorical</span><span class="p">(</span><span class="n">probs</span><span class="o">=</span><span class="n">B</span><span class="p">))</span>
  545. <span class="n">seed</span> <span class="o">=</span> <span class="mi">314</span>
  546. <span class="n">n_samples</span> <span class="o">=</span> <span class="mi">300</span>
  547. <span class="n">z_hist</span><span class="p">,</span> <span class="n">x_hist</span> <span class="o">=</span> <span class="n">hmm</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="n">seed</span><span class="o">=</span><span class="n">PRNGKey</span><span class="p">(</span><span class="n">seed</span><span class="p">),</span> <span class="n">seq_len</span><span class="o">=</span><span class="n">n_samples</span><span class="p">)</span>
  548. </pre></div>
  549. </div>
  550. </div>
  551. <div class="cell_output docutils container">
  552. <div class="output stderr highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
  553. </pre></div>
  554. </div>
  555. </div>
  556. </div>
  557. <div class="cell docutils container">
  558. <div class="cell_input docutils container">
  559. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># Call inference engine</span>
  560. <span class="n">filtered_dist</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">smoothed_dist</span><span class="p">,</span> <span class="n">loglik</span> <span class="o">=</span> <span class="n">hmm</span><span class="o">.</span><span class="n">forward_backward</span><span class="p">(</span><span class="n">x_hist</span><span class="p">)</span>
  561. <span class="n">map_path</span> <span class="o">=</span> <span class="n">hmm</span><span class="o">.</span><span class="n">viterbi</span><span class="p">(</span><span class="n">x_hist</span><span class="p">)</span>
  562. </pre></div>
  563. </div>
  564. </div>
  565. <div class="cell_output docutils container">
  566. <div class="output stderr highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>/opt/anaconda3/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:4457: UserWarning: Explicitly requested dtype &lt;class &#39;jax.numpy.int64&#39;&gt; requested in astype is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  567. lax_internal._check_user_dtype_supported(dtype, &quot;astype&quot;)
  568. </pre></div>
  569. </div>
  570. </div>
  571. </div>
  572. <div class="cell docutils container">
  573. <div class="cell_input docutils container">
  574. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># Find the span of timesteps that the simulated systems turns to be in state 1</span>
  575. <span class="k">def</span> <span class="nf">find_dishonest_intervals</span><span class="p">(</span><span class="n">z_hist</span><span class="p">):</span>
  576. <span class="n">spans</span> <span class="o">=</span> <span class="p">[]</span>
  577. <span class="n">x_init</span> <span class="o">=</span> <span class="mi">0</span>
  578. <span class="k">for</span> <span class="n">t</span><span class="p">,</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">z_hist</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]):</span>
  579. <span class="k">if</span> <span class="n">z_hist</span><span class="p">[</span><span class="n">t</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">z_hist</span><span class="p">[</span><span class="n">t</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
  580. <span class="n">x_end</span> <span class="o">=</span> <span class="n">t</span>
  581. <span class="n">spans</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">x_init</span><span class="p">,</span> <span class="n">x_end</span><span class="p">))</span>
  582. <span class="k">elif</span> <span class="n">z_hist</span><span class="p">[</span><span class="n">t</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">and</span> <span class="n">z_hist</span><span class="p">[</span><span class="n">t</span><span class="p">]</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
  583. <span class="n">x_init</span> <span class="o">=</span> <span class="n">t</span> <span class="o">+</span> <span class="mi">1</span>
  584. <span class="k">return</span> <span class="n">spans</span>
  585. </pre></div>
  586. </div>
  587. </div>
  588. </div>
  589. <div class="cell docutils container">
  590. <div class="cell_input docutils container">
  591. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># Plot posterior</span>
  592. <span class="k">def</span> <span class="nf">plot_inference</span><span class="p">(</span><span class="n">inference_values</span><span class="p">,</span> <span class="n">z_hist</span><span class="p">,</span> <span class="n">ax</span><span class="p">,</span> <span class="n">state</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">map_estimate</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
  593. <span class="n">n_samples</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">inference_values</span><span class="p">)</span>
  594. <span class="n">xspan</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">n_samples</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
  595. <span class="n">spans</span> <span class="o">=</span> <span class="n">find_dishonest_intervals</span><span class="p">(</span><span class="n">z_hist</span><span class="p">)</span>
  596. <span class="k">if</span> <span class="n">map_estimate</span><span class="p">:</span>
  597. <span class="n">ax</span><span class="o">.</span><span class="n">step</span><span class="p">(</span><span class="n">xspan</span><span class="p">,</span> <span class="n">inference_values</span><span class="p">,</span> <span class="n">where</span><span class="o">=</span><span class="s2">&quot;post&quot;</span><span class="p">)</span>
  598. <span class="k">else</span><span class="p">:</span>
  599. <span class="n">ax</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">xspan</span><span class="p">,</span> <span class="n">inference_values</span><span class="p">[:,</span> <span class="n">state</span><span class="p">])</span>
  600. <span class="k">for</span> <span class="n">span</span> <span class="ow">in</span> <span class="n">spans</span><span class="p">:</span>
  601. <span class="n">ax</span><span class="o">.</span><span class="n">axvspan</span><span class="p">(</span><span class="o">*</span><span class="n">span</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span> <span class="n">facecolor</span><span class="o">=</span><span class="s2">&quot;tab:gray&quot;</span><span class="p">,</span> <span class="n">edgecolor</span><span class="o">=</span><span class="s2">&quot;none&quot;</span><span class="p">)</span>
  602. <span class="n">ax</span><span class="o">.</span><span class="n">set_xlim</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">n_samples</span><span class="p">)</span>
  603. <span class="c1"># ax.set_ylim(0, 1)</span>
  604. <span class="n">ax</span><span class="o">.</span><span class="n">set_ylim</span><span class="p">(</span><span class="o">-</span><span class="mf">0.1</span><span class="p">,</span> <span class="mf">1.1</span><span class="p">)</span>
  605. <span class="n">ax</span><span class="o">.</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s2">&quot;Observation number&quot;</span><span class="p">)</span>
  606. </pre></div>
  607. </div>
  608. </div>
  609. </div>
  610. <div class="cell docutils container">
  611. <div class="cell_input docutils container">
  612. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span> <span class="c1"># Filtering</span>
  613. <span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">()</span>
  614. <span class="n">plot_inference</span><span class="p">(</span><span class="n">filtered_dist</span><span class="p">,</span> <span class="n">z_hist</span><span class="p">,</span> <span class="n">ax</span><span class="p">)</span>
  615. <span class="n">ax</span><span class="o">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s2">&quot;p(loaded)&quot;</span><span class="p">)</span>
  616. <span class="n">ax</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">&quot;Filtered&quot;</span><span class="p">)</span>
  617. </pre></div>
  618. </div>
  619. </div>
  620. <div class="cell_output docutils container">
  621. <div class="output text_plain highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>Text(0.5, 1.0, &#39;Filtered&#39;)
  622. </pre></div>
  623. </div>
  624. <img alt="../../_images/inference_7_1.png" src="../../_images/inference_7_1.png" />
  625. </div>
  626. </div>
  627. <div class="cell docutils container">
  628. <div class="cell_input docutils container">
  629. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># Smoothing</span>
  630. <span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">()</span>
  631. <span class="n">plot_inference</span><span class="p">(</span><span class="n">smoothed_dist</span><span class="p">,</span> <span class="n">z_hist</span><span class="p">,</span> <span class="n">ax</span><span class="p">)</span>
  632. <span class="n">ax</span><span class="o">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s2">&quot;p(loaded)&quot;</span><span class="p">)</span>
  633. <span class="n">ax</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">&quot;Smoothed&quot;</span><span class="p">)</span>
  634. </pre></div>
  635. </div>
  636. </div>
  637. <div class="cell_output docutils container">
  638. <div class="output text_plain highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>Text(0.5, 1.0, &#39;Smoothed&#39;)
  639. </pre></div>
  640. </div>
  641. <img alt="../../_images/inference_8_1.png" src="../../_images/inference_8_1.png" />
  642. </div>
  643. </div>
  644. <div class="cell docutils container">
  645. <div class="cell_input docutils container">
  646. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># MAP estimation</span>
  647. <span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">()</span>
  648. <span class="n">plot_inference</span><span class="p">(</span><span class="n">map_path</span><span class="p">,</span> <span class="n">z_hist</span><span class="p">,</span> <span class="n">ax</span><span class="p">,</span> <span class="n">map_estimate</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
  649. <span class="n">ax</span><span class="o">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s2">&quot;MAP state&quot;</span><span class="p">)</span>
  650. <span class="n">ax</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">&quot;Viterbi&quot;</span><span class="p">)</span>
  651. </pre></div>
  652. </div>
  653. </div>
  654. <div class="cell_output docutils container">
  655. <div class="output text_plain highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>Text(0.5, 1.0, &#39;Viterbi&#39;)
  656. </pre></div>
  657. </div>
  658. <img alt="../../_images/inference_9_1.png" src="../../_images/inference_9_1.png" />
  659. </div>
  660. </div>
  661. <div class="cell docutils container">
  662. <div class="cell_input docutils container">
  663. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># TODO: posterior samples</span>
  664. </pre></div>
  665. </div>
  666. </div>
  667. </div>
  668. </div>
  669. <div class="section" id="example-inference-in-the-tracking-lg-ssm">
  670. <h2>Example: inference in the tracking LG-SSM<a class="headerlink" href="#example-inference-in-the-tracking-lg-ssm" title="Permalink to this headline">¶</a></h2>
  671. <p>We now illustrate filtering, smoothing and MAP decoding applied
  672. to the 2d tracking HMM from <a class="reference internal" href="lds.html#sec-tracking-lds"><span class="std std-ref">Example: tracking a 2d point</span></a>.</p>
  673. <div class="cell docutils container">
  674. <div class="cell_input docutils container">
  675. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">key</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">PRNGKey</span><span class="p">(</span><span class="mi">314</span><span class="p">)</span>
  676. <span class="n">timesteps</span> <span class="o">=</span> <span class="mi">15</span>
  677. <span class="n">delta</span> <span class="o">=</span> <span class="mf">1.0</span>
  678. <span class="n">A</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span>
  679. <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">delta</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span>
  680. <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">delta</span><span class="p">],</span>
  681. <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span>
  682. <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">]</span>
  683. <span class="p">])</span>
  684. <span class="n">C</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span>
  685. <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span>
  686. <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span>
  687. <span class="p">])</span>
  688. <span class="n">state_size</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">A</span><span class="o">.</span><span class="n">shape</span>
  689. <span class="n">observation_size</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">C</span><span class="o">.</span><span class="n">shape</span>
  690. <span class="n">Q</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="n">state_size</span><span class="p">)</span> <span class="o">*</span> <span class="mf">0.001</span>
  691. <span class="n">R</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="n">observation_size</span><span class="p">)</span> <span class="o">*</span> <span class="mf">1.0</span>
  692. <span class="n">mu0</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mi">8</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">])</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="nb">float</span><span class="p">)</span>
  693. <span class="n">Sigma0</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="n">state_size</span><span class="p">)</span> <span class="o">*</span> <span class="mf">1.0</span>
  694. <span class="kn">from</span> <span class="nn">jsl.lds.kalman_filter</span> <span class="kn">import</span> <span class="n">LDS</span><span class="p">,</span> <span class="n">smooth</span><span class="p">,</span> <span class="nb">filter</span>
  695. <span class="n">lds</span> <span class="o">=</span> <span class="n">LDS</span><span class="p">(</span><span class="n">A</span><span class="p">,</span> <span class="n">C</span><span class="p">,</span> <span class="n">Q</span><span class="p">,</span> <span class="n">R</span><span class="p">,</span> <span class="n">mu0</span><span class="p">,</span> <span class="n">Sigma0</span><span class="p">)</span>
  696. <span class="n">z_hist</span><span class="p">,</span> <span class="n">x_hist</span> <span class="o">=</span> <span class="n">lds</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">timesteps</span><span class="p">)</span>
  697. </pre></div>
  698. </div>
  699. </div>
  700. </div>
  701. <div class="cell docutils container">
  702. <div class="cell_input docutils container">
  703. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">jsl.demos.plot_utils</span> <span class="kn">import</span> <span class="n">plot_ellipse</span>
  704. <span class="k">def</span> <span class="nf">plot_tracking_values</span><span class="p">(</span><span class="n">observed</span><span class="p">,</span> <span class="n">filtered</span><span class="p">,</span> <span class="n">cov_hist</span><span class="p">,</span> <span class="n">signal_label</span><span class="p">,</span> <span class="n">ax</span><span class="p">):</span>
  705. <span class="n">timesteps</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">observed</span><span class="o">.</span><span class="n">shape</span>
  706. <span class="n">ax</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">observed</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">observed</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">marker</span><span class="o">=</span><span class="s2">&quot;o&quot;</span><span class="p">,</span> <span class="n">linewidth</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
  707. <span class="n">markerfacecolor</span><span class="o">=</span><span class="s2">&quot;none&quot;</span><span class="p">,</span> <span class="n">markeredgewidth</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">markersize</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s2">&quot;observed&quot;</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="s2">&quot;tab:green&quot;</span><span class="p">)</span>
  708. <span class="n">ax</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="o">*</span><span class="n">filtered</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">T</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="n">signal_label</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="s2">&quot;tab:red&quot;</span><span class="p">,</span> <span class="n">marker</span><span class="o">=</span><span class="s2">&quot;x&quot;</span><span class="p">,</span> <span class="n">linewidth</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
  709. <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">timesteps</span><span class="p">,</span> <span class="mi">1</span><span class="p">):</span>
  710. <span class="n">covn</span> <span class="o">=</span> <span class="n">cov_hist</span><span class="p">[</span><span class="n">t</span><span class="p">][:</span><span class="mi">2</span><span class="p">,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]</span>
  711. <span class="n">plot_ellipse</span><span class="p">(</span><span class="n">covn</span><span class="p">,</span> <span class="n">filtered</span><span class="p">[</span><span class="n">t</span><span class="p">,</span> <span class="p">:</span><span class="mi">2</span><span class="p">],</span> <span class="n">ax</span><span class="p">,</span> <span class="n">n_std</span><span class="o">=</span><span class="mf">2.0</span><span class="p">,</span> <span class="n">plot_center</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
  712. <span class="n">ax</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">&quot;equal&quot;</span><span class="p">)</span>
  713. <span class="n">ax</span><span class="o">.</span><span class="n">legend</span><span class="p">()</span>
  714. </pre></div>
  715. </div>
  716. </div>
  717. </div>
  718. <div class="cell docutils container">
  719. <div class="cell_input docutils container">
  720. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># Filtering</span>
  721. <span class="n">mu_hist</span><span class="p">,</span> <span class="n">Sigma_hist</span><span class="p">,</span> <span class="n">mu_cond_hist</span><span class="p">,</span> <span class="n">Sigma_cond_hist</span> <span class="o">=</span> <span class="nb">filter</span><span class="p">(</span><span class="n">lds</span><span class="p">,</span> <span class="n">x_hist</span><span class="p">)</span>
  722. <span class="n">l2_filter</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">linalg</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">z_hist</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]</span> <span class="o">-</span> <span class="n">mu_hist</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">],</span> <span class="mi">2</span><span class="p">)</span>
  723. <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;L2-filter: </span><span class="si">{</span><span class="n">l2_filter</span><span class="si">:</span><span class="s2">0.4f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
  724. <span class="n">fig_filtered</span><span class="p">,</span> <span class="n">axs</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">()</span>
  725. <span class="n">plot_tracking_values</span><span class="p">(</span><span class="n">x_hist</span><span class="p">,</span> <span class="n">mu_hist</span><span class="p">,</span> <span class="n">Sigma_hist</span><span class="p">,</span> <span class="s2">&quot;filtered&quot;</span><span class="p">,</span> <span class="n">axs</span><span class="p">)</span>
  726. </pre></div>
  727. </div>
  728. </div>
  729. <div class="cell_output docutils container">
  730. <div class="output stream highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>L2-filter: 3.2481
  731. </pre></div>
  732. </div>
  733. <img alt="../../_images/inference_14_1.png" src="../../_images/inference_14_1.png" />
  734. </div>
  735. </div>
  736. <div class="cell docutils container">
  737. <div class="cell_input docutils container">
  738. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># Smoothing</span>
  739. <span class="n">mu_hist_smooth</span><span class="p">,</span> <span class="n">Sigma_hist_smooth</span> <span class="o">=</span> <span class="n">smooth</span><span class="p">(</span><span class="n">lds</span><span class="p">,</span> <span class="n">mu_hist</span><span class="p">,</span> <span class="n">Sigma_hist</span><span class="p">,</span> <span class="n">mu_cond_hist</span><span class="p">,</span> <span class="n">Sigma_cond_hist</span><span class="p">)</span>
  740. <span class="n">l2_smooth</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">linalg</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">z_hist</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]</span> <span class="o">-</span> <span class="n">mu_hist_smooth</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">],</span> <span class="mi">2</span><span class="p">)</span>
  741. <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;L2-smooth: </span><span class="si">{</span><span class="n">l2_smooth</span><span class="si">:</span><span class="s2">0.4f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
  742. <span class="n">fig_smoothed</span><span class="p">,</span> <span class="n">axs</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">()</span>
  743. <span class="n">plot_tracking_values</span><span class="p">(</span><span class="n">x_hist</span><span class="p">,</span> <span class="n">mu_hist_smooth</span><span class="p">,</span> <span class="n">Sigma_hist_smooth</span><span class="p">,</span> <span class="s2">&quot;smoothed&quot;</span><span class="p">,</span> <span class="n">axs</span><span class="p">)</span>
  744. </pre></div>
  745. </div>
  746. </div>
  747. <div class="cell_output docutils container">
  748. <div class="output stream highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>L2-smooth: 2.0450
  749. </pre></div>
  750. </div>
  751. <img alt="../../_images/inference_15_1.png" src="../../_images/inference_15_1.png" />
  752. </div>
  753. </div>
  754. </div>
  755. </div>
  756. <script type="text/x-thebe-config">
  757. {
  758. requestKernel: true,
  759. binderOptions: {
  760. repo: "binder-examples/jupyter-stacks-datascience",
  761. ref: "master",
  762. },
  763. codeMirrorConfig: {
  764. theme: "abcdef",
  765. mode: "python"
  766. },
  767. kernelOptions: {
  768. kernelName: "python3",
  769. path: "./chapters/ssm"
  770. },
  771. predefinedOutput: true
  772. }
  773. </script>
  774. <script>kernelName = 'python3'</script>
  775. </div>
  776. <!-- Previous / next buttons -->
  777. <div class='prev-next-area'>
  778. <a class='left-prev' id="prev-link" href="nlds.html" title="previous page">
  779. <i class="fas fa-angle-left"></i>
  780. <div class="prev-next-info">
  781. <p class="prev-next-subtitle">previous</p>
  782. <p class="prev-next-title">Nonlinear Gaussian SSMs</p>
  783. </div>
  784. </a>
  785. <a class='right-next' id="next-link" href="learning.html" title="next page">
  786. <div class="prev-next-info">
  787. <p class="prev-next-subtitle">next</p>
  788. <p class="prev-next-title">Parameter estimation (learning)</p>
  789. </div>
  790. <i class="fas fa-angle-right"></i>
  791. </a>
  792. </div>
  793. </div>
  794. </div>
  795. <footer class="footer">
  796. <p>
  797. By Kevin Murphy, Scott Linderman, et al.<br/>
  798. &copy; Copyright 2021.<br/>
  799. </p>
  800. </footer>
  801. </main>
  802. </div>
  803. </div>
  804. <script src="../../_static/js/index.be7d3bbb2ef33a8344ce.js"></script>
  805. </body>
  806. </html>