scratchpad.html 48 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856
  1. <!DOCTYPE html>
  2. <html>
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>&lt;no title&gt; &#8212; State Space Models: A Modern Approach</title>
  7. <link href="../_static/css/theme.css" rel="stylesheet">
  8. <link href="../_static/css/index.ff1ffe594081f20da1ef19478df9384b.css" rel="stylesheet">
  9. <link rel="stylesheet"
  10. href="../_static/vendor/fontawesome/5.13.0/css/all.min.css">
  11. <link rel="preload" as="font" type="font/woff2" crossorigin
  12. href="../_static/vendor/fontawesome/5.13.0/webfonts/fa-solid-900.woff2">
  13. <link rel="preload" as="font" type="font/woff2" crossorigin
  14. href="../_static/vendor/fontawesome/5.13.0/webfonts/fa-brands-400.woff2">
  15. <link rel="stylesheet" type="text/css" href="../_static/pygments.css" />
  16. <link rel="stylesheet" type="text/css" href="../_static/sphinx-book-theme.css?digest=c3fdc42140077d1ad13ad2f1588a4309" />
  17. <link rel="stylesheet" type="text/css" href="../_static/togglebutton.css" />
  18. <link rel="stylesheet" type="text/css" href="../_static/copybutton.css" />
  19. <link rel="stylesheet" type="text/css" href="../_static/mystnb.css" />
  20. <link rel="stylesheet" type="text/css" href="../_static/sphinx-thebe.css" />
  21. <link rel="stylesheet" type="text/css" href="../_static/panels-main.c949a650a448cc0ae9fd3441c0e17fb0.css" />
  22. <link rel="stylesheet" type="text/css" href="../_static/panels-variables.06eb56fa6e07937060861dad626602ad.css" />
  23. <link rel="preload" as="script" href="../_static/js/index.be7d3bbb2ef33a8344ce.js">
  24. <script data-url_root="../" id="documentation_options" src="../_static/documentation_options.js"></script>
  25. <script src="../_static/jquery.js"></script>
  26. <script src="../_static/underscore.js"></script>
  27. <script src="../_static/doctools.js"></script>
  28. <script src="../_static/clipboard.min.js"></script>
  29. <script src="../_static/copybutton.js"></script>
  30. <script>let toggleHintShow = 'Click to show';</script>
  31. <script>let toggleHintHide = 'Click to hide';</script>
  32. <script>let toggleOpenOnPrint = 'true';</script>
  33. <script src="../_static/togglebutton.js"></script>
  34. <script>var togglebuttonSelector = '.toggle, .admonition.dropdown, .tag_hide_input div.cell_input, .tag_hide-input div.cell_input, .tag_hide_output div.cell_output, .tag_hide-output div.cell_output, .tag_hide_cell.cell, .tag_hide-cell.cell';</script>
  35. <script src="../_static/sphinx-book-theme.d59cb220de22ca1c485ebbdc042f0030.js"></script>
  36. <script>const THEBE_JS_URL = "https://unpkg.com/thebe@0.8.2/lib/index.js"
  37. const thebe_selector = ".thebe,.cell"
  38. const thebe_selector_input = "pre"
  39. const thebe_selector_output = ".output, .cell_output"
  40. </script>
  41. <script async="async" src="../_static/sphinx-thebe.js"></script>
  42. <link rel="index" title="Index" href="../genindex.html" />
  43. <link rel="search" title="Search" href="../search.html" />
  44. <meta name="viewport" content="width=device-width, initial-scale=1" />
  45. <meta name="docsearch:language" content="None">
  46. <!-- Google Analytics -->
  47. </head>
  48. <body data-spy="scroll" data-target="#bd-toc-nav" data-offset="80">
  49. <div class="container-fluid" id="banner"></div>
  50. <div class="container-xl">
  51. <div class="row">
  52. <div class="col-12 col-md-3 bd-sidebar site-navigation show" id="site-navigation">
  53. <div class="navbar-brand-box">
  54. <a class="navbar-brand text-wrap" href="../index.html">
  55. <h1 class="site-logo" id="site-title">State Space Models: A Modern Approach</h1>
  56. </a>
  57. </div><form class="bd-search d-flex align-items-center" action="../search.html" method="get">
  58. <i class="icon fas fa-search"></i>
  59. <input type="search" class="form-control" name="q" id="search-input" placeholder="Search this book..." aria-label="Search this book..." autocomplete="off" >
  60. </form><nav class="bd-links" id="bd-docs-nav" aria-label="Main">
  61. <div class="bd-toc-item active">
  62. <ul class="nav bd-sidenav">
  63. <li class="toctree-l1">
  64. <a class="reference internal" href="../root.html">
  65. State Space Models: A Modern Approach
  66. </a>
  67. </li>
  68. </ul>
  69. <ul class="nav bd-sidenav">
  70. <li class="toctree-l1 has-children">
  71. <a class="reference internal" href="ssm/ssm_index.html">
  72. State Space Models
  73. </a>
  74. <input class="toctree-checkbox" id="toctree-checkbox-1" name="toctree-checkbox-1" type="checkbox"/>
  75. <label for="toctree-checkbox-1">
  76. <i class="fas fa-chevron-down">
  77. </i>
  78. </label>
  79. <ul>
  80. <li class="toctree-l2">
  81. <a class="reference internal" href="ssm/ssm_intro.html">
  82. What are State Space Models?
  83. </a>
  84. </li>
  85. <li class="toctree-l2">
  86. <a class="reference internal" href="ssm/hmm.html">
  87. Hidden Markov Models
  88. </a>
  89. </li>
  90. <li class="toctree-l2">
  91. <a class="reference internal" href="ssm/lds.html">
  92. Linear Gaussian SSMs
  93. </a>
  94. </li>
  95. <li class="toctree-l2">
  96. <a class="reference internal" href="ssm/nlds.html">
  97. Nonlinear Gaussian SSMs
  98. </a>
  99. </li>
  100. <li class="toctree-l2">
  101. <a class="reference internal" href="ssm/inference.html">
  102. States estimation (inference)
  103. </a>
  104. </li>
  105. <li class="toctree-l2">
  106. <a class="reference internal" href="ssm/learning.html">
  107. Parameter estimation (learning)
  108. </a>
  109. </li>
  110. </ul>
  111. </li>
  112. <li class="toctree-l1 has-children">
  113. <a class="reference internal" href="hmm/hmm_index.html">
  114. Hidden Markov Models
  115. </a>
  116. <input class="toctree-checkbox" id="toctree-checkbox-2" name="toctree-checkbox-2" type="checkbox"/>
  117. <label for="toctree-checkbox-2">
  118. <i class="fas fa-chevron-down">
  119. </i>
  120. </label>
  121. <ul>
  122. <li class="toctree-l2">
  123. <a class="reference internal" href="hmm/hmm_filter.html">
  124. HMM filtering (forwards algorithm)
  125. </a>
  126. </li>
  127. <li class="toctree-l2">
  128. <a class="reference internal" href="hmm/hmm_smoother.html">
  129. HMM smoothing (forwards-backwards algorithm)
  130. </a>
  131. </li>
  132. <li class="toctree-l2">
  133. <a class="reference internal" href="hmm/hmm_viterbi.html">
  134. Viterbi algorithm
  135. </a>
  136. </li>
  137. <li class="toctree-l2">
  138. <a class="reference internal" href="hmm/hmm_parallel.html">
  139. Parallel HMM smoothing
  140. </a>
  141. </li>
  142. <li class="toctree-l2">
  143. <a class="reference internal" href="hmm/hmm_sampling.html">
  144. Forwards-filtering backwards-sampling algorithm
  145. </a>
  146. </li>
  147. </ul>
  148. </li>
  149. <li class="toctree-l1 has-children">
  150. <a class="reference internal" href="lgssm/lgssm_index.html">
  151. Linear-Gaussian SSMs
  152. </a>
  153. <input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
  154. <label for="toctree-checkbox-3">
  155. <i class="fas fa-chevron-down">
  156. </i>
  157. </label>
  158. <ul>
  159. <li class="toctree-l2">
  160. <a class="reference internal" href="lgssm/kalman_filter.html">
  161. Kalman filtering
  162. </a>
  163. </li>
  164. <li class="toctree-l2">
  165. <a class="reference internal" href="lgssm/kalman_smoother.html">
  166. Kalman (RTS) smoother
  167. </a>
  168. </li>
  169. <li class="toctree-l2">
  170. <a class="reference internal" href="lgssm/kalman_parallel.html">
  171. Parallel Kalman Smoother
  172. </a>
  173. </li>
  174. <li class="toctree-l2">
  175. <a class="reference internal" href="lgssm/kalman_sampling.html">
  176. Forwards-filtering backwards sampling
  177. </a>
  178. </li>
  179. </ul>
  180. </li>
  181. <li class="toctree-l1 has-children">
  182. <a class="reference internal" href="extended/extended_index.html">
  183. Extended (linearized) methods
  184. </a>
  185. <input class="toctree-checkbox" id="toctree-checkbox-4" name="toctree-checkbox-4" type="checkbox"/>
  186. <label for="toctree-checkbox-4">
  187. <i class="fas fa-chevron-down">
  188. </i>
  189. </label>
  190. <ul>
  191. <li class="toctree-l2">
  192. <a class="reference internal" href="extended/extended_filter.html">
  193. Extended Kalman filtering
  194. </a>
  195. </li>
  196. <li class="toctree-l2">
  197. <a class="reference internal" href="extended/extended_smoother.html">
  198. Extended Kalman smoother
  199. </a>
  200. </li>
  201. <li class="toctree-l2">
  202. <a class="reference internal" href="extended/extended_parallel.html">
  203. Parallel extended Kalman smoothing
  204. </a>
  205. </li>
  206. </ul>
  207. </li>
  208. <li class="toctree-l1 has-children">
  209. <a class="reference internal" href="unscented/unscented_index.html">
  210. Unscented methods
  211. </a>
  212. <input class="toctree-checkbox" id="toctree-checkbox-5" name="toctree-checkbox-5" type="checkbox"/>
  213. <label for="toctree-checkbox-5">
  214. <i class="fas fa-chevron-down">
  215. </i>
  216. </label>
  217. <ul>
  218. <li class="toctree-l2">
  219. <a class="reference internal" href="unscented/unscented_filter.html">
  220. Unscented filtering
  221. </a>
  222. </li>
  223. <li class="toctree-l2">
  224. <a class="reference internal" href="unscented/unscented_smoother.html">
  225. Unscented smoothing
  226. </a>
  227. </li>
  228. </ul>
  229. </li>
  230. <li class="toctree-l1">
  231. <a class="reference internal" href="quadrature/quadrature_index.html">
  232. Quadrature and cubature methods
  233. </a>
  234. </li>
  235. <li class="toctree-l1">
  236. <a class="reference internal" href="postlin/postlin_index.html">
  237. Posterior linearization
  238. </a>
  239. </li>
  240. <li class="toctree-l1">
  241. <a class="reference internal" href="adf/adf_index.html">
  242. Assumed Density Filtering
  243. </a>
  244. </li>
  245. <li class="toctree-l1">
  246. <a class="reference internal" href="vi/vi_index.html">
  247. Variational inference
  248. </a>
  249. </li>
  250. <li class="toctree-l1">
  251. <a class="reference internal" href="pf/pf_index.html">
  252. Particle filtering
  253. </a>
  254. </li>
  255. <li class="toctree-l1">
  256. <a class="reference internal" href="smc/smc_index.html">
  257. Sequential Monte Carlo
  258. </a>
  259. </li>
  260. <li class="toctree-l1">
  261. <a class="reference internal" href="learning/learning_index.html">
  262. Offline parameter estimation (learning)
  263. </a>
  264. </li>
  265. <li class="toctree-l1">
  266. <a class="reference internal" href="tracking/tracking_index.html">
  267. Multi-target tracking
  268. </a>
  269. </li>
  270. <li class="toctree-l1">
  271. <a class="reference internal" href="ensemble/ensemble_index.html">
  272. Data assimilation using Ensemble Kalman filter
  273. </a>
  274. </li>
  275. <li class="toctree-l1">
  276. <a class="reference internal" href="bnp/bnp_index.html">
  277. Bayesian non-parametric SSMs
  278. </a>
  279. </li>
  280. <li class="toctree-l1">
  281. <a class="reference internal" href="changepoint/changepoint_index.html">
  282. Changepoint detection
  283. </a>
  284. </li>
  285. <li class="toctree-l1">
  286. <a class="reference internal" href="timeseries/timeseries_index.html">
  287. Timeseries forecasting
  288. </a>
  289. </li>
  290. <li class="toctree-l1">
  291. <a class="reference internal" href="gp/gp_index.html">
  292. Markovian Gaussian processes
  293. </a>
  294. </li>
  295. <li class="toctree-l1">
  296. <a class="reference internal" href="ode/ode_index.html">
  297. Differential equations and SSMs
  298. </a>
  299. </li>
  300. <li class="toctree-l1">
  301. <a class="reference internal" href="control/control_index.html">
  302. Optimal control
  303. </a>
  304. </li>
  305. <li class="toctree-l1">
  306. <a class="reference internal" href="../bib.html">
  307. Bibliography
  308. </a>
  309. </li>
  310. </ul>
  311. </div>
  312. </nav> <!-- To handle the deprecated key -->
  313. <div class="navbar_extra_footer">
  314. Powered by <a href="https://jupyterbook.org">Jupyter Book</a>
  315. </div>
  316. </div>
  317. <main class="col py-md-3 pl-md-4 bd-content overflow-auto" role="main">
  318. <div class="topbar container-xl fixed-top">
  319. <div class="topbar-contents row">
  320. <div class="col-12 col-md-3 bd-topbar-whitespace site-navigation show"></div>
  321. <div class="col pl-md-4 topbar-main">
  322. <button id="navbar-toggler" class="navbar-toggler ml-0" type="button" data-toggle="collapse"
  323. data-toggle="tooltip" data-placement="bottom" data-target=".site-navigation" aria-controls="navbar-menu"
  324. aria-expanded="true" aria-label="Toggle navigation" aria-controls="site-navigation"
  325. title="Toggle navigation" data-toggle="tooltip" data-placement="left">
  326. <i class="fas fa-bars"></i>
  327. <i class="fas fa-arrow-left"></i>
  328. <i class="fas fa-arrow-up"></i>
  329. </button>
  330. <div class="dropdown-buttons-trigger">
  331. <button id="dropdown-buttons-trigger" class="btn btn-secondary topbarbtn" aria-label="Download this page"><i
  332. class="fas fa-download"></i></button>
  333. <div class="dropdown-buttons">
  334. <!-- ipynb file if we had a myst markdown file -->
  335. <!-- Download raw file -->
  336. <a class="dropdown-buttons" href="../_sources/chapters/scratchpad.ipynb"><button type="button"
  337. class="btn btn-secondary topbarbtn" title="Download source file" data-toggle="tooltip"
  338. data-placement="left">.ipynb</button></a>
  339. <!-- Download PDF via print -->
  340. <button type="button" id="download-print" class="btn btn-secondary topbarbtn" title="Print to PDF"
  341. onclick="printPdf(this)" data-toggle="tooltip" data-placement="left">.pdf</button>
  342. </div>
  343. </div>
  344. <!-- Source interaction buttons -->
  345. <div class="dropdown-buttons-trigger">
  346. <button id="dropdown-buttons-trigger" class="btn btn-secondary topbarbtn"
  347. aria-label="Connect with source repository"><i class="fab fa-github"></i></button>
  348. <div class="dropdown-buttons sourcebuttons">
  349. <a class="repository-button"
  350. href="https://github.com/probml/ssm-book"><button type="button" class="btn btn-secondary topbarbtn"
  351. data-toggle="tooltip" data-placement="left" title="Source repository"><i
  352. class="fab fa-github"></i>repository</button></a>
  353. <a class="issues-button"
  354. href="https://github.com/probml/ssm-book/issues/new?title=Issue%20on%20page%20%2Fchapters/scratchpad.html&body=Your%20issue%20content%20here."><button
  355. type="button" class="btn btn-secondary topbarbtn" data-toggle="tooltip" data-placement="left"
  356. title="Open an issue"><i class="fas fa-lightbulb"></i>open issue</button></a>
  357. </div>
  358. </div>
  359. <!-- Full screen (wrap in <a> to have style consistency -->
  360. <a class="full-screen-button"><button type="button" class="btn btn-secondary topbarbtn" data-toggle="tooltip"
  361. data-placement="bottom" onclick="toggleFullScreen()" aria-label="Fullscreen mode"
  362. title="Fullscreen mode"><i
  363. class="fas fa-expand"></i></button></a>
  364. <!-- Launch buttons -->
  365. <div class="dropdown-buttons-trigger">
  366. <button id="dropdown-buttons-trigger" class="btn btn-secondary topbarbtn"
  367. aria-label="Launch interactive content"><i class="fas fa-rocket"></i></button>
  368. <div class="dropdown-buttons">
  369. <a class="binder-button" href="https://mybinder.org/v2/gh/probml/ssm-book/main?urlpath=tree/chapters/scratchpad.ipynb"><button type="button"
  370. class="btn btn-secondary topbarbtn" title="Launch Binder" data-toggle="tooltip"
  371. data-placement="left"><img class="binder-button-logo"
  372. src="../_static/images/logo_binder.svg"
  373. alt="Interact on binder">Binder</button></a>
  374. <a class="colab-button" href="https://colab.research.google.com/github/probml/ssm-book/blob/main/chapters/scratchpad.ipynb"><button type="button" class="btn btn-secondary topbarbtn"
  375. title="Launch Colab" data-toggle="tooltip" data-placement="left"><img class="colab-button-logo"
  376. src="../_static/images/logo_colab.png"
  377. alt="Interact on Colab">Colab</button></a>
  378. </div>
  379. </div>
  380. </div>
  381. <!-- Table of contents -->
  382. <div class="d-none d-md-block col-md-2 bd-toc show noprint">
  383. <div class="tocsection onthispage pt-5 pb-3">
  384. <i class="fas fa-list"></i> Contents
  385. </div>
  386. <nav id="bd-toc-nav" aria-label="Page">
  387. <ul class="simple visible nav section-nav flex-column">
  388. </ul>
  389. </nav>
  390. </div>
  391. </div>
  392. </div>
  393. <div id="main-content" class="row">
  394. <div class="col-12 col-md-9 pl-md-3 pr-md-0">
  395. <!-- Table of contents that is only displayed when printing the page -->
  396. <div id="jb-print-docs-body" class="onlyprint">
  397. <h1><no title></h1>
  398. <!-- Table of contents -->
  399. <div id="print-main-content">
  400. <div id="jb-print-toc">
  401. <div>
  402. <h2> Contents </h2>
  403. </div>
  404. <nav aria-label="Page">
  405. <ul class="simple visible nav section-nav flex-column">
  406. </ul>
  407. </nav>
  408. </div>
  409. </div>
  410. </div>
  411. <div>
  412. <div class="cell docutils container">
  413. <div class="cell_input docutils container">
  414. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1">### Import standard libraries</span>
  415. <span class="kn">import</span> <span class="nn">abc</span>
  416. <span class="kn">from</span> <span class="nn">dataclasses</span> <span class="kn">import</span> <span class="n">dataclass</span>
  417. <span class="kn">import</span> <span class="nn">functools</span>
  418. <span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">partial</span>
  419. <span class="kn">import</span> <span class="nn">itertools</span>
  420. <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span>
  421. <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
  422. <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span><span class="p">,</span> <span class="n">Callable</span><span class="p">,</span> <span class="n">NamedTuple</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Union</span><span class="p">,</span> <span class="n">Tuple</span>
  423. <span class="kn">import</span> <span class="nn">jax</span>
  424. <span class="kn">import</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span>
  425. <span class="kn">from</span> <span class="nn">jax</span> <span class="kn">import</span> <span class="n">lax</span><span class="p">,</span> <span class="n">vmap</span><span class="p">,</span> <span class="n">jit</span><span class="p">,</span> <span class="n">grad</span>
  426. <span class="c1">#from jax.scipy.special import logit</span>
  427. <span class="c1">#from jax.nn import softmax</span>
  428. <span class="kn">import</span> <span class="nn">jax.random</span> <span class="k">as</span> <span class="nn">jr</span>
  429. <span class="kn">import</span> <span class="nn">distrax</span>
  430. <span class="kn">import</span> <span class="nn">optax</span>
  431. <span class="kn">import</span> <span class="nn">jsl</span>
  432. <span class="kn">import</span> <span class="nn">ssm_jax</span>
  433. </pre></div>
  434. </div>
  435. </div>
  436. </div>
  437. <div class="cell docutils container">
  438. <div class="cell_input docutils container">
  439. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">inspect</span>
  440. <span class="kn">import</span> <span class="nn">inspect</span> <span class="k">as</span> <span class="nn">py_inspect</span>
  441. <span class="kn">import</span> <span class="nn">rich</span>
  442. <span class="kn">from</span> <span class="nn">rich</span> <span class="kn">import</span> <span class="n">inspect</span> <span class="k">as</span> <span class="n">r_inspect</span>
  443. <span class="kn">from</span> <span class="nn">rich</span> <span class="kn">import</span> <span class="nb">print</span> <span class="k">as</span> <span class="n">r_print</span>
  444. <span class="k">def</span> <span class="nf">print_source</span><span class="p">(</span><span class="n">fname</span><span class="p">):</span>
  445. <span class="n">r_print</span><span class="p">(</span><span class="n">py_inspect</span><span class="o">.</span><span class="n">getsource</span><span class="p">(</span><span class="n">fname</span><span class="p">))</span>
  446. </pre></div>
  447. </div>
  448. </div>
  449. </div>
  450. <div class="cell docutils container">
  451. <div class="cell_input docutils container">
  452. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># meta-data does not work yet in VScode</span>
  453. <span class="c1"># https://github.com/microsoft/vscode-jupyter/issues/1121</span>
  454. <span class="p">{</span>
  455. <span class="s2">&quot;tags&quot;</span><span class="p">:</span> <span class="p">[</span>
  456. <span class="s2">&quot;hide-cell&quot;</span>
  457. <span class="p">]</span>
  458. <span class="p">}</span>
  459. <span class="c1">### Install necessary libraries</span>
  460. <span class="k">try</span><span class="p">:</span>
  461. <span class="kn">import</span> <span class="nn">jax</span>
  462. <span class="k">except</span><span class="p">:</span>
  463. <span class="c1"># For cuda version, see https://github.com/google/jax#installation</span>
  464. <span class="o">%</span><span class="k">pip</span> install --upgrade &quot;jax[cpu]&quot;
  465. <span class="kn">import</span> <span class="nn">jax</span>
  466. <span class="k">try</span><span class="p">:</span>
  467. <span class="kn">import</span> <span class="nn">distrax</span>
  468. <span class="k">except</span><span class="p">:</span>
  469. <span class="o">%</span><span class="k">pip</span> install --upgrade distrax
  470. <span class="kn">import</span> <span class="nn">distrax</span>
  471. <span class="k">try</span><span class="p">:</span>
  472. <span class="kn">import</span> <span class="nn">jsl</span>
  473. <span class="k">except</span><span class="p">:</span>
  474. <span class="o">%</span><span class="k">pip</span> install git+https://github.com/probml/jsl
  475. <span class="kn">import</span> <span class="nn">jsl</span>
  476. <span class="k">try</span><span class="p">:</span>
  477. <span class="kn">import</span> <span class="nn">rich</span>
  478. <span class="k">except</span><span class="p">:</span>
  479. <span class="o">%</span><span class="k">pip</span> install rich
  480. <span class="kn">import</span> <span class="nn">rich</span>
  481. </pre></div>
  482. </div>
  483. </div>
  484. </div>
  485. <div class="cell docutils container">
  486. <div class="cell_input docutils container">
  487. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="p">{</span>
  488. <span class="s2">&quot;tags&quot;</span><span class="p">:</span> <span class="p">[</span>
  489. <span class="s2">&quot;hide-cell&quot;</span>
  490. <span class="p">]</span>
  491. <span class="p">}</span>
  492. <span class="c1">### Import standard libraries</span>
  493. <span class="kn">import</span> <span class="nn">abc</span>
  494. <span class="kn">from</span> <span class="nn">dataclasses</span> <span class="kn">import</span> <span class="n">dataclass</span>
  495. <span class="kn">import</span> <span class="nn">functools</span>
  496. <span class="kn">import</span> <span class="nn">itertools</span>
  497. <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span><span class="p">,</span> <span class="n">Callable</span><span class="p">,</span> <span class="n">NamedTuple</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Union</span><span class="p">,</span> <span class="n">Tuple</span>
  498. <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span>
  499. <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
  500. <span class="kn">import</span> <span class="nn">jax</span>
  501. <span class="kn">import</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span>
  502. <span class="kn">from</span> <span class="nn">jax</span> <span class="kn">import</span> <span class="n">lax</span><span class="p">,</span> <span class="n">vmap</span><span class="p">,</span> <span class="n">jit</span><span class="p">,</span> <span class="n">grad</span>
  503. <span class="kn">from</span> <span class="nn">jax.scipy.special</span> <span class="kn">import</span> <span class="n">logit</span>
  504. <span class="kn">from</span> <span class="nn">jax.nn</span> <span class="kn">import</span> <span class="n">softmax</span>
  505. <span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">partial</span>
  506. <span class="kn">from</span> <span class="nn">jax.random</span> <span class="kn">import</span> <span class="n">PRNGKey</span><span class="p">,</span> <span class="n">split</span>
  507. <span class="kn">import</span> <span class="nn">inspect</span>
  508. <span class="kn">import</span> <span class="nn">inspect</span> <span class="k">as</span> <span class="nn">py_inspect</span>
  509. <span class="kn">import</span> <span class="nn">rich</span>
  510. <span class="kn">from</span> <span class="nn">rich</span> <span class="kn">import</span> <span class="n">inspect</span> <span class="k">as</span> <span class="n">r_inspect</span>
  511. <span class="kn">from</span> <span class="nn">rich</span> <span class="kn">import</span> <span class="nb">print</span> <span class="k">as</span> <span class="n">r_print</span>
  512. <span class="k">def</span> <span class="nf">print_source</span><span class="p">(</span><span class="n">fname</span><span class="p">):</span>
  513. <span class="n">r_print</span><span class="p">(</span><span class="n">py_inspect</span><span class="o">.</span><span class="n">getsource</span><span class="p">(</span><span class="n">fname</span><span class="p">))</span>
  514. </pre></div>
  515. </div>
  516. </div>
  517. </div>
  518. <div class="cell docutils container">
  519. <div class="cell_input docutils container">
  520. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">ssm_jax</span>
  521. <span class="kn">from</span> <span class="nn">ssm_jax.hmm.models</span> <span class="kn">import</span> <span class="n">GaussianHMM</span>
  522. <span class="n">print_source</span><span class="p">(</span><span class="n">GaussianHMM</span><span class="p">)</span>
  523. </pre></div>
  524. </div>
  525. </div>
  526. <div class="cell_output docutils container">
  527. <div class="output text_html"><pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace">class GaussianHMM<span style="font-weight: bold">(</span>BaseHMM<span style="font-weight: bold">)</span>:
  528. def __init__<span style="font-weight: bold">(</span>self,
  529. initial_probabilities,
  530. transition_matrix,
  531. emission_means,
  532. emission_covariance_matrices<span style="font-weight: bold">)</span>:
  533. <span style="color: #008000; text-decoration-color: #008000">""</span>"_summary_
  534. Args:
  535. initial_probabilities <span style="font-weight: bold">(</span>_type_<span style="font-weight: bold">)</span>: _description_
  536. transition_matrix <span style="font-weight: bold">(</span>_type_<span style="font-weight: bold">)</span>: _description_
  537. emission_means <span style="font-weight: bold">(</span>_type_<span style="font-weight: bold">)</span>: _description_
  538. emission_covariance_matrices <span style="font-weight: bold">(</span>_type_<span style="font-weight: bold">)</span>: _description_
  539. <span style="color: #008000; text-decoration-color: #008000">""</span>"
  540. super<span style="font-weight: bold">()</span>.__init__<span style="font-weight: bold">(</span>initial_probabilities,
  541. transition_matrix<span style="font-weight: bold">)</span>
  542. self._emission_distribution = tfd.MultivariateNormalFullCovariance<span style="font-weight: bold">(</span>
  543. emission_means, emission_covariance_matrices<span style="font-weight: bold">)</span>
  544. @classmethod
  545. def random_initialization<span style="font-weight: bold">(</span>cls, key, num_states, emission_dim<span style="font-weight: bold">)</span>:
  546. key1, key2, key3 = jr.split<span style="font-weight: bold">(</span>key, <span style="color: #000080; text-decoration-color: #000080; font-weight: bold">3</span><span style="font-weight: bold">)</span>
  547. initial_probs = jr.dirichlet<span style="font-weight: bold">(</span>key1, jnp.ones<span style="font-weight: bold">(</span>num_states<span style="font-weight: bold">))</span>
  548. transition_matrix = jr.dirichlet<span style="font-weight: bold">(</span>key2, jnp.ones<span style="font-weight: bold">(</span>num_states<span style="font-weight: bold">)</span>, <span style="font-weight: bold">(</span>num_states,<span style="font-weight: bold">))</span>
  549. emission_means = jr.normal<span style="font-weight: bold">(</span>key3, <span style="font-weight: bold">(</span>num_states, emission_dim<span style="font-weight: bold">))</span>
  550. emission_covs = jnp.tile<span style="font-weight: bold">(</span>jnp.eye<span style="font-weight: bold">(</span>emission_dim<span style="font-weight: bold">)</span>, <span style="font-weight: bold">(</span>num_states, <span style="color: #000080; text-decoration-color: #000080; font-weight: bold">1</span>, <span style="color: #000080; text-decoration-color: #000080; font-weight: bold">1</span><span style="font-weight: bold">))</span>
  551. return cls<span style="font-weight: bold">(</span>initial_probs, transition_matrix, emission_means, emission_covs<span style="font-weight: bold">)</span>
  552. # Properties to get various parameters of the model
  553. @property
  554. def emission_distribution<span style="font-weight: bold">(</span>self<span style="font-weight: bold">)</span>:
  555. return self._emission_distribution
  556. @property
  557. def emission_means<span style="font-weight: bold">(</span>self<span style="font-weight: bold">)</span>:
  558. return self.emission_distribution.mean<span style="font-weight: bold">()</span>
  559. @property
  560. def emission_covariance_matrices<span style="font-weight: bold">(</span>self<span style="font-weight: bold">)</span>:
  561. return self.emission_distribution.covariance<span style="font-weight: bold">()</span>
  562. @property
  563. def unconstrained_params<span style="font-weight: bold">(</span>self<span style="font-weight: bold">)</span>:
  564. <span style="color: #008000; text-decoration-color: #008000">""</span>"Helper property to get a PyTree of unconstrained parameters.
  565. <span style="color: #008000; text-decoration-color: #008000">""</span>"
  566. return tfb.SoftmaxCentered<span style="font-weight: bold">()</span>.inverse<span style="font-weight: bold">(</span>self.initial_probabilities<span style="font-weight: bold">)</span>, \
  567. tfb.SoftmaxCentered<span style="font-weight: bold">()</span>.inverse<span style="font-weight: bold">(</span>self.transition_matrix<span style="font-weight: bold">)</span>, \
  568. self.emission_means, \
  569. PSDToRealBijector.forward<span style="font-weight: bold">(</span>self.emission_covariance_matrices<span style="font-weight: bold">)</span>
  570. @classmethod
  571. def from_unconstrained_params<span style="font-weight: bold">(</span>cls, unconstrained_params, hypers<span style="font-weight: bold">)</span>:
  572. initial_probabilities = tfb.SoftmaxCentered<span style="font-weight: bold">()</span>.forward<span style="font-weight: bold">(</span>unconstrained_params<span style="font-weight: bold">[</span><span style="color: #000080; text-decoration-color: #000080; font-weight: bold">0</span><span style="font-weight: bold">])</span>
  573. transition_matrix = tfb.SoftmaxCentered<span style="font-weight: bold">()</span>.forward<span style="font-weight: bold">(</span>unconstrained_params<span style="font-weight: bold">[</span><span style="color: #000080; text-decoration-color: #000080; font-weight: bold">1</span><span style="font-weight: bold">])</span>
  574. emission_means = unconstrained_params<span style="font-weight: bold">[</span><span style="color: #000080; text-decoration-color: #000080; font-weight: bold">2</span><span style="font-weight: bold">]</span>
  575. emission_covs = PSDToRealBijector.inverse<span style="font-weight: bold">(</span>unconstrained_params<span style="font-weight: bold">[</span><span style="color: #000080; text-decoration-color: #000080; font-weight: bold">3</span><span style="font-weight: bold">])</span>
  576. return cls<span style="font-weight: bold">(</span>initial_probabilities, transition_matrix, emission_means, emission_covs,
  577. *hypers<span style="font-weight: bold">)</span>
  578. </pre>
  579. </div></div>
  580. </div>
  581. <div class="cell docutils container">
  582. <div class="cell_input docutils container">
  583. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># Set dimensions</span>
  584. <span class="n">num_states</span> <span class="o">=</span> <span class="mi">5</span>
  585. <span class="n">emission_dim</span> <span class="o">=</span> <span class="mi">2</span>
  586. <span class="c1"># Specify parameters of the HMM</span>
  587. <span class="n">initial_probs</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">num_states</span><span class="p">)</span> <span class="o">/</span> <span class="n">num_states</span>
  588. <span class="n">transition_matrix</span> <span class="o">=</span> <span class="mf">0.95</span> <span class="o">*</span> <span class="n">jnp</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="n">num_states</span><span class="p">)</span> <span class="o">+</span> <span class="mf">0.05</span> <span class="o">*</span> <span class="n">jnp</span><span class="o">.</span><span class="n">roll</span><span class="p">(</span><span class="n">jnp</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="n">num_states</span><span class="p">),</span> <span class="mi">1</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
  589. <span class="n">emission_means</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">column_stack</span><span class="p">([</span>
  590. <span class="n">jnp</span><span class="o">.</span><span class="n">cos</span><span class="p">(</span><span class="n">jnp</span><span class="o">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">jnp</span><span class="o">.</span><span class="n">pi</span><span class="p">,</span> <span class="n">num_states</span><span class="o">+</span><span class="mi">1</span><span class="p">))[:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span>
  591. <span class="n">jnp</span><span class="o">.</span><span class="n">sin</span><span class="p">(</span><span class="n">jnp</span><span class="o">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">jnp</span><span class="o">.</span><span class="n">pi</span><span class="p">,</span> <span class="n">num_states</span><span class="o">+</span><span class="mi">1</span><span class="p">))[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
  592. <span class="p">])</span>
  593. <span class="n">emission_covs</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">tile</span><span class="p">(</span><span class="mf">0.1</span><span class="o">**</span><span class="mi">2</span> <span class="o">*</span> <span class="n">jnp</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="n">emission_dim</span><span class="p">),</span> <span class="p">(</span><span class="n">num_states</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
  594. <span class="n">hmm</span> <span class="o">=</span> <span class="n">GaussianHMM</span><span class="p">(</span><span class="n">initial_probs</span><span class="p">,</span>
  595. <span class="n">transition_matrix</span><span class="p">,</span>
  596. <span class="n">emission_means</span><span class="p">,</span>
  597. <span class="n">emission_covs</span><span class="p">)</span>
  598. <span class="n">print_source</span><span class="p">(</span><span class="n">hmm</span><span class="o">.</span><span class="n">sample</span><span class="p">)</span>
  599. </pre></div>
  600. </div>
  601. </div>
  602. <div class="cell_output docutils container">
  603. <div class="output stderr highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
  604. </pre></div>
  605. </div>
  606. <div class="output text_html"><pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"> def sample<span style="font-weight: bold">(</span>self, key, num_timesteps<span style="font-weight: bold">)</span>:
  607. <span style="color: #008000; text-decoration-color: #008000">""</span>"Sample a sequence of latent states and emissions.
  608. Args:
  609. key <span style="font-weight: bold">(</span>_type_<span style="font-weight: bold">)</span>: _description_
  610. num_timesteps <span style="font-weight: bold">(</span>_type_<span style="font-weight: bold">)</span>: _description_
  611. <span style="color: #008000; text-decoration-color: #008000">""</span>"
  612. def _step<span style="font-weight: bold">(</span>state, key<span style="font-weight: bold">)</span>:
  613. key1, key2 = jr.split<span style="font-weight: bold">(</span>key, <span style="color: #000080; text-decoration-color: #000080; font-weight: bold">2</span><span style="font-weight: bold">)</span>
  614. emission = self.emission_distribution.sample<span style="font-weight: bold">(</span><span style="color: #808000; text-decoration-color: #808000">seed</span>=<span style="color: #800080; text-decoration-color: #800080">key1</span><span style="font-weight: bold">)</span>
  615. next_state = self.transition_distribution.sample<span style="font-weight: bold">(</span><span style="color: #808000; text-decoration-color: #808000">seed</span>=<span style="color: #800080; text-decoration-color: #800080">key2</span><span style="font-weight: bold">)</span>
  616. return next_state, <span style="font-weight: bold">(</span>state, emission<span style="font-weight: bold">)</span>
  617. # Sample the initial state
  618. key1, key = jr.split<span style="font-weight: bold">(</span>key, <span style="color: #000080; text-decoration-color: #000080; font-weight: bold">2</span><span style="font-weight: bold">)</span>
  619. initial_state = self.initial_distribution.sample<span style="font-weight: bold">(</span><span style="color: #808000; text-decoration-color: #808000">seed</span>=<span style="color: #800080; text-decoration-color: #800080">key1</span><span style="font-weight: bold">)</span>
  620. # Sample the remaining emissions and states
  621. keys = jr.split<span style="font-weight: bold">(</span>key, num_timesteps<span style="font-weight: bold">)</span>
  622. _, <span style="font-weight: bold">(</span>states, emissions<span style="font-weight: bold">)</span> = lax.scan<span style="font-weight: bold">(</span>_step, initial_state, keys<span style="font-weight: bold">)</span>
  623. return states, emissions
  624. </pre>
  625. </div></div>
  626. </div>
  627. <div class="cell docutils container">
  628. <div class="cell_input docutils container">
  629. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">distrax</span>
  630. <span class="kn">from</span> <span class="nn">distrax</span> <span class="kn">import</span> <span class="n">HMM</span>
  631. <span class="n">A</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span>
  632. <span class="p">[</span><span class="mf">0.95</span><span class="p">,</span> <span class="mf">0.05</span><span class="p">],</span>
  633. <span class="p">[</span><span class="mf">0.10</span><span class="p">,</span> <span class="mf">0.90</span><span class="p">]</span>
  634. <span class="p">])</span>
  635. <span class="c1"># observation matrix</span>
  636. <span class="n">B</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span>
  637. <span class="p">[</span><span class="mi">1</span><span class="o">/</span><span class="mi">6</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">6</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">6</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">6</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">6</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">6</span><span class="p">],</span> <span class="c1"># fair die</span>
  638. <span class="p">[</span><span class="mi">1</span><span class="o">/</span><span class="mi">10</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">10</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">10</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">10</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">10</span><span class="p">,</span> <span class="mi">5</span><span class="o">/</span><span class="mi">10</span><span class="p">]</span> <span class="c1"># loaded die</span>
  639. <span class="p">])</span>
  640. <span class="n">pi</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">])</span>
  641. <span class="p">(</span><span class="n">nstates</span><span class="p">,</span> <span class="n">nobs</span><span class="p">)</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">B</span><span class="p">)</span>
  642. <span class="n">hmm</span> <span class="o">=</span> <span class="n">HMM</span><span class="p">(</span><span class="n">trans_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">Categorical</span><span class="p">(</span><span class="n">probs</span><span class="o">=</span><span class="n">A</span><span class="p">),</span>
  643. <span class="n">init_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">Categorical</span><span class="p">(</span><span class="n">probs</span><span class="o">=</span><span class="n">pi</span><span class="p">),</span>
  644. <span class="n">obs_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">Categorical</span><span class="p">(</span><span class="n">probs</span><span class="o">=</span><span class="n">B</span><span class="p">))</span>
  645. <span class="nb">print</span><span class="p">(</span><span class="n">hmm</span><span class="p">)</span>
  646. </pre></div>
  647. </div>
  648. </div>
  649. <div class="cell_output docutils container">
  650. <div class="output stream highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>&lt;distrax._src.utils.hmm.HMM object at 0x7fde82c856d0&gt;
  651. </pre></div>
  652. </div>
  653. </div>
  654. </div>
  655. <div class="cell docutils container">
  656. <div class="cell_input docutils container">
  657. <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">print_source</span><span class="p">(</span><span class="n">hmm</span><span class="o">.</span><span class="n">sample</span><span class="p">)</span>
  658. </pre></div>
  659. </div>
  660. </div>
  661. <div class="cell_output docutils container">
  662. <div class="output text_html"><pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"> def sample<span style="font-weight: bold">(</span>self,
  663. *,
  664. seed: chex.PRNGKey,
  665. seq_len: chex.Array<span style="font-weight: bold">)</span> -&gt; Tuple:
  666. <span style="color: #008000; text-decoration-color: #008000">""</span>"Sample from this HMM.
  667. Samples an observation of given length according to this
  668. Hidden Markov Model and gives the sequence of the hidden states
  669. as well as the observation.
  670. Args:
  671. seed: Random key of shape <span style="font-weight: bold">(</span><span style="color: #000080; text-decoration-color: #000080; font-weight: bold">2</span>,<span style="font-weight: bold">)</span> and dtype uint32.
  672. seq_len: The length of the observation sequence.
  673. Returns:
  674. Tuple of hidden state sequence, and observation sequence.
  675. <span style="color: #008000; text-decoration-color: #008000">""</span>"
  676. rng_key, rng_init = jax.random.split<span style="font-weight: bold">(</span>seed<span style="font-weight: bold">)</span>
  677. initial_state = self._init_dist.sample<span style="font-weight: bold">(</span><span style="color: #808000; text-decoration-color: #808000">seed</span>=<span style="color: #800080; text-decoration-color: #800080">rng_init</span><span style="font-weight: bold">)</span>
  678. def draw_state<span style="font-weight: bold">(</span>prev_state, key<span style="font-weight: bold">)</span>:
  679. state = self._trans_dist.sample<span style="font-weight: bold">(</span><span style="color: #808000; text-decoration-color: #808000">seed</span>=<span style="color: #800080; text-decoration-color: #800080">key</span><span style="font-weight: bold">)</span>
  680. return state, state
  681. rng_state, rng_obs = jax.random.split<span style="font-weight: bold">(</span>rng_key<span style="font-weight: bold">)</span>
  682. keys = jax.random.split<span style="font-weight: bold">(</span>rng_state, seq_len - <span style="color: #000080; text-decoration-color: #000080; font-weight: bold">1</span><span style="font-weight: bold">)</span>
  683. _, states = jax.lax.scan<span style="font-weight: bold">(</span>draw_state, initial_state, keys<span style="font-weight: bold">)</span>
  684. states = jnp.append<span style="font-weight: bold">(</span>initial_state, states<span style="font-weight: bold">)</span>
  685. def draw_obs<span style="font-weight: bold">(</span>state, key<span style="font-weight: bold">)</span>:
  686. return self._obs_dist.sample<span style="font-weight: bold">(</span><span style="color: #808000; text-decoration-color: #808000">seed</span>=<span style="color: #800080; text-decoration-color: #800080">key</span><span style="font-weight: bold">)</span>
  687. keys = jax.random.split<span style="font-weight: bold">(</span>rng_obs, seq_len<span style="font-weight: bold">)</span>
  688. obs_seq = jax.vmap<span style="font-weight: bold">(</span>draw_obs, <span style="color: #808000; text-decoration-color: #808000">in_axes</span>=<span style="font-weight: bold">(</span><span style="color: #000080; text-decoration-color: #000080; font-weight: bold">0</span>, <span style="color: #000080; text-decoration-color: #000080; font-weight: bold">0</span><span style="font-weight: bold">))(</span>states, keys<span style="font-weight: bold">)</span>
  689. return states, obs_seq
  690. </pre>
  691. </div></div>
  692. </div>
  693. <script type="text/x-thebe-config">
  694. {
  695. requestKernel: true,
  696. binderOptions: {
  697. repo: "binder-examples/jupyter-stacks-datascience",
  698. ref: "master",
  699. },
  700. codeMirrorConfig: {
  701. theme: "abcdef",
  702. mode: "python"
  703. },
  704. kernelOptions: {
  705. kernelName: "python3",
  706. path: "./chapters"
  707. },
  708. predefinedOutput: true
  709. }
  710. </script>
  711. <script>kernelName = 'python3'</script>
  712. </div>
  713. <!-- Previous / next buttons -->
  714. <div class='prev-next-area'>
  715. </div>
  716. </div>
  717. </div>
  718. <footer class="footer">
  719. <p>
  720. By Kevin Murphy, Scott Linderman, et al.<br/>
  721. &copy; Copyright 2021.<br/>
  722. </p>
  723. </footer>
  724. </main>
  725. </div>
  726. </div>
  727. <script src="../_static/js/index.be7d3bbb2ef33a8344ce.js"></script>
  728. </body>
  729. </html>