瀏覽代碼

fic typos

Kevin P Murphy 3 年之前
父節點
當前提交
4ab7e92009
共有 100 個文件被更改,包括 5088 次插入4933 次删除
  1. 二進制
      _build/.doctrees/bib.doctree
  2. 二進制
      _build/.doctrees/chapters/hmm/hmm_filter.doctree
  3. 二進制
      _build/.doctrees/chapters/hmm/hmm_index.doctree
  4. 二進制
      _build/.doctrees/chapters/scratch.doctree
  5. 二進制
      _build/.doctrees/chapters/scratchpad.doctree
  6. 二進制
      _build/.doctrees/chapters/ssm/hmm.doctree
  7. 二進制
      _build/.doctrees/chapters/ssm/inference.doctree
  8. 二進制
      _build/.doctrees/chapters/ssm/lds.doctree
  9. 二進制
      _build/.doctrees/chapters/ssm/learning.doctree
  10. 二進制
      _build/.doctrees/chapters/ssm/nlds.doctree
  11. 二進制
      _build/.doctrees/chapters/ssm/ssm_index.doctree
  12. 二進制
      _build/.doctrees/chapters/ssm/ssm_intro.doctree
  13. 二進制
      _build/.doctrees/environment.pickle
  14. 二進制
      _build/.doctrees/root.doctree
  15. 1 1
      _build/html/.buildinfo
  16. 8 35
      _build/html/README.html
  17. 二進制
      _build/html/_images/hmmDgmPlatesY.png
  18. 二進制
      _build/html/_images/hmm_15_0.png
  19. 二進制
      _build/html/_images/hmm_16_1.png
  20. 二進制
      _build/html/_images/inference_14_1.png
  21. 二進制
      _build/html/_images/inference_15_1.png
  22. 二進制
      _build/html/_images/inference_7_1.png
  23. 二進制
      _build/html/_images/inference_8_1.png
  24. 二進制
      _build/html/_images/inference_9_1.png
  25. 二進制
      _build/html/_images/lds_5_1.png
  26. 二進制
      _build/html/_images/lds_6_1.png
  27. 253 405
      _build/html/_sources/chapters/hmm/hmm_filter.ipynb
  28. 1 1
      _build/html/_sources/chapters/hmm/hmm_index.md
  29. 6 6
      _build/html/_sources/chapters/scratch.ipynb
  30. 1 1
      _build/html/_sources/chapters/scratch.md
  31. 552 0
      _build/html/_sources/chapters/scratchpad.ipynb
  32. 38 243
      _build/html/_sources/chapters/ssm/hmm.ipynb
  33. 30 244
      _build/html/_sources/chapters/ssm/inference.ipynb
  34. 65 311
      _build/html/_sources/chapters/ssm/lds.ipynb
  35. 115 0
      _build/html/_sources/chapters/ssm/learning.ipynb
  36. 18 189
      _build/html/_sources/chapters/ssm/nlds.ipynb
  37. 21 186
      _build/html/_sources/chapters/ssm/ssm_intro.ipynb
  38. 11 38
      _build/html/bib.html
  39. 8 35
      _build/html/chapters/adf/adf_index.html
  40. 2 29
      _build/html/chapters/blank.html
  41. 8 35
      _build/html/chapters/bnp/bnp_index.html
  42. 8 35
      _build/html/chapters/changepoint/changepoint_index.html
  43. 8 35
      _build/html/chapters/control/control_index.html
  44. 8 35
      _build/html/chapters/ensemble/ensemble_index.html
  45. 8 35
      _build/html/chapters/extended/extended_filter.html
  46. 8 35
      _build/html/chapters/extended/extended_index.html
  47. 8 35
      _build/html/chapters/extended/extended_parallel.html
  48. 8 35
      _build/html/chapters/extended/extended_smoother.html
  49. 8 35
      _build/html/chapters/gp/gp_index.html
  50. 2 29
      _build/html/chapters/hmm/hmm.html
  51. 283 311
      _build/html/chapters/hmm/hmm_filter.html
  52. 11 11
      _build/html/chapters/hmm/hmm_index.html
  53. 8 35
      _build/html/chapters/hmm/hmm_parallel.html
  54. 10 37
      _build/html/chapters/hmm/hmm_sampling.html
  55. 8 35
      _build/html/chapters/hmm/hmm_smoother.html
  56. 8 35
      _build/html/chapters/hmm/hmm_viterbi.html
  57. 6 6
      _build/html/chapters/learning/learning_index.html
  58. 10 37
      _build/html/chapters/lgssm/kalman_filter.html
  59. 8 35
      _build/html/chapters/lgssm/kalman_parallel.html
  60. 8 35
      _build/html/chapters/lgssm/kalman_sampling.html
  61. 8 35
      _build/html/chapters/lgssm/kalman_smoother.html
  62. 6 6
      _build/html/chapters/lgssm/lgssm_index.html
  63. 8 35
      _build/html/chapters/ode/ode_index.html
  64. 8 35
      _build/html/chapters/pf/pf_index.html
  65. 8 35
      _build/html/chapters/postlin/postlin_index.html
  66. 8 35
      _build/html/chapters/quadrature/quadrature_index.html
  67. 18 61
      _build/html/chapters/scratch.html
  68. 856 0
      _build/html/chapters/scratchpad.html
  69. 8 35
      _build/html/chapters/smc/smc_index.html
  70. 63 255
      _build/html/chapters/ssm/hmm.html
  71. 50 230
      _build/html/chapters/ssm/inference.html
  72. 87 256
      _build/html/chapters/ssm/lds.html
  73. 596 0
      _build/html/chapters/ssm/learning.html
  74. 42 212
      _build/html/chapters/ssm/nlds.html
  75. 11 10
      _build/html/chapters/ssm/ssm_index.html
  76. 34 191
      _build/html/chapters/ssm/ssm_intro.html
  77. 8 35
      _build/html/chapters/timeseries/timeseries_index.html
  78. 11 38
      _build/html/chapters/tracking/tracking_index.html
  79. 8 35
      _build/html/chapters/unscented/unscented_filter.html
  80. 8 35
      _build/html/chapters/unscented/unscented_index.html
  81. 8 35
      _build/html/chapters/unscented/unscented_smoother.html
  82. 8 35
      _build/html/chapters/vi/vi_index.html
  83. 6 6
      _build/html/genindex.html
  84. 二進制
      _build/html/objects.inv
  85. 75 0
      _build/html/reports/hmm.log
  86. 37 0
      _build/html/reports/hmm_filter.log
  87. 104 0
      _build/html/reports/inference.log
  88. 154 0
      _build/html/reports/lds.log
  89. 102 0
      _build/html/reports/learning.log
  90. 154 0
      _build/html/reports/nlds.log
  91. 102 0
      _build/html/reports/scratchpad.log
  92. 11 11
      _build/html/root.html
  93. 6 6
      _build/html/search.html
  94. 1 1
      _build/html/searchindex.js
  95. 252 346
      _build/jupyter_execute/chapters/hmm/hmm_filter.ipynb
  96. 217 267
      _build/jupyter_execute/chapters/hmm/hmm_filter.py
  97. 6 6
      _build/jupyter_execute/chapters/scratch.ipynb
  98. 1 1
      _build/jupyter_execute/chapters/scratch.py
  99. 434 0
      _build/jupyter_execute/chapters/scratchpad.ipynb
  100. 0 0
      _build/jupyter_execute/chapters/scratchpad.py

二進制
_build/.doctrees/bib.doctree


二進制
_build/.doctrees/chapters/hmm/hmm_filter.doctree


二進制
_build/.doctrees/chapters/hmm/hmm_index.doctree


二進制
_build/.doctrees/chapters/scratch.doctree


二進制
_build/.doctrees/chapters/scratchpad.doctree


二進制
_build/.doctrees/chapters/ssm/hmm.doctree


二進制
_build/.doctrees/chapters/ssm/inference.doctree


二進制
_build/.doctrees/chapters/ssm/lds.doctree


二進制
_build/.doctrees/chapters/ssm/learning.doctree


二進制
_build/.doctrees/chapters/ssm/nlds.doctree


二進制
_build/.doctrees/chapters/ssm/ssm_index.doctree


二進制
_build/.doctrees/chapters/ssm/ssm_intro.doctree


二進制
_build/.doctrees/environment.pickle


二進制
_build/.doctrees/root.doctree


+ 1 - 1
_build/html/.buildinfo

@@ -1,4 +1,4 @@
 # Sphinx build info version 1
 # This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done.
-config: e1b22480b825cc4d693369dcb5b3d1e3
+config: 70a1b193a2154c8b9a08d7b32062e5d6
 tags: 645f666f9bcd5a90fca523b33c5a78b7

+ 8 - 35
_build/html/README.html

@@ -92,11 +92,6 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
 </ul>
 <ul class="nav bd-sidenav">
- <li class="toctree-l1">
-  <a class="reference internal" href="chapters/scratch.html">
-   Scratchpad
-  </a>
- </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="chapters/ssm/ssm_index.html">
    State Space Models
@@ -129,7 +124,12 @@ const thebe_selector_output = ".output, .cell_output"
    </li>
    <li class="toctree-l2">
     <a class="reference internal" href="chapters/ssm/inference.html">
-     Inferential goals
+     States estimation (inference)
+    </a>
+   </li>
+   <li class="toctree-l2">
+    <a class="reference internal" href="chapters/ssm/learning.html">
+     Parameter estimation (learning)
     </a>
    </li>
   </ul>
@@ -173,7 +173,7 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="chapters/lgssm/lgssm_index.html">
-   Inference in linear-Gaussian SSMs
+   Linear-Gaussian SSMs
   </a>
   <input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
   <label for="toctree-checkbox-3">
@@ -282,37 +282,10 @@ const thebe_selector_output = ".output, .cell_output"
    Sequential Monte Carlo
   </a>
  </li>
- <li class="toctree-l1 has-children">
+ <li class="toctree-l1">
   <a class="reference internal" href="chapters/learning/learning_index.html">
    Offline parameter estimation (learning)
   </a>
-  <input class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" type="checkbox"/>
-  <label for="toctree-checkbox-6">
-   <i class="fas fa-chevron-down">
-   </i>
-  </label>
-  <ul>
-   <li class="toctree-l2">
-    <a class="reference internal" href="chapters/learning/em.html">
-     Expectation Maximization (EM)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="chapters/learning/sgd.html">
-     Stochastic Gradient Descent (SGD)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="chapters/learning/vb.html">
-     Variational Bayes (VB)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="chapters/learning/mcmc.html">
-     Markov Chain Monte Carlo (MCMC)
-    </a>
-   </li>
-  </ul>
  </li>
  <li class="toctree-l1">
   <a class="reference internal" href="chapters/tracking/tracking_index.html">

二進制
_build/html/_images/hmmDgmPlatesY.png


二進制
_build/html/_images/hmm_15_0.png


二進制
_build/html/_images/hmm_16_1.png


二進制
_build/html/_images/inference_14_1.png


二進制
_build/html/_images/inference_15_1.png


二進制
_build/html/_images/inference_7_1.png


二進制
_build/html/_images/inference_8_1.png


二進制
_build/html/_images/inference_9_1.png


二進制
_build/html/_images/lds_5_1.png


二進制
_build/html/_images/lds_6_1.png


+ 253 - 405
_build/html/_sources/chapters/hmm/hmm_filter.ipynb

@@ -4,171 +4,8 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "```{math}\n",
-    "\n",
-    "\\newcommand{\\defeq}{\\triangleq}\n",
-    "\\newcommand{\\trans}{{\\mkern-1.5mu\\mathsf{T}}}\n",
-    "\\newcommand{\\transpose}[1]{{#1}^{\\trans}}\n",
-    "\n",
-    "\\newcommand{\\inv}[1]{{#1}^{-1}}\n",
-    "\\DeclareMathOperator{\\dotstar}{\\odot}\n",
-    "\n",
-    "\n",
-    "\\newcommand\\floor[1]{\\lfloor#1\\rfloor}\n",
-    "\n",
-    "\\newcommand{\\real}{\\mathbb{R}}\n",
-    "\n",
-    "% Numbers\n",
-    "\\newcommand{\\vzero}{\\boldsymbol{0}}\n",
-    "\\newcommand{\\vone}{\\boldsymbol{1}}\n",
-    "\n",
-    "% Greek https://www.latex-tutorial.com/symbols/greek-alphabet/\n",
-    "\\newcommand{\\valpha}{\\boldsymbol{\\alpha}}\n",
-    "\\newcommand{\\vbeta}{\\boldsymbol{\\beta}}\n",
-    "\\newcommand{\\vchi}{\\boldsymbol{\\chi}}\n",
-    "\\newcommand{\\vdelta}{\\boldsymbol{\\delta}}\n",
-    "\\newcommand{\\vDelta}{\\boldsymbol{\\Delta}}\n",
-    "\\newcommand{\\vepsilon}{\\boldsymbol{\\epsilon}}\n",
-    "\\newcommand{\\vzeta}{\\boldsymbol{\\zeta}}\n",
-    "\\newcommand{\\vXi}{\\boldsymbol{\\Xi}}\n",
-    "\\newcommand{\\vell}{\\boldsymbol{\\ell}}\n",
-    "\\newcommand{\\veta}{\\boldsymbol{\\eta}}\n",
-    "%\\newcommand{\\vEta}{\\boldsymbol{\\Eta}}\n",
-    "\\newcommand{\\vgamma}{\\boldsymbol{\\gamma}}\n",
-    "\\newcommand{\\vGamma}{\\boldsymbol{\\Gamma}}\n",
-    "\\newcommand{\\vmu}{\\boldsymbol{\\mu}}\n",
-    "\\newcommand{\\vmut}{\\boldsymbol{\\tilde{\\mu}}}\n",
-    "\\newcommand{\\vnu}{\\boldsymbol{\\nu}}\n",
-    "\\newcommand{\\vkappa}{\\boldsymbol{\\kappa}}\n",
-    "\\newcommand{\\vlambda}{\\boldsymbol{\\lambda}}\n",
-    "\\newcommand{\\vLambda}{\\boldsymbol{\\Lambda}}\n",
-    "\\newcommand{\\vLambdaBar}{\\overline{\\vLambda}}\n",
-    "%\\newcommand{\\vnu}{\\boldsymbol{\\nu}}\n",
-    "\\newcommand{\\vomega}{\\boldsymbol{\\omega}}\n",
-    "\\newcommand{\\vOmega}{\\boldsymbol{\\Omega}}\n",
-    "\\newcommand{\\vphi}{\\boldsymbol{\\phi}}\n",
-    "\\newcommand{\\vvarphi}{\\boldsymbol{\\varphi}}\n",
-    "\\newcommand{\\vPhi}{\\boldsymbol{\\Phi}}\n",
-    "\\newcommand{\\vpi}{\\boldsymbol{\\pi}}\n",
-    "\\newcommand{\\vPi}{\\boldsymbol{\\Pi}}\n",
-    "\\newcommand{\\vpsi}{\\boldsymbol{\\psi}}\n",
-    "\\newcommand{\\vPsi}{\\boldsymbol{\\Psi}}\n",
-    "\\newcommand{\\vrho}{\\boldsymbol{\\rho}}\n",
-    "\\newcommand{\\vtheta}{\\boldsymbol{\\theta}}\n",
-    "\\newcommand{\\vthetat}{\\boldsymbol{\\tilde{\\theta}}}\n",
-    "\\newcommand{\\vTheta}{\\boldsymbol{\\Theta}}\n",
-    "\\newcommand{\\vsigma}{\\boldsymbol{\\sigma}}\n",
-    "\\newcommand{\\vSigma}{\\boldsymbol{\\Sigma}}\n",
-    "\\newcommand{\\vSigmat}{\\boldsymbol{\\tilde{\\Sigma}}}\n",
-    "\\newcommand{\\vsigmoid}{\\vsigma}\n",
-    "\\newcommand{\\vtau}{\\boldsymbol{\\tau}}\n",
-    "\\newcommand{\\vxi}{\\boldsymbol{\\xi}}\n",
-    "\n",
-    "\n",
-    "% Lower Roman (Vectors)\n",
-    "\\newcommand{\\va}{\\mathbf{a}}\n",
-    "\\newcommand{\\vb}{\\mathbf{b}}\n",
-    "\\newcommand{\\vBt}{\\mathbf{\\tilde{B}}}\n",
-    "\\newcommand{\\vc}{\\mathbf{c}}\n",
-    "\\newcommand{\\vct}{\\mathbf{\\tilde{c}}}\n",
-    "\\newcommand{\\vd}{\\mathbf{d}}\n",
-    "\\newcommand{\\ve}{\\mathbf{e}}\n",
-    "\\newcommand{\\vf}{\\mathbf{f}}\n",
-    "\\newcommand{\\vg}{\\mathbf{g}}\n",
-    "\\newcommand{\\vh}{\\mathbf{h}}\n",
-    "%\\newcommand{\\myvh}{\\mathbf{h}}\n",
-    "\\newcommand{\\vi}{\\mathbf{i}}\n",
-    "\\newcommand{\\vj}{\\mathbf{j}}\n",
-    "\\newcommand{\\vk}{\\mathbf{k}}\n",
-    "\\newcommand{\\vl}{\\mathbf{l}}\n",
-    "\\newcommand{\\vm}{\\mathbf{m}}\n",
-    "\\newcommand{\\vn}{\\mathbf{n}}\n",
-    "\\newcommand{\\vo}{\\mathbf{o}}\n",
-    "\\newcommand{\\vp}{\\mathbf{p}}\n",
-    "\\newcommand{\\vq}{\\mathbf{q}}\n",
-    "\\newcommand{\\vr}{\\mathbf{r}}\n",
-    "\\newcommand{\\vs}{\\mathbf{s}}\n",
-    "\\newcommand{\\vt}{\\mathbf{t}}\n",
-    "\\newcommand{\\vu}{\\mathbf{u}}\n",
-    "\\newcommand{\\vv}{\\mathbf{v}}\n",
-    "\\newcommand{\\vw}{\\mathbf{w}}\n",
-    "\\newcommand{\\vws}{\\vw_s}\n",
-    "\\newcommand{\\vwt}{\\mathbf{\\tilde{w}}}\n",
-    "\\newcommand{\\vWt}{\\mathbf{\\tilde{W}}}\n",
-    "\\newcommand{\\vwh}{\\hat{\\vw}}\n",
-    "\\newcommand{\\vx}{\\mathbf{x}}\n",
-    "%\\newcommand{\\vx}{\\mathbf{x}}\n",
-    "\\newcommand{\\vxt}{\\mathbf{\\tilde{x}}}\n",
-    "\\newcommand{\\vy}{\\mathbf{y}}\n",
-    "\\newcommand{\\vyt}{\\mathbf{\\tilde{y}}}\n",
-    "\\newcommand{\\vz}{\\mathbf{z}}\n",
-    "%\\newcommand{\\vzt}{\\mathbf{\\tilde{z}}}\n",
-    "\n",
-    "\n",
-    "% Upper Roman (Matrices)\n",
-    "\\newcommand{\\vA}{\\mathbf{A}}\n",
-    "\\newcommand{\\vB}{\\mathbf{B}}\n",
-    "\\newcommand{\\vC}{\\mathbf{C}}\n",
-    "\\newcommand{\\vD}{\\mathbf{D}}\n",
-    "\\newcommand{\\vE}{\\mathbf{E}}\n",
-    "\\newcommand{\\vF}{\\mathbf{F}}\n",
-    "\\newcommand{\\vG}{\\mathbf{G}}\n",
-    "\\newcommand{\\vH}{\\mathbf{H}}\n",
-    "\\newcommand{\\vI}{\\mathbf{I}}\n",
-    "\\newcommand{\\vJ}{\\mathbf{J}}\n",
-    "\\newcommand{\\vK}{\\mathbf{K}}\n",
-    "\\newcommand{\\vL}{\\mathbf{L}}\n",
-    "\\newcommand{\\vM}{\\mathbf{M}}\n",
-    "\\newcommand{\\vMt}{\\mathbf{\\tilde{M}}}\n",
-    "\\newcommand{\\vN}{\\mathbf{N}}\n",
-    "\\newcommand{\\vO}{\\mathbf{O}}\n",
-    "\\newcommand{\\vP}{\\mathbf{P}}\n",
-    "\\newcommand{\\vQ}{\\mathbf{Q}}\n",
-    "\\newcommand{\\vR}{\\mathbf{R}}\n",
-    "\\newcommand{\\vS}{\\mathbf{S}}\n",
-    "\\newcommand{\\vT}{\\mathbf{T}}\n",
-    "\\newcommand{\\vU}{\\mathbf{U}}\n",
-    "\\newcommand{\\vV}{\\mathbf{V}}\n",
-    "\\newcommand{\\vW}{\\mathbf{W}}\n",
-    "\\newcommand{\\vX}{\\mathbf{X}}\n",
-    "%\\newcommand{\\vXs}{\\vX_{\\vs}}\n",
-    "\\newcommand{\\vXs}{\\vX_{s}}\n",
-    "\\newcommand{\\vXt}{\\mathbf{\\tilde{X}}}\n",
-    "\\newcommand{\\vY}{\\mathbf{Y}}\n",
-    "\\newcommand{\\vZ}{\\mathbf{Z}}\n",
-    "\\newcommand{\\vZt}{\\mathbf{\\tilde{Z}}}\n",
-    "\\newcommand{\\vzt}{\\mathbf{\\tilde{z}}}\n",
-    "\n",
-    "\n",
-    "%%%%\n",
-    "\\newcommand{\\hidden}{\\vz}\n",
-    "\\newcommand{\\hid}{\\hidden}\n",
-    "\\newcommand{\\observed}{\\vy}\n",
-    "\\newcommand{\\obs}{\\observed}\n",
-    "\\newcommand{\\inputs}{\\vu}\n",
-    "\\newcommand{\\input}{\\inputs}\n",
-    "\n",
-    "\\newcommand{\\hmmTrans}{\\vA}\n",
-    "\\newcommand{\\hmmObs}{\\vB}\n",
-    "\\newcommand{\\hmmInit}{\\vpi}\n",
-    "\n",
-    "\n",
-    "\\newcommand{\\ldsDyn}{\\vA}\n",
-    "\\newcommand{\\ldsObs}{\\vC}\n",
-    "\\newcommand{\\ldsDynIn}{\\vB}\n",
-    "\\newcommand{\\ldsObsIn}{\\vD}\n",
-    "\\newcommand{\\ldsDynNoise}{\\vQ}\n",
-    "\\newcommand{\\ldsObsNoise}{\\vR}\n",
-    "\n",
-    "\\newcommand{\\ssmDynFn}{f}\n",
-    "\\newcommand{\\ssmObsFn}{h}\n",
-    "\n",
-    "\n",
-    "%%%\n",
-    "\\newcommand{\\gauss}{\\mathcal{N}}\n",
-    "\n",
-    "\\newcommand{\\diag}{\\mathrm{diag}}\n",
-    "```\n"
+    "(sec:forwards)=\n",
+    "# HMM filtering (forwards algorithm)"
    ]
   },
   {
@@ -177,141 +14,85 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "# meta-data does not work yet in VScode\n",
-    "# https://github.com/microsoft/vscode-jupyter/issues/1121\n",
-    "\n",
-    "{\n",
-    "    \"tags\": [\n",
-    "        \"hide-cell\"\n",
-    "    ]\n",
-    "}\n",
-    "\n",
-    "\n",
-    "### Install necessary libraries\n",
-    "\n",
-    "try:\n",
-    "    import jax\n",
-    "except:\n",
-    "    # For cuda version, see https://github.com/google/jax#installation\n",
-    "    %pip install --upgrade \"jax[cpu]\" \n",
-    "    import jax\n",
-    "\n",
-    "try:\n",
-    "    import distrax\n",
-    "except:\n",
-    "    %pip install --upgrade  distrax\n",
-    "    import distrax\n",
-    "\n",
-    "try:\n",
-    "    import jsl\n",
-    "except:\n",
-    "    %pip install git+https://github.com/probml/jsl\n",
-    "    import jsl\n",
-    "\n",
-    "#try:\n",
-    "#    import ssm_jax\n",
-    "##except:\n",
-    "#    %pip install git+https://github.com/probml/ssm-jax\n",
-    "#    import ssm_jax\n",
-    "\n",
-    "try:\n",
-    "    import rich\n",
-    "except:\n",
-    "    %pip install rich\n",
-    "    import rich\n",
-    "\n",
-    "\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 4,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "{\n",
-    "    \"tags\": [\n",
-    "        \"hide-cell\"\n",
-    "    ]\n",
-    "}\n",
-    "\n",
-    "\n",
     "### Import standard libraries\n",
     "\n",
     "import abc\n",
     "from dataclasses import dataclass\n",
     "import functools\n",
+    "from functools import partial\n",
     "import itertools\n",
-    "\n",
-    "from typing import Any, Callable, NamedTuple, Optional, Union, Tuple\n",
-    "\n",
     "import matplotlib.pyplot as plt\n",
     "import numpy as np\n",
-    "\n",
+    "from typing import Any, Callable, NamedTuple, Optional, Union, Tuple\n",
     "\n",
     "import jax\n",
     "import jax.numpy as jnp\n",
     "from jax import lax, vmap, jit, grad\n",
-    "from jax.scipy.special import logit\n",
-    "from jax.nn import softmax\n",
-    "from functools import partial\n",
-    "from jax.random import PRNGKey, split\n",
+    "#from jax.scipy.special import logit\n",
+    "#from jax.nn import softmax\n",
+    "import jax.random as jr\n",
     "\n",
-    "import inspect\n",
-    "import inspect as py_inspect\n",
-    "import rich\n",
-    "from rich import inspect as r_inspect\n",
-    "from rich import print as r_print\n",
     "\n",
-    "def print_source(fname):\n",
-    "    r_print(py_inspect.getsource(fname))"
+    "\n",
+    "import distrax\n",
+    "import optax\n",
+    "\n",
+    "import jsl\n",
+    "import ssm_jax"
    ]
   },
   {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "(sec:forwards)=\n",
-    "# HMM filtering (forwards algorithm)\n",
+    "\n",
+    "## Introduction\n",
     "\n",
     "\n",
-    "The  **Bayes filter** is an algorithm for recursively computing\n",
+    "The  $\\keyword{Bayes filter}$ is an algorithm for recursively computing\n",
     "the belief state\n",
     "$p(\\hidden_t|\\obs_{1:t})$ given\n",
     "the prior belief from the previous step,\n",
     "$p(\\hidden_{t-1}|\\obs_{1:t-1})$,\n",
     "the new observation $\\obs_t$,\n",
     "and the model.\n",
-    "This can be done using **sequential Bayesian updating**.\n",
+    "This can be done using $\\keyword{sequential Bayesian updating}$.\n",
     "For a dynamical model, this reduces to the\n",
-    "**predict-update** cycle described below.\n",
+    "$\\keyword{predict-update}$ cycle described below.\n",
     "\n",
-    "\n",
-    "The **prediction step** is just the **Chapman-Kolmogorov equation**:\n",
-    "```{math}\n",
+    "The $\\keyword{prediction step}$ is just the $\\keyword{Chapman-Kolmogorov equation}$:\n",
+    "\\begin{align}\n",
     "p(\\hidden_t|\\obs_{1:t-1})\n",
     "= \\int p(\\hidden_t|\\hidden_{t-1}) p(\\hidden_{t-1}|\\obs_{1:t-1}) d\\hidden_{t-1}\n",
-    "```\n",
+    "\\end{align}\n",
     "The prediction step computes\n",
-    "the one-step-ahead predictive distribution\n",
-    "for the latent state, which updates\n",
-    "the posterior from the previous time step into the prior\n",
+    "the $\\keyword{one-step-ahead predictive distribution}$\n",
+    "for the latent state, which converts\n",
+    "the posterior from the previous time step to become the prior\n",
     "for the current step.\n",
     "\n",
     "\n",
-    "The **update step**\n",
+    "The $\\keyword{update step}$\n",
     "is just Bayes rule:\n",
-    "```{math}\n",
+    "\\begin{align}\n",
     "p(\\hidden_t|\\obs_{1:t}) = \\frac{1}{Z_t}\n",
     "p(\\obs_t|\\hidden_t) p(\\hidden_t|\\obs_{1:t-1})\n",
-    "```\n",
+    "\\end{align}\n",
     "where the normalization constant is\n",
-    "```{math}\n",
+    "\\begin{align}\n",
     "Z_t = \\int p(\\obs_t|\\hidden_t) p(\\hidden_t|\\obs_{1:t-1}) d\\hidden_{t}\n",
     "= p(\\obs_t|\\obs_{1:t-1})\n",
-    "```\n",
+    "\\end{align}\n",
     "\n",
+    "Note that we can derive the log marginal likelihood from these normalization constants\n",
+    "as follows:\n",
+    "```{math}\n",
+    ":label: eqn:logZ\n",
     "\n",
+    "\\log p(\\obs_{1:T})\n",
+    "= \\sum_{t=1}^{T} \\log p(\\obs_t|\\obs_{1:t-1})\n",
+    "= \\sum_{t=1}^{T} \\log Z_t\n",
+    "```\n",
     "\n"
    ]
   },
@@ -323,10 +104,11 @@
     "When the latent states $\\hidden_t$ are discrete, as in HMM,\n",
     "the above integrals become sums.\n",
     "In particular, suppose we define\n",
-    "the belief state as $\\alpha_t(j) \\defeq p(\\hidden_t=j|\\obs_{1:t})$,\n",
-    "the local evidence as $\\lambda_t(j) \\defeq p(\\obs_t|\\hidden_t=j)$,\n",
-    "and the transition matrix\n",
-    "$A(i,j)  = p(\\hidden_t=j|\\hidden_{t-1}=i)$.\n",
+    "the $\\keyword{belief state}$ as $\\alpha_t(j) \\defeq p(\\hidden_t=j|\\obs_{1:t})$,\n",
+    "the  $\\keyword{local evidence}$ (or $\\keyword{local likelihood}$)\n",
+    "as $\\lambda_t(j) \\defeq p(\\obs_t|\\hidden_t=j)$,\n",
+    "and the transition matrix as\n",
+    "$\\hmmTrans(i,j)  = p(\\hidden_t=j|\\hidden_{t-1}=i)$.\n",
     "Then the predict step becomes\n",
     "```{math}\n",
     ":label: eqn:predictiveHMM\n",
@@ -338,7 +120,7 @@
     ":label: eqn:fwdsEqn\n",
     "\\alpha_t(j)\n",
     "= \\frac{1}{Z_t} \\lambda_t(j) \\alpha_{t|t-1}(j)\n",
-    "= \\frac{1}{Z_t} \\lambda_t(j) \\left[\\sum_i \\alpha_{t-1}(i) A(i,j)  \\right]\n",
+    "= \\frac{1}{Z_t} \\lambda_t(j) \\left[\\sum_i \\alpha_{t-1}(i) \\hmmTrans(i,j)  \\right]\n",
     "```\n",
     "where\n",
     "the  normalization constant for each time step is given by\n",
@@ -350,20 +132,33 @@
     "&=  \\sum_{j=1}^K \\lambda_t(j) \\alpha_{t|t-1}(j)\n",
     "\\end{align}\n",
     "```\n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
     "\n",
     "Since all the quantities are finite length vectors and matrices,\n",
-    "we can write the update equation\n",
-    "in matrix-vector notation as follows:\n",
+    "we can implement the whole procedure using matrix vector multoplication:\n",
     "```{math}\n",
+    ":label: eqn:fwdsAlgoMatrixForm\n",
     "\\valpha_t =\\text{normalize}\\left(\n",
-    "\\vlambda_t \\dotstar  (\\vA^{\\trans} \\valpha_{t-1}) \\right)\n",
-    "\\label{eqn:fwdsAlgoMatrixForm}\n",
+    "\\vlambda_t \\dotstar  (\\hmmTrans^{\\trans} \\valpha_{t-1}) \\right)\n",
     "```\n",
     "where $\\dotstar$ represents\n",
     "elementwise vector multiplication,\n",
-    "and the $\\text{normalize}$ function just ensures its argument sums to one.\n",
+    "and the $\\text{normalize}$ function just ensures its argument sums to one.\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Example\n",
     "\n",
-    "In {ref}(sec:casino-inference)\n",
+    "In {ref}`sec:casino-inference`\n",
     "we illustrate\n",
     "filtering for the casino HMM,\n",
     "applied to a random sequence $\\obs_{1:T}$ of length $T=300$.\n",
@@ -371,156 +166,209 @@
     "based on the evidence seen so far.\n",
     "The gray bars indicate time intervals during which the generative\n",
     "process actually switched to the loaded dice.\n",
-    "We see that the probability generally increases in the right places.\n"
+    "We see that the probability generally increases in the right places."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Normalization constants\n",
+    "\n",
+    "In most publications on HMMs,\n",
+    "such as {cite}`Rabiner89`,\n",
+    "the forwards message is defined\n",
+    "as the following unnormalized joint probability:\n",
+    "```{math}\n",
+    "\\alpha'_t(j) = p(\\hidden_t=j,\\obs_{1:t}) \n",
+    "= \\lambda_t(j) \\left[\\sum_i \\alpha'_{t-1}(i) A(i,j)  \\right]\n",
+    "```\n",
+    "In this book we define the forwards message   as the normalized\n",
+    "conditional probability\n",
+    "```{math}\n",
+    "\\alpha_t(j) = p(\\hidden_t=j|\\obs_{1:t}) \n",
+    "= \\frac{1}{Z_t} \\lambda_t(j) \\left[\\sum_i \\alpha_{t-1}(i) A(i,j)  \\right]\n",
+    "```\n",
+    "where $Z_t = p(\\obs_t|\\obs_{1:t-1})$.\n",
+    "\n",
+    "The \"traditional\" unnormalized form has several problems.\n",
+    "First, it rapidly suffers from numerical underflow,\n",
+    "since the probability of\n",
+    "the joint event that $(\\hidden_t=j,\\obs_{1:t})$\n",
+    "is vanishingly small. \n",
+    "To see why, suppose the observations are independent of the states.\n",
+    "In this case, the unnormalized joint has the form\n",
+    "\\begin{align}\n",
+    "p(\\hidden_t=j,\\obs_{1:t}) = p(\\hidden_t=j)\\prod_{i=1}^t p(\\obs_i)\n",
+    "\\end{align}\n",
+    "which becomes exponentially small with $t$, because we multiply\n",
+    "many probabilities which are less than one.\n",
+    "Second, the unnormalized probability is less interpretable,\n",
+    "since it is a joint distribution over states and observations,\n",
+    "rather than a conditional probability of states given observations.\n",
+    "Third, the unnormalized joint form is harder to approximate\n",
+    "than the normalized form.\n",
+    "Of course,\n",
+    "the two definitions only differ by a\n",
+    "multiplicative constant\n",
+    "{cite}`Devijver85`,\n",
+    "so the algorithmic difference is just\n",
+    "one line of code (namely the presence or absence of a call to the `normalize` function).\n",
+    "\n",
+    "\n",
+    "\n",
+    "\n"
    ]
   },
   {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "Here is a JAX implementation of the forwards algorithm."
+    "## Naive implementation\n",
+    "\n",
+    "Below we give a simple numpy implementation of the forwards algorithm.\n",
+    "We assume the HMM uses categorical observations, for simplicity.\n",
+    "\n"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 5,
+   "execution_count": null,
    "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/html": [
-       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">@jit\n",
-       "def <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">hmm_forwards_jax</span><span style=\"font-weight: bold\">(</span>params, obs_seq, <span style=\"color: #808000; text-decoration-color: #808000\">length</span>=<span style=\"color: #800080; text-decoration-color: #800080; font-style: italic\">None</span><span style=\"font-weight: bold\">)</span>:\n",
-       "    <span style=\"color: #008000; text-decoration-color: #008000\">''</span>'\n",
-       "    Calculates a belief state\n",
-       "\n",
-       "    Parameters\n",
-       "    ----------\n",
-       "    params : HMMJax\n",
-       "        Hidden Markov Model\n",
-       "\n",
-       "    obs_seq: <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">array</span><span style=\"font-weight: bold\">(</span>seq_len<span style=\"font-weight: bold\">)</span>\n",
-       "        History of observable events\n",
-       "\n",
-       "    Returns\n",
-       "    -------\n",
-       "    * float\n",
-       "        The loglikelihood giving <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">log</span><span style=\"font-weight: bold\">(</span><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">p</span><span style=\"font-weight: bold\">(</span>x|model<span style=\"font-weight: bold\">))</span>\n",
-       "\n",
-       "    * <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">array</span><span style=\"font-weight: bold\">(</span>seq_len, n_hidden<span style=\"font-weight: bold\">)</span> :\n",
-       "        All alpha values found for each sample\n",
-       "    <span style=\"color: #008000; text-decoration-color: #008000\">''</span>'\n",
-       "    seq_len = <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">len</span><span style=\"font-weight: bold\">(</span>obs_seq<span style=\"font-weight: bold\">)</span>\n",
-       "\n",
-       "    if length is <span style=\"color: #800080; text-decoration-color: #800080; font-style: italic\">None</span>:\n",
-       "        length = seq_len\n",
-       "\n",
-       "    trans_mat, obs_mat, init_dist = params.trans_mat, params.obs_mat, params.init_dist\n",
-       "\n",
-       "    trans_mat = <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">jnp.array</span><span style=\"font-weight: bold\">(</span>trans_mat<span style=\"font-weight: bold\">)</span>\n",
-       "    obs_mat = <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">jnp.array</span><span style=\"font-weight: bold\">(</span>obs_mat<span style=\"font-weight: bold\">)</span>\n",
-       "    init_dist = <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">jnp.array</span><span style=\"font-weight: bold\">(</span>init_dist<span style=\"font-weight: bold\">)</span>\n",
-       "\n",
-       "    n_states, n_obs = obs_mat.shape\n",
-       "\n",
-       "    def <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">scan_fn</span><span style=\"font-weight: bold\">(</span>carry, t<span style=\"font-weight: bold\">)</span>:\n",
-       "        <span style=\"font-weight: bold\">(</span>alpha_prev, log_ll_prev<span style=\"font-weight: bold\">)</span> = carry\n",
-       "        alpha_n = <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">jnp.where</span><span style=\"font-weight: bold\">(</span>t &lt; length,\n",
-       "                            obs_mat<span style=\"font-weight: bold\">[</span>:, obs_seq<span style=\"font-weight: bold\">]</span> * <span style=\"font-weight: bold\">(</span>alpha_prev<span style=\"font-weight: bold\">[</span>:, <span style=\"color: #800080; text-decoration-color: #800080; font-style: italic\">None</span><span style=\"font-weight: bold\">]</span> * \n",
-       "trans_mat<span style=\"font-weight: bold\">)</span><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">.sum</span><span style=\"font-weight: bold\">(</span><span style=\"color: #808000; text-decoration-color: #808000\">axis</span>=<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span><span style=\"font-weight: bold\">)</span>,\n",
-       "                            <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">jnp.zeros_like</span><span style=\"font-weight: bold\">(</span>alpha_prev<span style=\"font-weight: bold\">))</span>\n",
-       "\n",
-       "        alpha_n, cn = <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">normalize</span><span style=\"font-weight: bold\">(</span>alpha_n<span style=\"font-weight: bold\">)</span>\n",
-       "        carry = <span style=\"font-weight: bold\">(</span>alpha_n, <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">jnp.log</span><span style=\"font-weight: bold\">(</span>cn<span style=\"font-weight: bold\">)</span> + log_ll_prev<span style=\"font-weight: bold\">)</span>\n",
-       "\n",
-       "        return carry, alpha_n\n",
-       "\n",
-       "    # initial belief state\n",
-       "    alpha_0, c0 = <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">normalize</span><span style=\"font-weight: bold\">(</span>init_dist * obs_mat<span style=\"font-weight: bold\">[</span>:, obs_seq<span style=\"font-weight: bold\">[</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span><span style=\"font-weight: bold\">]])</span>\n",
-       "\n",
-       "    # setup scan loop\n",
-       "    init_state = <span style=\"font-weight: bold\">(</span>alpha_0, <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">jnp.log</span><span style=\"font-weight: bold\">(</span>c0<span style=\"font-weight: bold\">))</span>\n",
-       "    ts = <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">jnp.arange</span><span style=\"font-weight: bold\">(</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span>, seq_len<span style=\"font-weight: bold\">)</span>\n",
-       "    carry, alpha_hist = <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">lax.scan</span><span style=\"font-weight: bold\">(</span>scan_fn, init_state, ts<span style=\"font-weight: bold\">)</span>\n",
-       "\n",
-       "    # post-process\n",
-       "    alpha_hist = <span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">jnp.vstack</span><span style=\"font-weight: bold\">()</span>\n",
-       "    <span style=\"font-weight: bold\">(</span>alpha_final, log_ll<span style=\"font-weight: bold\">)</span> = carry\n",
-       "    return log_ll, alpha_hist\n",
-       "\n",
-       "</pre>\n"
-      ],
-      "text/plain": [
-       "@jit\n",
-       "def \u001b[1;35mhmm_forwards_jax\u001b[0m\u001b[1m(\u001b[0mparams, obs_seq, \u001b[33mlength\u001b[0m=\u001b[3;35mNone\u001b[0m\u001b[1m)\u001b[0m:\n",
-       "    \u001b[32m''\u001b[0m'\n",
-       "    Calculates a belief state\n",
-       "\n",
-       "    Parameters\n",
-       "    ----------\n",
-       "    params : HMMJax\n",
-       "        Hidden Markov Model\n",
-       "\n",
-       "    obs_seq: \u001b[1;35marray\u001b[0m\u001b[1m(\u001b[0mseq_len\u001b[1m)\u001b[0m\n",
-       "        History of observable events\n",
-       "\n",
-       "    Returns\n",
-       "    -------\n",
-       "    * float\n",
-       "        The loglikelihood giving \u001b[1;35mlog\u001b[0m\u001b[1m(\u001b[0m\u001b[1;35mp\u001b[0m\u001b[1m(\u001b[0mx|model\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m\n",
-       "\n",
-       "    * \u001b[1;35marray\u001b[0m\u001b[1m(\u001b[0mseq_len, n_hidden\u001b[1m)\u001b[0m :\n",
-       "        All alpha values found for each sample\n",
-       "    \u001b[32m''\u001b[0m'\n",
-       "    seq_len = \u001b[1;35mlen\u001b[0m\u001b[1m(\u001b[0mobs_seq\u001b[1m)\u001b[0m\n",
-       "\n",
-       "    if length is \u001b[3;35mNone\u001b[0m:\n",
-       "        length = seq_len\n",
-       "\n",
-       "    trans_mat, obs_mat, init_dist = params.trans_mat, params.obs_mat, params.init_dist\n",
-       "\n",
-       "    trans_mat = \u001b[1;35mjnp.array\u001b[0m\u001b[1m(\u001b[0mtrans_mat\u001b[1m)\u001b[0m\n",
-       "    obs_mat = \u001b[1;35mjnp.array\u001b[0m\u001b[1m(\u001b[0mobs_mat\u001b[1m)\u001b[0m\n",
-       "    init_dist = \u001b[1;35mjnp.array\u001b[0m\u001b[1m(\u001b[0minit_dist\u001b[1m)\u001b[0m\n",
-       "\n",
-       "    n_states, n_obs = obs_mat.shape\n",
-       "\n",
-       "    def \u001b[1;35mscan_fn\u001b[0m\u001b[1m(\u001b[0mcarry, t\u001b[1m)\u001b[0m:\n",
-       "        \u001b[1m(\u001b[0malpha_prev, log_ll_prev\u001b[1m)\u001b[0m = carry\n",
-       "        alpha_n = \u001b[1;35mjnp.where\u001b[0m\u001b[1m(\u001b[0mt < length,\n",
-       "                            obs_mat\u001b[1m[\u001b[0m:, obs_seq\u001b[1m]\u001b[0m * \u001b[1m(\u001b[0malpha_prev\u001b[1m[\u001b[0m:, \u001b[3;35mNone\u001b[0m\u001b[1m]\u001b[0m * \n",
-       "trans_mat\u001b[1m)\u001b[0m\u001b[1;35m.sum\u001b[0m\u001b[1m(\u001b[0m\u001b[33maxis\u001b[0m=\u001b[1;36m0\u001b[0m\u001b[1m)\u001b[0m,\n",
-       "                            \u001b[1;35mjnp.zeros_like\u001b[0m\u001b[1m(\u001b[0malpha_prev\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m\n",
-       "\n",
-       "        alpha_n, cn = \u001b[1;35mnormalize\u001b[0m\u001b[1m(\u001b[0malpha_n\u001b[1m)\u001b[0m\n",
-       "        carry = \u001b[1m(\u001b[0malpha_n, \u001b[1;35mjnp.log\u001b[0m\u001b[1m(\u001b[0mcn\u001b[1m)\u001b[0m + log_ll_prev\u001b[1m)\u001b[0m\n",
-       "\n",
-       "        return carry, alpha_n\n",
-       "\n",
-       "    # initial belief state\n",
-       "    alpha_0, c0 = \u001b[1;35mnormalize\u001b[0m\u001b[1m(\u001b[0minit_dist * obs_mat\u001b[1m[\u001b[0m:, obs_seq\u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m\n",
-       "\n",
-       "    # setup scan loop\n",
-       "    init_state = \u001b[1m(\u001b[0malpha_0, \u001b[1;35mjnp.log\u001b[0m\u001b[1m(\u001b[0mc0\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m\n",
-       "    ts = \u001b[1;35mjnp.arange\u001b[0m\u001b[1m(\u001b[0m\u001b[1;36m1\u001b[0m, seq_len\u001b[1m)\u001b[0m\n",
-       "    carry, alpha_hist = \u001b[1;35mlax.scan\u001b[0m\u001b[1m(\u001b[0mscan_fn, init_state, ts\u001b[1m)\u001b[0m\n",
-       "\n",
-       "    # post-process\n",
-       "    alpha_hist = \u001b[1;35mjnp.vstack\u001b[0m\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\n",
-       "    \u001b[1m(\u001b[0malpha_final, log_ll\u001b[1m)\u001b[0m = carry\n",
-       "    return log_ll, alpha_hist\n",
-       "\n"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    }
-   ],
+   "outputs": [],
    "source": [
-    "import jsl.hmm.hmm_lib as hmm_lib\n",
-    "print_source(hmm_lib.hmm_forwards_jax)\n",
-    "#https://github.com/probml/JSL/blob/main/jsl/hmm/hmm_lib.py#L189\n",
-    "\n"
+    "\n",
+    "\n",
+    "def normalize_np(u, axis=0, eps=1e-15):\n",
+    "    u = np.where(u == 0, 0, np.where(u < eps, eps, u))\n",
+    "    c = u.sum(axis=axis)\n",
+    "    c = np.where(c == 0, 1, c)\n",
+    "    return u / c, c\n",
+    "\n",
+    "def hmm_forwards_np(trans_mat, obs_mat, init_dist, obs_seq):\n",
+    "    n_states, n_obs = obs_mat.shape\n",
+    "    seq_len = len(obs_seq)\n",
+    "\n",
+    "    alpha_hist = np.zeros((seq_len, n_states))\n",
+    "    ll_hist = np.zeros(seq_len)  # loglikelihood history\n",
+    "\n",
+    "    alpha_n = init_dist * obs_mat[:, obs_seq[0]]\n",
+    "    alpha_n, cn = normalize_np(alpha_n)\n",
+    "\n",
+    "    alpha_hist[0] = alpha_n\n",
+    "    log_normalizer = np.log(cn)\n",
+    "\n",
+    "    for t in range(1, seq_len):\n",
+    "        alpha_n = obs_mat[:, obs_seq[t]] * (alpha_n[:, None] * trans_mat).sum(axis=0)\n",
+    "        alpha_n, zn = normalize_np(alpha_n)\n",
+    "\n",
+    "        alpha_hist[t] = alpha_n\n",
+    "        log_normalizer = np.log(zn) + log_normalizer\n",
+    "\n",
+    "    return  log_normalizer, alpha_hist"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Numerically stable implementation \n",
+    "\n",
+    "\n",
+    "\n",
+    "In practice it is more numerically stable to compute\n",
+    "the log likelihoods $\\ell_t(j) = \\log p(\\obs_t|\\hidden_t=j)$,\n",
+    "rather than the likelioods $\\lambda_t(j) = p(\\obs_t|\\hidden_t=j)$.\n",
+    "In this case, we can perform the posterior updating in a numerically stable way as follows.\n",
+    "Define $L_t = \\max_j \\ell_t(j)$ and\n",
+    "\\begin{align}\n",
+    "\\tilde{p}(\\hidden_t=j,\\obs_t|\\obs_{1:t-1})\n",
+    "&\\defeq p(\\hidden_t=j|\\obs_{1:t-1}) p(\\obs_t|\\hidden_t=j) e^{-L_t} \\\\\n",
+    " &= p(\\hidden_t=j|\\obs_{1:t-1}) e^{\\ell_t(j) - L_t}\n",
+    "\\end{align}\n",
+    "Then we have\n",
+    "\\begin{align}\n",
+    "p(\\hidden_t=j|\\obs_t,\\obs_{1:t-1})\n",
+    "  &= \\frac{1}{\\tilde{Z}_t} \\tilde{p}(\\hidden_t=j,\\obs_t|\\obs_{1:t-1}) \\\\\n",
+    "\\tilde{Z}_t &= \\sum_j \\tilde{p}(\\hidden_t=j,\\obs_t|\\obs_{1:t-1})\n",
+    "= p(\\obs_t|\\obs_{1:t-1}) e^{-L_t} \\\\\n",
+    "\\log Z_t &= \\log p(\\obs_t|\\obs_{1:t-1}) = \\log \\tilde{Z}_t + L_t\n",
+    "\\end{align}\n",
+    "\n",
+    "Below we show some JAX code that implements this core operation.\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "\n",
+    "def _condition_on(probs, ll):\n",
+    "    ll_max = ll.max()\n",
+    "    new_probs = probs * jnp.exp(ll - ll_max)\n",
+    "    norm = new_probs.sum()\n",
+    "    new_probs /= norm\n",
+    "    log_norm = jnp.log(norm) + ll_max\n",
+    "    return new_probs, log_norm"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "With the above function, we can implement a more numerically stable version of the forwards filter,\n",
+    "that works for any likelihood function, as shown below. It takes in the prior predictive distribution,\n",
+    "$\\alpha_{t|t-1}$,\n",
+    "stored in `predicted_probs`, and conditions them on the log-likelihood for each time step $\\ell_t$ to get the\n",
+    "posterior, $\\alpha_t$, stored in `filtered_probs`,\n",
+    "which is then converted to the prediction for the next state, $\\alpha_{t+1|t}$."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def _predict(probs, A):\n",
+    "    return A.T @ probs\n",
+    "\n",
+    "\n",
+    "def hmm_filter(initial_distribution,\n",
+    "               transition_matrix,\n",
+    "               log_likelihoods):\n",
+    "    def _step(carry, t):\n",
+    "        log_normalizer, predicted_probs = carry\n",
+    "\n",
+    "        # Get parameters for time t\n",
+    "        get = lambda x: x[t] if x.ndim == 3 else x\n",
+    "        A = get(transition_matrix)\n",
+    "        ll = log_likelihoods[t]\n",
+    "\n",
+    "        # Condition on emissions at time t, being careful not to overflow\n",
+    "        filtered_probs, log_norm = _condition_on(predicted_probs, ll)\n",
+    "        # Update the log normalizer\n",
+    "        log_normalizer += log_norm\n",
+    "        # Predict the next state\n",
+    "        predicted_probs = _predict(filtered_probs, A)\n",
+    "\n",
+    "        return (log_normalizer, predicted_probs), (filtered_probs, predicted_probs)\n",
+    "\n",
+    "    num_timesteps = len(log_likelihoods)\n",
+    "    carry = (0.0, initial_distribution)\n",
+    "    (log_normalizer, _), (filtered_probs, predicted_probs) = lax.scan(\n",
+    "        _step, carry, jnp.arange(num_timesteps))\n",
+    "    return log_normalizer, filtered_probs, predicted_probs"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "\n",
+    "TODO: check equivalence of these two implementations!"
    ]
   }
  ],

+ 1 - 1
_build/html/_sources/chapters/hmm/hmm_index.md

@@ -2,7 +2,7 @@
 # Hidden Markov Models 
 
 This chapter discusses Hidden Markov Models (HMMs), which are state space models
-in which the latent state $z_t \in \{1,\ldots,K\}$ is discrete.
+in which the latent state $\hidden_t \in \{1,\ldots,\nstates\}$ is discrete.
 
 
 ```{tableofcontents}

+ 6 - 6
_build/html/_sources/chapters/scratch.ipynb

@@ -2,7 +2,7 @@
  "cells": [
   {
    "cell_type": "markdown",
-   "id": "d79c3788",
+   "id": "9d5991e9",
    "metadata": {},
    "source": [
     "(ch:intro)=\n",
@@ -28,7 +28,7 @@
   {
    "cell_type": "code",
    "execution_count": 1,
-   "id": "28c55864",
+   "id": "08375669",
    "metadata": {},
    "outputs": [],
    "source": [
@@ -41,7 +41,7 @@
   {
    "cell_type": "code",
    "execution_count": 2,
-   "id": "b0bc0f28",
+   "id": "b5b2f4ec",
    "metadata": {},
    "outputs": [
     {
@@ -84,7 +84,7 @@
   {
    "cell_type": "code",
    "execution_count": 3,
-   "id": "902c33c0",
+   "id": "06fcf1f1",
    "metadata": {},
    "outputs": [
     {
@@ -113,7 +113,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "945c9164",
+   "id": "8688e6c2",
    "metadata": {},
    "source": [
     "## Images\n",
@@ -151,7 +151,7 @@
     "\n",
     "## Math\n",
     "\n",
-    "Here is $\\N=10$ and blah. $\\floor{42.3}= 42$.\n",
+    "Here is $\\N=10$ and blah. $\\floor{42.3}= 42$. Let's try again.\n",
     "\n",
     "We have $E= mc^2$, and also\n",
     "\n",

+ 1 - 1
_build/html/_sources/chapters/scratch.md

@@ -102,7 +102,7 @@ I am a useful note!
 
 ## Math
 
-Here is $\N=10$ and blah. $\floor{42.3}= 42$.
+Here is $\N=10$ and blah. $\floor{42.3}= 42$. Let's try again.
 
 We have $E= mc^2$, and also
 

文件差異過大導致無法顯示
+ 552 - 0
_build/html/_sources/chapters/scratchpad.ipynb


+ 38 - 243
_build/html/_sources/chapters/ssm/hmm.ipynb

@@ -1,49 +1,26 @@
 {
  "cells": [
   {
-   "cell_type": "code",
-   "execution_count": 1,
+   "cell_type": "markdown",
    "metadata": {},
-   "outputs": [],
    "source": [
-    "# meta-data does not work yet in VScode\n",
-    "# https://github.com/microsoft/vscode-jupyter/issues/1121\n",
-    "\n",
-    "{\n",
-    "    \"tags\": [\n",
-    "        \"hide-cell\"\n",
-    "    ]\n",
-    "}\n",
-    "\n",
-    "\n",
-    "### Install necessary libraries\n",
-    "\n",
-    "try:\n",
-    "    import jax\n",
-    "except:\n",
-    "    # For cuda version, see https://github.com/google/jax#installation\n",
-    "    %pip install --upgrade \"jax[cpu]\" \n",
-    "    import jax\n",
-    "\n",
-    "try:\n",
-    "    import distrax\n",
-    "except:\n",
-    "    %pip install --upgrade  distrax\n",
-    "    import distrax\n",
-    "\n",
-    "try:\n",
-    "    import jsl\n",
-    "except:\n",
-    "    %pip install git+https://github.com/probml/jsl\n",
-    "    import jsl\n",
-    "\n",
-    "try:\n",
-    "    import rich\n",
-    "except:\n",
-    "    %pip install rich\n",
-    "    import rich\n",
+    "(sec:hmm-intro)=\n",
+    "# Hidden Markov Models\n",
     "\n",
-    "\n"
+    "In this section, we discuss the\n",
+    "hidden Markov model or HMM,\n",
+    "which is a state space model in which the hidden states\n",
+    "are discrete, so $\\hidden_t \\in \\{1,\\ldots, \\nstates\\}$.\n",
+    "The observations may be discrete,\n",
+    "$\\obs_t \\in \\{1,\\ldots, \\nsymbols\\}$,\n",
+    "or continuous,\n",
+    "$\\obs_t \\in \\real^\\nstates$,\n",
+    "or some combination,\n",
+    "as we illustrate below.\n",
+    "More details can be found in e.g., \n",
+    "{cite}`Rabiner89,Fraser08,Cappe05`.\n",
+    "For an interactive introduction,\n",
+    "see https://nipunbatra.github.io/hmm/."
    ]
   },
   {
@@ -64,21 +41,26 @@
     "import abc\n",
     "from dataclasses import dataclass\n",
     "import functools\n",
+    "from functools import partial\n",
     "import itertools\n",
-    "\n",
-    "from typing import Any, Callable, NamedTuple, Optional, Union, Tuple\n",
-    "\n",
     "import matplotlib.pyplot as plt\n",
     "import numpy as np\n",
-    "\n",
+    "from typing import Any, Callable, NamedTuple, Optional, Union, Tuple\n",
     "\n",
     "import jax\n",
     "import jax.numpy as jnp\n",
     "from jax import lax, vmap, jit, grad\n",
-    "from jax.scipy.special import logit\n",
-    "from jax.nn import softmax\n",
-    "from functools import partial\n",
-    "from jax.random import PRNGKey, split\n",
+    "#from jax.scipy.special import logit\n",
+    "#from jax.nn import softmax\n",
+    "import jax.random as jr\n",
+    "\n",
+    "\n",
+    "\n",
+    "import distrax\n",
+    "import optax\n",
+    "\n",
+    "import jsl\n",
+    "import ssm_jax\n",
     "\n",
     "import inspect\n",
     "import inspect as py_inspect\n",
@@ -94,193 +76,6 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "```{math}\n",
-    "\n",
-    "\\newcommand\\floor[1]{\\lfloor#1\\rfloor}\n",
-    "\n",
-    "\\newcommand{\\real}{\\mathbb{R}}\n",
-    "\n",
-    "% Numbers\n",
-    "\\newcommand{\\vzero}{\\boldsymbol{0}}\n",
-    "\\newcommand{\\vone}{\\boldsymbol{1}}\n",
-    "\n",
-    "% Greek https://www.latex-tutorial.com/symbols/greek-alphabet/\n",
-    "\\newcommand{\\valpha}{\\boldsymbol{\\alpha}}\n",
-    "\\newcommand{\\vbeta}{\\boldsymbol{\\beta}}\n",
-    "\\newcommand{\\vchi}{\\boldsymbol{\\chi}}\n",
-    "\\newcommand{\\vdelta}{\\boldsymbol{\\delta}}\n",
-    "\\newcommand{\\vDelta}{\\boldsymbol{\\Delta}}\n",
-    "\\newcommand{\\vepsilon}{\\boldsymbol{\\epsilon}}\n",
-    "\\newcommand{\\vzeta}{\\boldsymbol{\\zeta}}\n",
-    "\\newcommand{\\vXi}{\\boldsymbol{\\Xi}}\n",
-    "\\newcommand{\\vell}{\\boldsymbol{\\ell}}\n",
-    "\\newcommand{\\veta}{\\boldsymbol{\\eta}}\n",
-    "%\\newcommand{\\vEta}{\\boldsymbol{\\Eta}}\n",
-    "\\newcommand{\\vgamma}{\\boldsymbol{\\gamma}}\n",
-    "\\newcommand{\\vGamma}{\\boldsymbol{\\Gamma}}\n",
-    "\\newcommand{\\vmu}{\\boldsymbol{\\mu}}\n",
-    "\\newcommand{\\vmut}{\\boldsymbol{\\tilde{\\mu}}}\n",
-    "\\newcommand{\\vnu}{\\boldsymbol{\\nu}}\n",
-    "\\newcommand{\\vkappa}{\\boldsymbol{\\kappa}}\n",
-    "\\newcommand{\\vlambda}{\\boldsymbol{\\lambda}}\n",
-    "\\newcommand{\\vLambda}{\\boldsymbol{\\Lambda}}\n",
-    "\\newcommand{\\vLambdaBar}{\\overline{\\vLambda}}\n",
-    "%\\newcommand{\\vnu}{\\boldsymbol{\\nu}}\n",
-    "\\newcommand{\\vomega}{\\boldsymbol{\\omega}}\n",
-    "\\newcommand{\\vOmega}{\\boldsymbol{\\Omega}}\n",
-    "\\newcommand{\\vphi}{\\boldsymbol{\\phi}}\n",
-    "\\newcommand{\\vvarphi}{\\boldsymbol{\\varphi}}\n",
-    "\\newcommand{\\vPhi}{\\boldsymbol{\\Phi}}\n",
-    "\\newcommand{\\vpi}{\\boldsymbol{\\pi}}\n",
-    "\\newcommand{\\vPi}{\\boldsymbol{\\Pi}}\n",
-    "\\newcommand{\\vpsi}{\\boldsymbol{\\psi}}\n",
-    "\\newcommand{\\vPsi}{\\boldsymbol{\\Psi}}\n",
-    "\\newcommand{\\vrho}{\\boldsymbol{\\rho}}\n",
-    "\\newcommand{\\vtheta}{\\boldsymbol{\\theta}}\n",
-    "\\newcommand{\\vthetat}{\\boldsymbol{\\tilde{\\theta}}}\n",
-    "\\newcommand{\\vTheta}{\\boldsymbol{\\Theta}}\n",
-    "\\newcommand{\\vsigma}{\\boldsymbol{\\sigma}}\n",
-    "\\newcommand{\\vSigma}{\\boldsymbol{\\Sigma}}\n",
-    "\\newcommand{\\vSigmat}{\\boldsymbol{\\tilde{\\Sigma}}}\n",
-    "\\newcommand{\\vsigmoid}{\\vsigma}\n",
-    "\\newcommand{\\vtau}{\\boldsymbol{\\tau}}\n",
-    "\\newcommand{\\vxi}{\\boldsymbol{\\xi}}\n",
-    "\n",
-    "\n",
-    "% Lower Roman (Vectors)\n",
-    "\\newcommand{\\va}{\\mathbf{a}}\n",
-    "\\newcommand{\\vb}{\\mathbf{b}}\n",
-    "\\newcommand{\\vBt}{\\mathbf{\\tilde{B}}}\n",
-    "\\newcommand{\\vc}{\\mathbf{c}}\n",
-    "\\newcommand{\\vct}{\\mathbf{\\tilde{c}}}\n",
-    "\\newcommand{\\vd}{\\mathbf{d}}\n",
-    "\\newcommand{\\ve}{\\mathbf{e}}\n",
-    "\\newcommand{\\vf}{\\mathbf{f}}\n",
-    "\\newcommand{\\vg}{\\mathbf{g}}\n",
-    "\\newcommand{\\vh}{\\mathbf{h}}\n",
-    "%\\newcommand{\\myvh}{\\mathbf{h}}\n",
-    "\\newcommand{\\vi}{\\mathbf{i}}\n",
-    "\\newcommand{\\vj}{\\mathbf{j}}\n",
-    "\\newcommand{\\vk}{\\mathbf{k}}\n",
-    "\\newcommand{\\vl}{\\mathbf{l}}\n",
-    "\\newcommand{\\vm}{\\mathbf{m}}\n",
-    "\\newcommand{\\vn}{\\mathbf{n}}\n",
-    "\\newcommand{\\vo}{\\mathbf{o}}\n",
-    "\\newcommand{\\vp}{\\mathbf{p}}\n",
-    "\\newcommand{\\vq}{\\mathbf{q}}\n",
-    "\\newcommand{\\vr}{\\mathbf{r}}\n",
-    "\\newcommand{\\vs}{\\mathbf{s}}\n",
-    "\\newcommand{\\vt}{\\mathbf{t}}\n",
-    "\\newcommand{\\vu}{\\mathbf{u}}\n",
-    "\\newcommand{\\vv}{\\mathbf{v}}\n",
-    "\\newcommand{\\vw}{\\mathbf{w}}\n",
-    "\\newcommand{\\vws}{\\vw_s}\n",
-    "\\newcommand{\\vwt}{\\mathbf{\\tilde{w}}}\n",
-    "\\newcommand{\\vWt}{\\mathbf{\\tilde{W}}}\n",
-    "\\newcommand{\\vwh}{\\hat{\\vw}}\n",
-    "\\newcommand{\\vx}{\\mathbf{x}}\n",
-    "%\\newcommand{\\vx}{\\mathbf{x}}\n",
-    "\\newcommand{\\vxt}{\\mathbf{\\tilde{x}}}\n",
-    "\\newcommand{\\vy}{\\mathbf{y}}\n",
-    "\\newcommand{\\vyt}{\\mathbf{\\tilde{y}}}\n",
-    "\\newcommand{\\vz}{\\mathbf{z}}\n",
-    "%\\newcommand{\\vzt}{\\mathbf{\\tilde{z}}}\n",
-    "\n",
-    "\n",
-    "% Upper Roman (Matrices)\n",
-    "\\newcommand{\\vA}{\\mathbf{A}}\n",
-    "\\newcommand{\\vB}{\\mathbf{B}}\n",
-    "\\newcommand{\\vC}{\\mathbf{C}}\n",
-    "\\newcommand{\\vD}{\\mathbf{D}}\n",
-    "\\newcommand{\\vE}{\\mathbf{E}}\n",
-    "\\newcommand{\\vF}{\\mathbf{F}}\n",
-    "\\newcommand{\\vG}{\\mathbf{G}}\n",
-    "\\newcommand{\\vH}{\\mathbf{H}}\n",
-    "\\newcommand{\\vI}{\\mathbf{I}}\n",
-    "\\newcommand{\\vJ}{\\mathbf{J}}\n",
-    "\\newcommand{\\vK}{\\mathbf{K}}\n",
-    "\\newcommand{\\vL}{\\mathbf{L}}\n",
-    "\\newcommand{\\vM}{\\mathbf{M}}\n",
-    "\\newcommand{\\vMt}{\\mathbf{\\tilde{M}}}\n",
-    "\\newcommand{\\vN}{\\mathbf{N}}\n",
-    "\\newcommand{\\vO}{\\mathbf{O}}\n",
-    "\\newcommand{\\vP}{\\mathbf{P}}\n",
-    "\\newcommand{\\vQ}{\\mathbf{Q}}\n",
-    "\\newcommand{\\vR}{\\mathbf{R}}\n",
-    "\\newcommand{\\vS}{\\mathbf{S}}\n",
-    "\\newcommand{\\vT}{\\mathbf{T}}\n",
-    "\\newcommand{\\vU}{\\mathbf{U}}\n",
-    "\\newcommand{\\vV}{\\mathbf{V}}\n",
-    "\\newcommand{\\vW}{\\mathbf{W}}\n",
-    "\\newcommand{\\vX}{\\mathbf{X}}\n",
-    "%\\newcommand{\\vXs}{\\vX_{\\vs}}\n",
-    "\\newcommand{\\vXs}{\\vX_{s}}\n",
-    "\\newcommand{\\vXt}{\\mathbf{\\tilde{X}}}\n",
-    "\\newcommand{\\vY}{\\mathbf{Y}}\n",
-    "\\newcommand{\\vZ}{\\mathbf{Z}}\n",
-    "\\newcommand{\\vZt}{\\mathbf{\\tilde{Z}}}\n",
-    "\\newcommand{\\vzt}{\\mathbf{\\tilde{z}}}\n",
-    "\n",
-    "\n",
-    "%%%%\n",
-    "\\newcommand{\\hidden}{\\vz}\n",
-    "\\newcommand{\\hid}{\\hidden}\n",
-    "\\newcommand{\\observed}{\\vy}\n",
-    "\\newcommand{\\obs}{\\observed}\n",
-    "\\newcommand{\\inputs}{\\vu}\n",
-    "\\newcommand{\\input}{\\inputs}\n",
-    "\n",
-    "\\newcommand{\\hmmTrans}{\\vA}\n",
-    "\\newcommand{\\hmmObs}{\\vB}\n",
-    "\\newcommand{\\hmmInit}{\\vpi}\n",
-    "\\newcommand{\\hmmhid}{\\hidden}\n",
-    "\\newcommand{\\hmmobs}{\\obs}\n",
-    "\n",
-    "\\newcommand{\\ldsDyn}{\\vA}\n",
-    "\\newcommand{\\ldsObs}{\\vC}\n",
-    "\\newcommand{\\ldsDynIn}{\\vB}\n",
-    "\\newcommand{\\ldsObsIn}{\\vD}\n",
-    "\\newcommand{\\ldsDynNoise}{\\vQ}\n",
-    "\\newcommand{\\ldsObsNoise}{\\vR}\n",
-    "\n",
-    "\\newcommand{\\ssmDynFn}{f}\n",
-    "\\newcommand{\\ssmObsFn}{h}\n",
-    "\n",
-    "\n",
-    "%%%\n",
-    "\\newcommand{\\gauss}{\\mathcal{N}}\n",
-    "\n",
-    "\\newcommand{\\diag}{\\mathrm{diag}}\n",
-    "```\n"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "(sec:hmm-intro)=\n",
-    "# Hidden Markov Models\n",
-    "\n",
-    "In this section, we discuss the\n",
-    "hidden Markov model or HMM,\n",
-    "which is a state space model in which the hidden states\n",
-    "are discrete, so $\\hmmhid_t \\in \\{1,\\ldots, K\\}$.\n",
-    "The observations may be discrete,\n",
-    "$\\hmmobs_t \\in \\{1,\\ldots, C\\}$,\n",
-    "or continuous,\n",
-    "$\\hmmobs_t \\in \\real^D$,\n",
-    "or some combination,\n",
-    "as we illustrate below.\n",
-    "More details can be found in e.g., \n",
-    "{cite}`Rabiner89,Fraser08,Cappe05`.\n",
-    "For an interactive introduction,\n",
-    "see https://nipunbatra.github.io/hmm/."
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
     "(sec:casino)=\n",
     "### Example: Casino HMM\n",
     "\n",
@@ -291,9 +86,9 @@
     "\n",
     "The transition model is denoted by\n",
     "```{math}\n",
-    "p(z_t=j|z_{t-1}=i) = \\hmmTrans_{ij}\n",
+    "p(\\hidden_t=j|\\hidden_{t-1}=i) = \\hmmTransScalar_{ij}\n",
     "```\n",
-    "Here the $i$'th row of $\\vA$ corresponds to the outgoing distribution from state $i$.\n",
+    "Here the $i$'th row of $\\hmmTrans$ corresponds to the outgoing distribution from state $i$.\n",
     "This is  a row stochastic matrix,\n",
     "meaning each row sums to one.\n",
     "We can visualize\n",
@@ -311,18 +106,18 @@
     "The  observation model\n",
     "$p(\\obs_t|\\hidden_t=j)$ has the form\n",
     "```{math}\n",
-    "p(\\obs_t=k|\\hidden_t=j) = \\hmmObs_{jk} \n",
+    "p(\\obs_t=k|\\hidden_t=j) = \\hmmObsScalar_{jk} \n",
     "```\n",
     "This is represented by the histograms associated with each\n",
-    "state in  {numref}`casino-fig`.\n",
+    "state in  {numref}`fig:casino`.\n",
     "\n",
     "Finally,\n",
     "the initial state distribution is denoted by\n",
     "```{math}\n",
-    "p(z_1=j) = \\hmmInit_j\n",
+    "p(\\hidden_1=j) = \\hmmInitScalar_j\n",
     "```\n",
     "\n",
-    "Collectively we denote all the parameters by $\\vtheta=(\\hmmTrans, \\hmmObs, \\hmmInit)$.\n",
+    "Collectively we denote all the parameters by $\\params=(\\hmmTrans, \\hmmObs, \\hmmInit)$.\n",
     "\n",
     "Now let us implement this model in code."
    ]
@@ -387,7 +182,7 @@
    "metadata": {},
    "source": [
     "\n",
-    "Let's sample from the model. We will generate a sequence of latent states, $\\hid_{1:T}$,\n",
+    "Let's sample from the model. We will generate a sequence of latent states, $\\hidden_{1:T}$,\n",
     "which we then convert to a sequence of observations, $\\obs_{1:T}$."
    ]
   },
@@ -412,7 +207,7 @@
     "\n",
     "seed = 314\n",
     "n_samples = 300\n",
-    "z_hist, x_hist = hmm.sample(seed=PRNGKey(seed), seq_len=n_samples)\n",
+    "z_hist, x_hist = hmm.sample(seed=jr.PRNGKey(seed), seq_len=n_samples)\n",
     "\n",
     "z_hist_str = \"\".join((np.array(z_hist) + 1).astype(str))[:60]\n",
     "x_hist_str = \"\".join((np.array(x_hist) + 1).astype(str))[:60]\n",

+ 30 - 244
_build/html/_sources/chapters/ssm/inference.ipynb

@@ -2,52 +2,6 @@
  "cells": [
   {
    "cell_type": "code",
-   "execution_count": 2,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "# meta-data does not work yet in VScode\n",
-    "# https://github.com/microsoft/vscode-jupyter/issues/1121\n",
-    "\n",
-    "{\n",
-    "    \"tags\": [\n",
-    "        \"hide-cell\"\n",
-    "    ]\n",
-    "}\n",
-    "\n",
-    "\n",
-    "### Install necessary libraries\n",
-    "\n",
-    "try:\n",
-    "    import jax\n",
-    "except:\n",
-    "    # For cuda version, see https://github.com/google/jax#installation\n",
-    "    %pip install --upgrade \"jax[cpu]\" \n",
-    "    import jax\n",
-    "\n",
-    "try:\n",
-    "    import distrax\n",
-    "except:\n",
-    "    %pip install --upgrade  distrax\n",
-    "    import distrax\n",
-    "\n",
-    "try:\n",
-    "    import jsl\n",
-    "except:\n",
-    "    %pip install git+https://github.com/probml/jsl\n",
-    "    import jsl\n",
-    "\n",
-    "try:\n",
-    "    import rich\n",
-    "except:\n",
-    "    %pip install rich\n",
-    "    import rich\n",
-    "\n",
-    "\n"
-   ]
-  },
-  {
-   "cell_type": "code",
    "execution_count": 3,
    "metadata": {},
    "outputs": [],
@@ -80,178 +34,8 @@
     "from functools import partial\n",
     "from jax.random import PRNGKey, split\n",
     "\n",
-    "import inspect\n",
-    "import inspect as py_inspect\n",
-    "import rich\n",
-    "from rich import inspect as r_inspect\n",
-    "from rich import print as r_print\n",
-    "\n",
-    "def print_source(fname):\n",
-    "    r_print(py_inspect.getsource(fname))"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "```{math}\n",
-    "\n",
-    "\\newcommand\\floor[1]{\\lfloor#1\\rfloor}\n",
-    "\n",
-    "\\newcommand{\\real}{\\mathbb{R}}\n",
-    "\n",
-    "% Numbers\n",
-    "\\newcommand{\\vzero}{\\boldsymbol{0}}\n",
-    "\\newcommand{\\vone}{\\boldsymbol{1}}\n",
-    "\n",
-    "% Greek https://www.latex-tutorial.com/symbols/greek-alphabet/\n",
-    "\\newcommand{\\valpha}{\\boldsymbol{\\alpha}}\n",
-    "\\newcommand{\\vbeta}{\\boldsymbol{\\beta}}\n",
-    "\\newcommand{\\vchi}{\\boldsymbol{\\chi}}\n",
-    "\\newcommand{\\vdelta}{\\boldsymbol{\\delta}}\n",
-    "\\newcommand{\\vDelta}{\\boldsymbol{\\Delta}}\n",
-    "\\newcommand{\\vepsilon}{\\boldsymbol{\\epsilon}}\n",
-    "\\newcommand{\\vzeta}{\\boldsymbol{\\zeta}}\n",
-    "\\newcommand{\\vXi}{\\boldsymbol{\\Xi}}\n",
-    "\\newcommand{\\vell}{\\boldsymbol{\\ell}}\n",
-    "\\newcommand{\\veta}{\\boldsymbol{\\eta}}\n",
-    "%\\newcommand{\\vEta}{\\boldsymbol{\\Eta}}\n",
-    "\\newcommand{\\vgamma}{\\boldsymbol{\\gamma}}\n",
-    "\\newcommand{\\vGamma}{\\boldsymbol{\\Gamma}}\n",
-    "\\newcommand{\\vmu}{\\boldsymbol{\\mu}}\n",
-    "\\newcommand{\\vmut}{\\boldsymbol{\\tilde{\\mu}}}\n",
-    "\\newcommand{\\vnu}{\\boldsymbol{\\nu}}\n",
-    "\\newcommand{\\vkappa}{\\boldsymbol{\\kappa}}\n",
-    "\\newcommand{\\vlambda}{\\boldsymbol{\\lambda}}\n",
-    "\\newcommand{\\vLambda}{\\boldsymbol{\\Lambda}}\n",
-    "\\newcommand{\\vLambdaBar}{\\overline{\\vLambda}}\n",
-    "%\\newcommand{\\vnu}{\\boldsymbol{\\nu}}\n",
-    "\\newcommand{\\vomega}{\\boldsymbol{\\omega}}\n",
-    "\\newcommand{\\vOmega}{\\boldsymbol{\\Omega}}\n",
-    "\\newcommand{\\vphi}{\\boldsymbol{\\phi}}\n",
-    "\\newcommand{\\vvarphi}{\\boldsymbol{\\varphi}}\n",
-    "\\newcommand{\\vPhi}{\\boldsymbol{\\Phi}}\n",
-    "\\newcommand{\\vpi}{\\boldsymbol{\\pi}}\n",
-    "\\newcommand{\\vPi}{\\boldsymbol{\\Pi}}\n",
-    "\\newcommand{\\vpsi}{\\boldsymbol{\\psi}}\n",
-    "\\newcommand{\\vPsi}{\\boldsymbol{\\Psi}}\n",
-    "\\newcommand{\\vrho}{\\boldsymbol{\\rho}}\n",
-    "\\newcommand{\\vtheta}{\\boldsymbol{\\theta}}\n",
-    "\\newcommand{\\vthetat}{\\boldsymbol{\\tilde{\\theta}}}\n",
-    "\\newcommand{\\vTheta}{\\boldsymbol{\\Theta}}\n",
-    "\\newcommand{\\vsigma}{\\boldsymbol{\\sigma}}\n",
-    "\\newcommand{\\vSigma}{\\boldsymbol{\\Sigma}}\n",
-    "\\newcommand{\\vSigmat}{\\boldsymbol{\\tilde{\\Sigma}}}\n",
-    "\\newcommand{\\vsigmoid}{\\vsigma}\n",
-    "\\newcommand{\\vtau}{\\boldsymbol{\\tau}}\n",
-    "\\newcommand{\\vxi}{\\boldsymbol{\\xi}}\n",
-    "\n",
-    "\n",
-    "% Lower Roman (Vectors)\n",
-    "\\newcommand{\\va}{\\mathbf{a}}\n",
-    "\\newcommand{\\vb}{\\mathbf{b}}\n",
-    "\\newcommand{\\vBt}{\\mathbf{\\tilde{B}}}\n",
-    "\\newcommand{\\vc}{\\mathbf{c}}\n",
-    "\\newcommand{\\vct}{\\mathbf{\\tilde{c}}}\n",
-    "\\newcommand{\\vd}{\\mathbf{d}}\n",
-    "\\newcommand{\\ve}{\\mathbf{e}}\n",
-    "\\newcommand{\\vf}{\\mathbf{f}}\n",
-    "\\newcommand{\\vg}{\\mathbf{g}}\n",
-    "\\newcommand{\\vh}{\\mathbf{h}}\n",
-    "%\\newcommand{\\myvh}{\\mathbf{h}}\n",
-    "\\newcommand{\\vi}{\\mathbf{i}}\n",
-    "\\newcommand{\\vj}{\\mathbf{j}}\n",
-    "\\newcommand{\\vk}{\\mathbf{k}}\n",
-    "\\newcommand{\\vl}{\\mathbf{l}}\n",
-    "\\newcommand{\\vm}{\\mathbf{m}}\n",
-    "\\newcommand{\\vn}{\\mathbf{n}}\n",
-    "\\newcommand{\\vo}{\\mathbf{o}}\n",
-    "\\newcommand{\\vp}{\\mathbf{p}}\n",
-    "\\newcommand{\\vq}{\\mathbf{q}}\n",
-    "\\newcommand{\\vr}{\\mathbf{r}}\n",
-    "\\newcommand{\\vs}{\\mathbf{s}}\n",
-    "\\newcommand{\\vt}{\\mathbf{t}}\n",
-    "\\newcommand{\\vu}{\\mathbf{u}}\n",
-    "\\newcommand{\\vv}{\\mathbf{v}}\n",
-    "\\newcommand{\\vw}{\\mathbf{w}}\n",
-    "\\newcommand{\\vws}{\\vw_s}\n",
-    "\\newcommand{\\vwt}{\\mathbf{\\tilde{w}}}\n",
-    "\\newcommand{\\vWt}{\\mathbf{\\tilde{W}}}\n",
-    "\\newcommand{\\vwh}{\\hat{\\vw}}\n",
-    "\\newcommand{\\vx}{\\mathbf{x}}\n",
-    "%\\newcommand{\\vx}{\\mathbf{x}}\n",
-    "\\newcommand{\\vxt}{\\mathbf{\\tilde{x}}}\n",
-    "\\newcommand{\\vy}{\\mathbf{y}}\n",
-    "\\newcommand{\\vyt}{\\mathbf{\\tilde{y}}}\n",
-    "\\newcommand{\\vz}{\\mathbf{z}}\n",
-    "%\\newcommand{\\vzt}{\\mathbf{\\tilde{z}}}\n",
-    "\n",
-    "\n",
-    "% Upper Roman (Matrices)\n",
-    "\\newcommand{\\vA}{\\mathbf{A}}\n",
-    "\\newcommand{\\vB}{\\mathbf{B}}\n",
-    "\\newcommand{\\vC}{\\mathbf{C}}\n",
-    "\\newcommand{\\vD}{\\mathbf{D}}\n",
-    "\\newcommand{\\vE}{\\mathbf{E}}\n",
-    "\\newcommand{\\vF}{\\mathbf{F}}\n",
-    "\\newcommand{\\vG}{\\mathbf{G}}\n",
-    "\\newcommand{\\vH}{\\mathbf{H}}\n",
-    "\\newcommand{\\vI}{\\mathbf{I}}\n",
-    "\\newcommand{\\vJ}{\\mathbf{J}}\n",
-    "\\newcommand{\\vK}{\\mathbf{K}}\n",
-    "\\newcommand{\\vL}{\\mathbf{L}}\n",
-    "\\newcommand{\\vM}{\\mathbf{M}}\n",
-    "\\newcommand{\\vMt}{\\mathbf{\\tilde{M}}}\n",
-    "\\newcommand{\\vN}{\\mathbf{N}}\n",
-    "\\newcommand{\\vO}{\\mathbf{O}}\n",
-    "\\newcommand{\\vP}{\\mathbf{P}}\n",
-    "\\newcommand{\\vQ}{\\mathbf{Q}}\n",
-    "\\newcommand{\\vR}{\\mathbf{R}}\n",
-    "\\newcommand{\\vS}{\\mathbf{S}}\n",
-    "\\newcommand{\\vT}{\\mathbf{T}}\n",
-    "\\newcommand{\\vU}{\\mathbf{U}}\n",
-    "\\newcommand{\\vV}{\\mathbf{V}}\n",
-    "\\newcommand{\\vW}{\\mathbf{W}}\n",
-    "\\newcommand{\\vX}{\\mathbf{X}}\n",
-    "%\\newcommand{\\vXs}{\\vX_{\\vs}}\n",
-    "\\newcommand{\\vXs}{\\vX_{s}}\n",
-    "\\newcommand{\\vXt}{\\mathbf{\\tilde{X}}}\n",
-    "\\newcommand{\\vY}{\\mathbf{Y}}\n",
-    "\\newcommand{\\vZ}{\\mathbf{Z}}\n",
-    "\\newcommand{\\vZt}{\\mathbf{\\tilde{Z}}}\n",
-    "\\newcommand{\\vzt}{\\mathbf{\\tilde{z}}}\n",
-    "\n",
-    "\n",
-    "%%%%\n",
-    "\\newcommand{\\hidden}{\\vz}\n",
-    "\\newcommand{\\hid}{\\hidden}\n",
-    "\\newcommand{\\observed}{\\vy}\n",
-    "\\newcommand{\\obs}{\\observed}\n",
-    "\\newcommand{\\inputs}{\\vu}\n",
-    "\\newcommand{\\input}{\\inputs}\n",
-    "\n",
-    "\\newcommand{\\hmmTrans}{\\vA}\n",
-    "\\newcommand{\\hmmObs}{\\vB}\n",
-    "\\newcommand{\\hmmInit}{\\vpi}\n",
-    "\\newcommand{\\hmmhid}{\\hidden}\n",
-    "\\newcommand{\\hmmobs}{\\obs}\n",
-    "\n",
-    "\\newcommand{\\ldsDyn}{\\vA}\n",
-    "\\newcommand{\\ldsObs}{\\vC}\n",
-    "\\newcommand{\\ldsDynIn}{\\vB}\n",
-    "\\newcommand{\\ldsObsIn}{\\vD}\n",
-    "\\newcommand{\\ldsDynNoise}{\\vQ}\n",
-    "\\newcommand{\\ldsObsNoise}{\\vR}\n",
-    "\n",
-    "\\newcommand{\\ssmDynFn}{f}\n",
-    "\\newcommand{\\ssmObsFn}{h}\n",
-    "\n",
-    "\n",
-    "%%%\n",
-    "\\newcommand{\\gauss}{\\mathcal{N}}\n",
-    "\n",
-    "\\newcommand{\\diag}{\\mathrm{diag}}\n",
-    "```\n"
+    "import jsl\n",
+    "import ssm_jax\n"
    ]
   },
   {
@@ -259,19 +43,9 @@
    "metadata": {},
    "source": [
     "(sec:inference)=\n",
-    "# Inferential goals\n",
+    "# States estimation (inference)\n",
     "\n",
-    "```{figure} /figures/inference-problems-tikz.png\n",
-    ":scale: 30%\n",
-    ":name: fig:dbn-inference\n",
     "\n",
-    "Illustration of the different kinds of inference in an SSM.\n",
-    " The main kinds of inference for state-space models.\n",
-    "    The shaded region is the interval for which we have data.\n",
-    "    The arrow represents the time step at which we want to perform inference.\n",
-    "    $t$ is the current time,  $T$ is the sequence length,\n",
-    "$\\ell$ is the lag and $h$ is the prediction horizon.\n",
-    "```\n",
     "\n",
     "\n",
     "\n",
@@ -284,41 +58,53 @@
     "there are multiple forms of posterior we may be interested in computing,\n",
     "including the following:\n",
     "- the filtering distribution\n",
-    "$p(\\hmmhid_t|\\hmmobs_{1:t})$\n",
+    "$p(\\hidden_t|\\obs_{1:t})$\n",
     "- the smoothing distribution\n",
-    "$p(\\hmmhid_t|\\hmmobs_{1:T})$ (note that this conditions on future data $T>t$)\n",
+    "$p(\\hidden_t|\\obs_{1:T})$ (note that this conditions on future data $T>t$)\n",
     "- the fixed-lag smoothing distribution\n",
-    "$p(\\hmmhid_{t-\\ell}|\\hmmobs_{1:t})$ (note that this\n",
+    "$p(\\hidden_{t-\\ell}|\\obs_{1:t})$ (note that this\n",
     "infers $\\ell$ steps in the past given data up to the present).\n",
     "\n",
     "We may also want to compute the\n",
     "predictive distribution $h$ steps into the future:\n",
-    "```{math}\n",
-    "p(\\hmmobs_{t+h}|\\hmmobs_{1:t})\n",
-    "= \\sum_{\\hmmhid_{t+h}} p(\\hmmobs_{t+h}|\\hmmhid_{t+h}) p(\\hmmhid_{t+h}|\\hmmobs_{1:t})\n",
-    "```\n",
+    "\\begin{align}\n",
+    "p(\\obs_{t+h}|\\obs_{1:t})\n",
+    "= \\sum_{\\hidden_{t+h}} p(\\obs_{t+h}|\\hidden_{t+h}) p(\\hidden_{t+h}|\\obs_{1:t})\n",
+    "\\end{align}\n",
     "where the hidden state predictive distribution is\n",
     "\\begin{align}\n",
-    "p(\\hmmhid_{t+h}|\\hmmobs_{1:t})\n",
-    "&= \\sum_{\\hmmhid_{t:t+h-1}}\n",
-    " p(\\hmmhid_t|\\hmmobs_{1:t}) \n",
-    " p(\\hmmhid_{t+1}|\\hmmhid_{t})\n",
-    " p(\\hmmhid_{t+2}|\\hmmhid_{t+1})\n",
+    "p(\\hidden_{t+h}|\\obs_{1:t})\n",
+    "&= \\sum_{\\hidden_{t:t+h-1}}\n",
+    " p(\\hidden_t|\\obs_{1:t}) \n",
+    " p(\\hidden_{t+1}|\\hidden_{t})\n",
+    " p(\\hidden_{t+2}|\\hidden_{t+1})\n",
     "\\cdots\n",
-    " p(\\hmmhid_{t+h}|\\hmmhid_{t+h-1})\n",
+    " p(\\hidden_{t+h}|\\hidden_{t+h-1})\n",
     "\\end{align}\n",
     "See \n",
     "{numref}`fig:dbn-inference` for a summary of these distributions.\n",
     "\n",
+    "```{figure} /figures/inference-problems-tikz.png\n",
+    ":scale: 30%\n",
+    ":name: fig:dbn-inference\n",
+    "\n",
+    "Illustration of the different kinds of inference in an SSM.\n",
+    " The main kinds of inference for state-space models.\n",
+    "    The shaded region is the interval for which we have data.\n",
+    "    The arrow represents the time step at which we want to perform inference.\n",
+    "    $t$ is the current time,  $T$ is the sequence length,\n",
+    "$\\ell$ is the lag and $h$ is the prediction horizon.\n",
+    "```\n",
+    "\n",
     "In addition  to comuting posterior marginals,\n",
     "we may want to compute the most probable hidden sequence,\n",
     "i.e., the joint MAP estimate\n",
     "```{math}\n",
-    "\\arg \\max_{\\hmmhid_{1:T}} p(\\hmmhid_{1:T}|\\hmmobs_{1:T})\n",
+    "\\arg \\max_{\\hidden_{1:T}} p(\\hidden_{1:T}|\\obs_{1:T})\n",
     "```\n",
     "or sample sequences from the posterior\n",
     "```{math}\n",
-    "\\hmmhid_{1:T} \\sim p(\\hmmhid_{1:T}|\\hmmobs_{1:T})\n",
+    "\\hidden_{1:T} \\sim p(\\hidden_{1:T}|\\obs_{1:T})\n",
     "```\n",
     "\n",
     "Algorithms for all these task are discussed in the following chapters,\n",

+ 65 - 311
_build/html/_sources/chapters/ssm/lds.ipynb

@@ -2,125 +2,35 @@
  "cells": [
   {
    "cell_type": "code",
-   "execution_count": 1,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "Requirement already satisfied: distrax in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (0.0.1)\n",
-      "Collecting distrax\n",
-      "  Downloading distrax-0.1.2-py3-none-any.whl (272 kB)\n",
-      "\u001b[K     |████████████████████████████████| 272 kB 6.9 MB/s eta 0:00:01\n",
-      "\u001b[?25hRequirement already satisfied: jax>=0.1.55 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from distrax) (0.2.11)\n",
-      "Requirement already satisfied: absl-py>=0.9.0 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from distrax) (0.12.0)\n",
-      "Requirement already satisfied: chex>=0.0.7 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from distrax) (0.0.8)\n",
-      "Requirement already satisfied: jaxlib>=0.1.67 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from distrax) (0.1.70)\n",
-      "Requirement already satisfied: numpy>=1.18.0 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from distrax) (1.19.5)\n",
-      "Collecting tensorflow-probability>=0.15.0\n",
-      "  Using cached tensorflow_probability-0.16.0-py2.py3-none-any.whl (6.3 MB)\n",
-      "Requirement already satisfied: six in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from absl-py>=0.9.0->distrax) (1.15.0)\n",
-      "Requirement already satisfied: dm-tree>=0.1.5 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from chex>=0.0.7->distrax) (0.1.6)\n",
-      "Requirement already satisfied: toolz>=0.9.0 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from chex>=0.0.7->distrax) (0.11.1)\n",
-      "Requirement already satisfied: opt-einsum in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from jax>=0.1.55->distrax) (3.3.0)\n",
-      "Requirement already satisfied: flatbuffers<3.0,>=1.12 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from jaxlib>=0.1.67->distrax) (1.12)\n",
-      "Requirement already satisfied: scipy in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from jaxlib>=0.1.67->distrax) (1.6.3)\n",
-      "Requirement already satisfied: cloudpickle>=1.3 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from tensorflow-probability>=0.15.0->distrax) (1.6.0)\n",
-      "Requirement already satisfied: decorator in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from tensorflow-probability>=0.15.0->distrax) (4.4.2)\n",
-      "Requirement already satisfied: gast>=0.3.2 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from tensorflow-probability>=0.15.0->distrax) (0.4.0)\n",
-      "Installing collected packages: tensorflow-probability, distrax\n",
-      "  Attempting uninstall: tensorflow-probability\n",
-      "    Found existing installation: tensorflow-probability 0.13.0\n",
-      "    Uninstalling tensorflow-probability-0.13.0:\n",
-      "      Successfully uninstalled tensorflow-probability-0.13.0\n",
-      "  Attempting uninstall: distrax\n",
-      "    Found existing installation: distrax 0.0.1\n",
-      "    Uninstalling distrax-0.0.1:\n",
-      "      Successfully uninstalled distrax-0.0.1\n",
-      "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
-      "jsl 0.0.0 requires dataclasses, which is not installed.\u001b[0m\n",
-      "Successfully installed distrax-0.1.2 tensorflow-probability-0.16.0\n",
-      "\u001b[33mWARNING: You are using pip version 21.2.4; however, version 22.0.4 is available.\n",
-      "You should consider upgrading via the '/opt/anaconda3/envs/spyder-dev/bin/python -m pip install --upgrade pip' command.\u001b[0m\n",
-      "Note: you may need to restart the kernel to use updated packages.\n"
-     ]
-    }
-   ],
-   "source": [
-    "# meta-data does not work yet in VScode\n",
-    "# https://github.com/microsoft/vscode-jupyter/issues/1121\n",
-    "\n",
-    "{\n",
-    "    \"tags\": [\n",
-    "        \"hide-cell\"\n",
-    "    ]\n",
-    "}\n",
-    "\n",
-    "\n",
-    "### Install necessary libraries\n",
-    "\n",
-    "try:\n",
-    "    import jax\n",
-    "except:\n",
-    "    # For cuda version, see https://github.com/google/jax#installation\n",
-    "    %pip install --upgrade \"jax[cpu]\" \n",
-    "    import jax\n",
-    "\n",
-    "try:\n",
-    "    import distrax\n",
-    "except:\n",
-    "    %pip install --upgrade  distrax\n",
-    "    import distrax\n",
-    "\n",
-    "try:\n",
-    "    import jsl\n",
-    "except:\n",
-    "    %pip install git+https://github.com/probml/jsl\n",
-    "    import jsl\n",
-    "\n",
-    "try:\n",
-    "    import rich\n",
-    "except:\n",
-    "    %pip install rich\n",
-    "    import rich\n",
-    "\n",
-    "\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 2,
+   "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
-    "{\n",
-    "    \"tags\": [\n",
-    "        \"hide-cell\"\n",
-    "    ]\n",
-    "}\n",
-    "\n",
-    "\n",
     "### Import standard libraries\n",
     "\n",
     "import abc\n",
     "from dataclasses import dataclass\n",
     "import functools\n",
+    "from functools import partial\n",
     "import itertools\n",
-    "\n",
-    "from typing import Any, Callable, NamedTuple, Optional, Union, Tuple\n",
-    "\n",
     "import matplotlib.pyplot as plt\n",
     "import numpy as np\n",
-    "\n",
+    "from typing import Any, Callable, NamedTuple, Optional, Union, Tuple\n",
     "\n",
     "import jax\n",
     "import jax.numpy as jnp\n",
     "from jax import lax, vmap, jit, grad\n",
-    "from jax.scipy.special import logit\n",
-    "from jax.nn import softmax\n",
-    "from functools import partial\n",
-    "from jax.random import PRNGKey, split\n",
+    "#from jax.scipy.special import logit\n",
+    "#from jax.nn import softmax\n",
+    "import jax.random as jr\n",
+    "\n",
+    "\n",
+    "\n",
+    "import distrax\n",
+    "import optax\n",
+    "\n",
+    "import jsl\n",
+    "import ssm_jax\n",
     "\n",
     "import inspect\n",
     "import inspect as py_inspect\n",
@@ -136,170 +46,6 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "```{math}\n",
-    "\n",
-    "\\newcommand\\floor[1]{\\lfloor#1\\rfloor}\n",
-    "\n",
-    "\\newcommand{\\real}{\\mathbb{R}}\n",
-    "\n",
-    "% Numbers\n",
-    "\\newcommand{\\vzero}{\\boldsymbol{0}}\n",
-    "\\newcommand{\\vone}{\\boldsymbol{1}}\n",
-    "\n",
-    "% Greek https://www.latex-tutorial.com/symbols/greek-alphabet/\n",
-    "\\newcommand{\\valpha}{\\boldsymbol{\\alpha}}\n",
-    "\\newcommand{\\vbeta}{\\boldsymbol{\\beta}}\n",
-    "\\newcommand{\\vchi}{\\boldsymbol{\\chi}}\n",
-    "\\newcommand{\\vdelta}{\\boldsymbol{\\delta}}\n",
-    "\\newcommand{\\vDelta}{\\boldsymbol{\\Delta}}\n",
-    "\\newcommand{\\vepsilon}{\\boldsymbol{\\epsilon}}\n",
-    "\\newcommand{\\vzeta}{\\boldsymbol{\\zeta}}\n",
-    "\\newcommand{\\vXi}{\\boldsymbol{\\Xi}}\n",
-    "\\newcommand{\\vell}{\\boldsymbol{\\ell}}\n",
-    "\\newcommand{\\veta}{\\boldsymbol{\\eta}}\n",
-    "%\\newcommand{\\vEta}{\\boldsymbol{\\Eta}}\n",
-    "\\newcommand{\\vgamma}{\\boldsymbol{\\gamma}}\n",
-    "\\newcommand{\\vGamma}{\\boldsymbol{\\Gamma}}\n",
-    "\\newcommand{\\vmu}{\\boldsymbol{\\mu}}\n",
-    "\\newcommand{\\vmut}{\\boldsymbol{\\tilde{\\mu}}}\n",
-    "\\newcommand{\\vnu}{\\boldsymbol{\\nu}}\n",
-    "\\newcommand{\\vkappa}{\\boldsymbol{\\kappa}}\n",
-    "\\newcommand{\\vlambda}{\\boldsymbol{\\lambda}}\n",
-    "\\newcommand{\\vLambda}{\\boldsymbol{\\Lambda}}\n",
-    "\\newcommand{\\vLambdaBar}{\\overline{\\vLambda}}\n",
-    "%\\newcommand{\\vnu}{\\boldsymbol{\\nu}}\n",
-    "\\newcommand{\\vomega}{\\boldsymbol{\\omega}}\n",
-    "\\newcommand{\\vOmega}{\\boldsymbol{\\Omega}}\n",
-    "\\newcommand{\\vphi}{\\boldsymbol{\\phi}}\n",
-    "\\newcommand{\\vvarphi}{\\boldsymbol{\\varphi}}\n",
-    "\\newcommand{\\vPhi}{\\boldsymbol{\\Phi}}\n",
-    "\\newcommand{\\vpi}{\\boldsymbol{\\pi}}\n",
-    "\\newcommand{\\vPi}{\\boldsymbol{\\Pi}}\n",
-    "\\newcommand{\\vpsi}{\\boldsymbol{\\psi}}\n",
-    "\\newcommand{\\vPsi}{\\boldsymbol{\\Psi}}\n",
-    "\\newcommand{\\vrho}{\\boldsymbol{\\rho}}\n",
-    "\\newcommand{\\vtheta}{\\boldsymbol{\\theta}}\n",
-    "\\newcommand{\\vthetat}{\\boldsymbol{\\tilde{\\theta}}}\n",
-    "\\newcommand{\\vTheta}{\\boldsymbol{\\Theta}}\n",
-    "\\newcommand{\\vsigma}{\\boldsymbol{\\sigma}}\n",
-    "\\newcommand{\\vSigma}{\\boldsymbol{\\Sigma}}\n",
-    "\\newcommand{\\vSigmat}{\\boldsymbol{\\tilde{\\Sigma}}}\n",
-    "\\newcommand{\\vsigmoid}{\\vsigma}\n",
-    "\\newcommand{\\vtau}{\\boldsymbol{\\tau}}\n",
-    "\\newcommand{\\vxi}{\\boldsymbol{\\xi}}\n",
-    "\n",
-    "\n",
-    "% Lower Roman (Vectors)\n",
-    "\\newcommand{\\va}{\\mathbf{a}}\n",
-    "\\newcommand{\\vb}{\\mathbf{b}}\n",
-    "\\newcommand{\\vBt}{\\mathbf{\\tilde{B}}}\n",
-    "\\newcommand{\\vc}{\\mathbf{c}}\n",
-    "\\newcommand{\\vct}{\\mathbf{\\tilde{c}}}\n",
-    "\\newcommand{\\vd}{\\mathbf{d}}\n",
-    "\\newcommand{\\ve}{\\mathbf{e}}\n",
-    "\\newcommand{\\vf}{\\mathbf{f}}\n",
-    "\\newcommand{\\vg}{\\mathbf{g}}\n",
-    "\\newcommand{\\vh}{\\mathbf{h}}\n",
-    "%\\newcommand{\\myvh}{\\mathbf{h}}\n",
-    "\\newcommand{\\vi}{\\mathbf{i}}\n",
-    "\\newcommand{\\vj}{\\mathbf{j}}\n",
-    "\\newcommand{\\vk}{\\mathbf{k}}\n",
-    "\\newcommand{\\vl}{\\mathbf{l}}\n",
-    "\\newcommand{\\vm}{\\mathbf{m}}\n",
-    "\\newcommand{\\vn}{\\mathbf{n}}\n",
-    "\\newcommand{\\vo}{\\mathbf{o}}\n",
-    "\\newcommand{\\vp}{\\mathbf{p}}\n",
-    "\\newcommand{\\vq}{\\mathbf{q}}\n",
-    "\\newcommand{\\vr}{\\mathbf{r}}\n",
-    "\\newcommand{\\vs}{\\mathbf{s}}\n",
-    "\\newcommand{\\vt}{\\mathbf{t}}\n",
-    "\\newcommand{\\vu}{\\mathbf{u}}\n",
-    "\\newcommand{\\vv}{\\mathbf{v}}\n",
-    "\\newcommand{\\vw}{\\mathbf{w}}\n",
-    "\\newcommand{\\vws}{\\vw_s}\n",
-    "\\newcommand{\\vwt}{\\mathbf{\\tilde{w}}}\n",
-    "\\newcommand{\\vWt}{\\mathbf{\\tilde{W}}}\n",
-    "\\newcommand{\\vwh}{\\hat{\\vw}}\n",
-    "\\newcommand{\\vx}{\\mathbf{x}}\n",
-    "%\\newcommand{\\vx}{\\mathbf{x}}\n",
-    "\\newcommand{\\vxt}{\\mathbf{\\tilde{x}}}\n",
-    "\\newcommand{\\vy}{\\mathbf{y}}\n",
-    "\\newcommand{\\vyt}{\\mathbf{\\tilde{y}}}\n",
-    "\\newcommand{\\vz}{\\mathbf{z}}\n",
-    "%\\newcommand{\\vzt}{\\mathbf{\\tilde{z}}}\n",
-    "\n",
-    "\n",
-    "% Upper Roman (Matrices)\n",
-    "\\newcommand{\\vA}{\\mathbf{A}}\n",
-    "\\newcommand{\\vB}{\\mathbf{B}}\n",
-    "\\newcommand{\\vC}{\\mathbf{C}}\n",
-    "\\newcommand{\\vD}{\\mathbf{D}}\n",
-    "\\newcommand{\\vE}{\\mathbf{E}}\n",
-    "\\newcommand{\\vF}{\\mathbf{F}}\n",
-    "\\newcommand{\\vG}{\\mathbf{G}}\n",
-    "\\newcommand{\\vH}{\\mathbf{H}}\n",
-    "\\newcommand{\\vI}{\\mathbf{I}}\n",
-    "\\newcommand{\\vJ}{\\mathbf{J}}\n",
-    "\\newcommand{\\vK}{\\mathbf{K}}\n",
-    "\\newcommand{\\vL}{\\mathbf{L}}\n",
-    "\\newcommand{\\vM}{\\mathbf{M}}\n",
-    "\\newcommand{\\vMt}{\\mathbf{\\tilde{M}}}\n",
-    "\\newcommand{\\vN}{\\mathbf{N}}\n",
-    "\\newcommand{\\vO}{\\mathbf{O}}\n",
-    "\\newcommand{\\vP}{\\mathbf{P}}\n",
-    "\\newcommand{\\vQ}{\\mathbf{Q}}\n",
-    "\\newcommand{\\vR}{\\mathbf{R}}\n",
-    "\\newcommand{\\vS}{\\mathbf{S}}\n",
-    "\\newcommand{\\vT}{\\mathbf{T}}\n",
-    "\\newcommand{\\vU}{\\mathbf{U}}\n",
-    "\\newcommand{\\vV}{\\mathbf{V}}\n",
-    "\\newcommand{\\vW}{\\mathbf{W}}\n",
-    "\\newcommand{\\vX}{\\mathbf{X}}\n",
-    "%\\newcommand{\\vXs}{\\vX_{\\vs}}\n",
-    "\\newcommand{\\vXs}{\\vX_{s}}\n",
-    "\\newcommand{\\vXt}{\\mathbf{\\tilde{X}}}\n",
-    "\\newcommand{\\vY}{\\mathbf{Y}}\n",
-    "\\newcommand{\\vZ}{\\mathbf{Z}}\n",
-    "\\newcommand{\\vZt}{\\mathbf{\\tilde{Z}}}\n",
-    "\\newcommand{\\vzt}{\\mathbf{\\tilde{z}}}\n",
-    "\n",
-    "\n",
-    "%%%%\n",
-    "\\newcommand{\\hidden}{\\vz}\n",
-    "\\newcommand{\\hid}{\\hidden}\n",
-    "\\newcommand{\\observed}{\\vy}\n",
-    "\\newcommand{\\obs}{\\observed}\n",
-    "\\newcommand{\\inputs}{\\vu}\n",
-    "\\newcommand{\\input}{\\inputs}\n",
-    "\n",
-    "\\newcommand{\\hmmTrans}{\\vA}\n",
-    "\\newcommand{\\hmmObs}{\\vB}\n",
-    "\\newcommand{\\hmmInit}{\\vpi}\n",
-    "\\newcommand{\\hmmhid}{\\hidden}\n",
-    "\\newcommand{\\hmmobs}{\\obs}\n",
-    "\n",
-    "\\newcommand{\\ldsDyn}{\\vA}\n",
-    "\\newcommand{\\ldsObs}{\\vC}\n",
-    "\\newcommand{\\ldsDynIn}{\\vB}\n",
-    "\\newcommand{\\ldsObsIn}{\\vD}\n",
-    "\\newcommand{\\ldsDynNoise}{\\vQ}\n",
-    "\\newcommand{\\ldsObsNoise}{\\vR}\n",
-    "\n",
-    "\\newcommand{\\ssmDynFn}{f}\n",
-    "\\newcommand{\\ssmObsFn}{h}\n",
-    "\n",
-    "\n",
-    "%%%\n",
-    "\\newcommand{\\gauss}{\\mathcal{N}}\n",
-    "\n",
-    "\\newcommand{\\diag}{\\mathrm{diag}}\n",
-    "```\n"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
     "(sec:lds-intro)=\n",
     "# Linear Gaussian SSMs\n",
     "\n",
@@ -310,50 +56,53 @@
     "hidden states and inputs (i.e. there are no auto-regressive dependencies\n",
     "between the observables).\n",
     "We can rewrite this model as \n",
-    "a stochastic nonlinear dynamical system (NLDS)\n",
+    "a stochastic $\\keyword{nonlinear dynamical system}$ or $\\keyword{NLDS}$\n",
     "by defining the distribution of the next hidden state \n",
+    "$\\hidden_t \\in \\real^{\\nhidden}$\n",
     "as a deterministic function of the past state\n",
-    "plus random process noise $\\vepsilon_t$ \n",
+    "$\\hidden_{t-1}$,\n",
+    "the input $\\inputs_t \\in \\real^{\\ninputs}$,\n",
+    "and some random $\\keyword{process noise}$ $\\transNoise_t \\in \\real^{\\nhidden}$ \n",
     "\\begin{align}\n",
-    "\\hmmhid_t &= \\ssmDynFn(\\hmmhid_{t-1}, \\inputs_t, \\vepsilon_t)  \n",
+    "\\hidden_t &= \\dynamicsFn(\\hidden_{t-1}, \\inputs_t, \\transNoise_t)  \n",
     "\\end{align}\n",
-    "where $\\vepsilon_t$ is drawn from the distribution such\n",
+    "where $\\transNoise_t$ is drawn from the distribution such\n",
     "that the induced distribution\n",
-    "on $\\hmmhid_t$ matches $p(\\hmmhid_t|\\hmmhid_{t-1}, \\inputs_t)$.\n",
-    "Similarly we can rewrite the observation distributions\n",
+    "on $\\hidden_t$ matches $p(\\hidden_t|\\hidden_{t-1}, \\inputs_t)$.\n",
+    "Similarly we can rewrite the observation distribution\n",
     "as a deterministic function of the hidden state\n",
-    "plus observation noise $\\veta_t$:\n",
+    "plus $\\keyword{observation noise}$ $\\obsNoise_t \\in \\real^{\\nobs}$:\n",
     "\\begin{align}\n",
-    "\\hmmobs_t &= \\ssmObsFn(\\hmmhid_{t}, \\inputs_t, \\veta_t)\n",
+    "\\obs_t &= \\measurementFn(\\hidden_{t}, \\inputs_t, \\obsNoise_t)\n",
     "\\end{align}\n",
     "\n",
     "\n",
     "If we assume additive Gaussian noise,\n",
     "the model becomes\n",
     "\\begin{align}\n",
-    "\\hmmhid_t &= \\ssmDynFn(\\hmmhid_{t-1}, \\inputs_t) +  \\vepsilon_t  \\\\\n",
-    "\\hmmobs_t &= \\ssmObsFn(\\hmmhid_{t}, \\inputs_t) + \\veta_t\n",
+    "\\hidden_t &= \\dynamicsFn(\\hidden_{t-1}, \\inputs_t) +  \\transNoise_t  \\\\\n",
+    "\\obs_t &= \\measurementFn(\\hidden_{t}, \\inputs_t) + \\obsNoise_t\n",
     "\\end{align}\n",
-    "where $\\vepsilon_t \\sim \\gauss(\\vzero,\\vQ_t)$\n",
-    "and $\\veta_t \\sim \\gauss(\\vzero,\\vR_t)$.\n",
-    "We will call these Gaussian SSMs.\n",
+    "where $\\transNoise_t \\sim \\gauss(\\vzero,\\transCov_t)$\n",
+    "and $\\obsNoise_t \\sim \\gauss(\\vzero,\\obsCov_t)$.\n",
+    "We will call these $\\keyword{Gaussian SSMs}$.\n",
     "\n",
     "If we additionally assume\n",
-    "the transition function $\\ssmDynFn$\n",
-    "and the observation function $\\ssmObsFn$ are both linear,\n",
+    "the transition function $\\dynamicsFn$\n",
+    "and the observation function $\\measurementFn$ are both linear,\n",
     "then we can rewrite the model as follows:\n",
     "\\begin{align}\n",
-    "p(\\hmmhid_t|\\hmmhid_{t-1},\\inputs_t) &= \\gauss(\\hmmhid_t|\\ldsDyn_t \\hmmhid_{t-1}\n",
-    "+ \\ldsDynIn_t \\inputs_t, \\vQ_t)\n",
+    "p(\\hidden_t|\\hidden_{t-1},\\inputs_t) &= \\gauss(\\hidden_t|\\ldsDyn \\hidden_{t-1}\n",
+    "+ \\ldsDynIn \\inputs_t, \\transCov)\n",
     "\\\\\n",
-    "p(\\hmmobs_t|\\hmmhid_t,\\inputs_t) &= \\gauss(\\hmmobs_t|\\ldsObs_t \\hmmhid_{t}\n",
-    "+ \\ldsObsIn_t \\inputs_t, \\vR_t)\n",
+    "p(\\obs_t|\\hidden_t,\\inputs_t) &= \\gauss(\\obs_t|\\ldsObs \\hidden_{t}\n",
+    "+ \\ldsObsIn \\inputs_t, \\obsCov)\n",
     "\\end{align}\n",
     "This is called a \n",
-    "linear-Gaussian state space model\n",
-    "(LG-SSM),\n",
-    "or a\n",
-    "linear dynamical system (LDS).\n",
+    "$\\keyword{linear-Gaussian state space model}$\n",
+    "or $\\keyword{LG-SSM}$;\n",
+    "it is also called \n",
+    "a $\\keyword{linear dynamical system}$ or $\\keyword{LDS}$.\n",
     "We usually assume the parameters are independent of time, in which case\n",
     "the model is said to be time-invariant or homogeneous.\n"
    ]
@@ -372,13 +121,13 @@
     "Consider an object moving in $\\real^2$.\n",
     "Let the state be\n",
     "the position and velocity of the object,\n",
-    "$\\vz_t =\\begin{pmatrix} u_t & \\dot{u}_t & v_t & \\dot{v}_t \\end{pmatrix}$.\n",
+    "$\\hidden_t =\\begin{pmatrix} u_t & \\dot{u}_t & v_t & \\dot{v}_t \\end{pmatrix}$.\n",
     "(We use $u$ and $v$ for the two coordinates,\n",
     "to avoid confusion with the state and observation variables.)\n",
     "If we use Euler discretization,\n",
     "the dynamics become\n",
     "\\begin{align}\n",
-    "\\underbrace{\\begin{pmatrix} u_t\\\\ \\dot{u}_t \\\\ v_t \\\\ \\dot{v}_t \\end{pmatrix}}_{\\vz_t}\n",
+    "\\underbrace{\\begin{pmatrix} u_t\\\\ \\dot{u}_t \\\\ v_t \\\\ \\dot{v}_t \\end{pmatrix}}_{\\hidden_t}\n",
     "  = \n",
     "\\underbrace{\n",
     "\\begin{pmatrix}\n",
@@ -388,20 +137,18 @@
     "0 & 0 & 0 & 1\n",
     "\\end{pmatrix}\n",
     "}_{\\ldsDyn}\n",
-    "\\\n",
-    "\\underbrace{\\begin{pmatrix} u_{t-1} \\\\ \\dot{u}_{t-1} \\\\ v_{t-1} \\\\ \\dot{v}_{t-1} \\end{pmatrix}}_{\\vz_{t-1}}\n",
-    "+ \\vepsilon_t\n",
+    "\n",
+    "\\underbrace{\\begin{pmatrix} u_{t-1} \\\\ \\dot{u}_{t-1} \\\\ v_{t-1} \\\\ \\dot{v}_{t-1} \\end{pmatrix}}_{\\hidden_{t-1}}\n",
+    "+ \\transNoise_t\n",
     "\\end{align}\n",
-    "where $\\vepsilon_t \\sim \\gauss(\\vzero,\\vQ)$ is\n",
+    "where $\\transNoise_t \\sim \\gauss(\\vzero,\\transCov)$ is\n",
     "the process noise.\n",
-    "\n",
-    "Let us assume\n",
+    "We assume\n",
     "that the process noise is \n",
     "a white noise process added to the velocity components\n",
-    "of the state, but not to the location.\n",
-    "(This is known as a random accelerations model.)\n",
-    "We can approximate the resulting process in discrete time by assuming\n",
-    "$\\vQ = \\diag(0, q, 0, q)$.\n",
+    "of the state, but not to the location,\n",
+    "so $\\transCov = \\diag(0, q, 0, q)$.\n",
+    "This is known as a random accelerations model.\n",
     "(See  {cite}`Sarkka13` p60 for a more accurate way\n",
     "to convert the continuous time process to discrete time.)\n",
     "\n",
@@ -411,7 +158,7 @@
     "corrupted by  Gaussian noise.\n",
     "Thus the observation model becomes\n",
     "\\begin{align}\n",
-    "\\underbrace{\\begin{pmatrix}  y_{1,t} \\\\  y_{2,t} \\end{pmatrix}}_{\\vy_t}\n",
+    "\\underbrace{\\begin{pmatrix}  \\obs_{1,t} \\\\  \\obs_{2,t} \\end{pmatrix}}_{\\obs_t}\n",
     "  &=\n",
     "    \\underbrace{\n",
     "    \\begin{pmatrix}\n",
@@ -419,19 +166,19 @@
     "0 & 0 & 1 & 0\n",
     "    \\end{pmatrix}\n",
     "    }_{\\ldsObs}\n",
-    "    \\\n",
-    "\\underbrace{\\begin{pmatrix} u_t\\\\ \\dot{u}_t \\\\ v_t \\\\ \\dot{v}_t \\end{pmatrix}}_{\\vz_t}    \n",
-    " + \\veta_t\n",
+    "    \n",
+    "\\underbrace{\\begin{pmatrix} u_t\\\\ \\dot{u}_t \\\\ v_t \\\\ \\dot{v}_t \\end{pmatrix}}_{\\hidden_t}    \n",
+    " + \\obsNoise_t\n",
     "\\end{align}\n",
-    "where $\\veta_t \\sim \\gauss(\\vzero,\\vR)$ is the \\keywordDef{observation noise}.\n",
+    "where $\\obsNoise_t \\sim \\gauss(\\vzero,\\obsCov)$ is the observation noise.\n",
     "We see that the observation matrix $\\ldsObs$ simply ``extracts'' the\n",
     "relevant parts  of the state vector.\n",
     "\n",
     "Suppose we sample a trajectory and corresponding set\n",
     "of noisy observations from this model,\n",
-    "$(\\vz_{1:T}, \\vy_{1:T}) \\sim p(\\vz,\\vy|\\vtheta)$.\n",
+    "$(\\hidden_{1:T}, \\obs_{1:T}) \\sim p(\\hidden,\\obs|\\params)$.\n",
     "(We use diagonal observation noise,\n",
-    "$\\vR = \\diag(\\sigma_1^2, \\sigma_2^2)$.)\n",
+    "$\\obsCov = \\diag(\\sigma_1^2, \\sigma_2^2)$.)\n",
     "The results are shown below. \n"
    ]
   },
@@ -561,7 +308,14 @@
    "metadata": {},
    "source": [
     "The main task is to infer the hidden states given the noisy\n",
-    "observations, i.e., $p(\\vz|\\vy,\\vtheta)$. We discuss the topic of inference in {ref}`sec:inference`."
+    "observations, i.e., $p(\\hidden_t|\\obs_{1:t},\\params)$\n",
+    "or $p(\\hidden_t|\\obs_{1:T}, \\params)$ in the offline case.\n",
+    "We discuss the topic of inference in {ref}`sec:inference`.\n",
+    "We will usually represent this belief state by a Gaussian distribution,\n",
+    "$p(\\hidden_t|\\obs_{1:s},\\params) = \\gauss(\\hidden_t| \\mean_{t|s}, \\covMat_{t|s})$,\n",
+    "where usually $s=t$ or $s=T$.\n",
+    "Sometimes we use information form, \n",
+    "$p(\\hidden_t|\\obs_{1:s},\\params) = \\gaussInfo(\\hidden_t|\\precMean_{t|s}, \\precMat_{t|s})$."
    ]
   }
  ],

+ 115 - 0
_build/html/_sources/chapters/ssm/learning.ipynb

@@ -0,0 +1,115 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "### Import standard libraries\n",
+    "\n",
+    "import abc\n",
+    "from dataclasses import dataclass\n",
+    "import functools\n",
+    "from functools import partial\n",
+    "import itertools\n",
+    "import matplotlib.pyplot as plt\n",
+    "import numpy as np\n",
+    "from typing import Any, Callable, NamedTuple, Optional, Union, Tuple\n",
+    "\n",
+    "import jax\n",
+    "import jax.numpy as jnp\n",
+    "from jax import lax, vmap, jit, grad\n",
+    "#from jax.scipy.special import logit\n",
+    "#from jax.nn import softmax\n",
+    "import jax.random as jr\n",
+    "\n",
+    "\n",
+    "\n",
+    "import distrax\n",
+    "import optax\n",
+    "\n",
+    "import jsl\n",
+    "import ssm_jax\n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "(sec:learning)=\n",
+    "# Parameter estimation (learning)\n",
+    "\n",
+    "\n",
+    "So far, we have assumed that the parameters $\\params$ of the SSM are known.\n",
+    "For example, in the case of an HMM with categorical observations\n",
+    "we have $\\params = (\\hmmInit, \\hmmTrans, \\hmmObs)$,\n",
+    "and in the case of an LDS, we have $\\params = \n",
+    "(\\ldsTrans, \\ldsObs, \\ldsTransIn, \\ldsObsIn, \\transCov, \\obsCov, \\initMean, \\initCov)$.\n",
+    "If we adopt a Bayesian perspective, we can view these parameters as random variables that are\n",
+    "shared across all time steps, and across all sequences.\n",
+    "This is shown in {numref}`fig:hmm-plates`, where we adopt $\\keyword{plate notation}$\n",
+    "to represent repetitive structure.\n",
+    "\n",
+    "```{figure} /figures/hmmDgmPlatesY.png\n",
+    ":scale: 100%\n",
+    ":name: fig:hmm-plates\n",
+    "\n",
+    "Illustration of an HMM using plate notation, where we show the parameter\n",
+    "nodes which are shared across all the sequences.\n",
+    "```\n",
+    "\n",
+    "Suppose we observe $N$ sequences $\\data = \\{\\obs_{n,1:T_n}: n=1:N\\}$.\n",
+    "Then the goal of $\\keyword{parameter estimation}$, also called $\\keyword{model learning}$\n",
+    "or $\\keyword{model fitting}$, is to approximate the posterior\n",
+    "\\begin{align}\n",
+    "p(\\params|\\data) \\propto p(\\params) \\prod_{n=1}^N p(\\obs_{n,1:T_n} | \\params)\n",
+    "\\end{align}\n",
+    "where $p(\\obs_{n,1:T_n} | \\params)$ is the marginal likelihood of sequence $n$:\n",
+    "\\begin{align}\n",
+    "p(\\obs_{1:T} | \\params) = \\int  p(\\hidden_{1:T}, \\obs_{1:T} | \\params) d\\hidden_{1:T}\n",
+    "\\end{align}\n",
+    "\n",
+    "Since computing the full posterior is computationally difficult, we often settle for computing\n",
+    "a point estimate such as the MAP (maximum a posterior) estimate\n",
+    "\\begin{align}\n",
+    "\\params_{\\map} = \\arg \\max_{\\params} \\log p(\\params) + \\sum_{n=1}^N \\log p(\\obs_{n,1:T_n} | \\params)\n",
+    "\\end{align}\n",
+    "If we ignore the prior term, we get the maximum likelihood estimate or MLE:\n",
+    "\\begin{align}\n",
+    "\\params_{\\mle} = \\arg \\max_{\\params}  \\sum_{n=1}^N \\log p(\\obs_{n,1:T_n} | \\params)\n",
+    "\\end{align}\n",
+    "In practice, the MAP estimate often works better than the MLE, since the prior can regularize\n",
+    "the estimate to ensure the model is numerically stable and does not overfit the training set.\n",
+    "\n",
+    "We will discuss a variety of algorithms for parameter estimation in later chapters.\n",
+    "\n"
+   ]
+  }
+ ],
+ "metadata": {
+  "interpreter": {
+   "hash": "6407c60499271029b671b4ff687c4ed4626355c45fd34c44476827f4be42c4d7"
+  },
+  "kernelspec": {
+   "display_name": "Python 3.9.2 ('spyder-dev')",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.9.2"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}

+ 18 - 189
_build/html/_sources/chapters/ssm/nlds.ipynb

@@ -136,170 +136,6 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "```{math}\n",
-    "\n",
-    "\\newcommand\\floor[1]{\\lfloor#1\\rfloor}\n",
-    "\n",
-    "\\newcommand{\\real}{\\mathbb{R}}\n",
-    "\n",
-    "% Numbers\n",
-    "\\newcommand{\\vzero}{\\boldsymbol{0}}\n",
-    "\\newcommand{\\vone}{\\boldsymbol{1}}\n",
-    "\n",
-    "% Greek https://www.latex-tutorial.com/symbols/greek-alphabet/\n",
-    "\\newcommand{\\valpha}{\\boldsymbol{\\alpha}}\n",
-    "\\newcommand{\\vbeta}{\\boldsymbol{\\beta}}\n",
-    "\\newcommand{\\vchi}{\\boldsymbol{\\chi}}\n",
-    "\\newcommand{\\vdelta}{\\boldsymbol{\\delta}}\n",
-    "\\newcommand{\\vDelta}{\\boldsymbol{\\Delta}}\n",
-    "\\newcommand{\\vepsilon}{\\boldsymbol{\\epsilon}}\n",
-    "\\newcommand{\\vzeta}{\\boldsymbol{\\zeta}}\n",
-    "\\newcommand{\\vXi}{\\boldsymbol{\\Xi}}\n",
-    "\\newcommand{\\vell}{\\boldsymbol{\\ell}}\n",
-    "\\newcommand{\\veta}{\\boldsymbol{\\eta}}\n",
-    "%\\newcommand{\\vEta}{\\boldsymbol{\\Eta}}\n",
-    "\\newcommand{\\vgamma}{\\boldsymbol{\\gamma}}\n",
-    "\\newcommand{\\vGamma}{\\boldsymbol{\\Gamma}}\n",
-    "\\newcommand{\\vmu}{\\boldsymbol{\\mu}}\n",
-    "\\newcommand{\\vmut}{\\boldsymbol{\\tilde{\\mu}}}\n",
-    "\\newcommand{\\vnu}{\\boldsymbol{\\nu}}\n",
-    "\\newcommand{\\vkappa}{\\boldsymbol{\\kappa}}\n",
-    "\\newcommand{\\vlambda}{\\boldsymbol{\\lambda}}\n",
-    "\\newcommand{\\vLambda}{\\boldsymbol{\\Lambda}}\n",
-    "\\newcommand{\\vLambdaBar}{\\overline{\\vLambda}}\n",
-    "%\\newcommand{\\vnu}{\\boldsymbol{\\nu}}\n",
-    "\\newcommand{\\vomega}{\\boldsymbol{\\omega}}\n",
-    "\\newcommand{\\vOmega}{\\boldsymbol{\\Omega}}\n",
-    "\\newcommand{\\vphi}{\\boldsymbol{\\phi}}\n",
-    "\\newcommand{\\vvarphi}{\\boldsymbol{\\varphi}}\n",
-    "\\newcommand{\\vPhi}{\\boldsymbol{\\Phi}}\n",
-    "\\newcommand{\\vpi}{\\boldsymbol{\\pi}}\n",
-    "\\newcommand{\\vPi}{\\boldsymbol{\\Pi}}\n",
-    "\\newcommand{\\vpsi}{\\boldsymbol{\\psi}}\n",
-    "\\newcommand{\\vPsi}{\\boldsymbol{\\Psi}}\n",
-    "\\newcommand{\\vrho}{\\boldsymbol{\\rho}}\n",
-    "\\newcommand{\\vtheta}{\\boldsymbol{\\theta}}\n",
-    "\\newcommand{\\vthetat}{\\boldsymbol{\\tilde{\\theta}}}\n",
-    "\\newcommand{\\vTheta}{\\boldsymbol{\\Theta}}\n",
-    "\\newcommand{\\vsigma}{\\boldsymbol{\\sigma}}\n",
-    "\\newcommand{\\vSigma}{\\boldsymbol{\\Sigma}}\n",
-    "\\newcommand{\\vSigmat}{\\boldsymbol{\\tilde{\\Sigma}}}\n",
-    "\\newcommand{\\vsigmoid}{\\vsigma}\n",
-    "\\newcommand{\\vtau}{\\boldsymbol{\\tau}}\n",
-    "\\newcommand{\\vxi}{\\boldsymbol{\\xi}}\n",
-    "\n",
-    "\n",
-    "% Lower Roman (Vectors)\n",
-    "\\newcommand{\\va}{\\mathbf{a}}\n",
-    "\\newcommand{\\vb}{\\mathbf{b}}\n",
-    "\\newcommand{\\vBt}{\\mathbf{\\tilde{B}}}\n",
-    "\\newcommand{\\vc}{\\mathbf{c}}\n",
-    "\\newcommand{\\vct}{\\mathbf{\\tilde{c}}}\n",
-    "\\newcommand{\\vd}{\\mathbf{d}}\n",
-    "\\newcommand{\\ve}{\\mathbf{e}}\n",
-    "\\newcommand{\\vf}{\\mathbf{f}}\n",
-    "\\newcommand{\\vg}{\\mathbf{g}}\n",
-    "\\newcommand{\\vh}{\\mathbf{h}}\n",
-    "%\\newcommand{\\myvh}{\\mathbf{h}}\n",
-    "\\newcommand{\\vi}{\\mathbf{i}}\n",
-    "\\newcommand{\\vj}{\\mathbf{j}}\n",
-    "\\newcommand{\\vk}{\\mathbf{k}}\n",
-    "\\newcommand{\\vl}{\\mathbf{l}}\n",
-    "\\newcommand{\\vm}{\\mathbf{m}}\n",
-    "\\newcommand{\\vn}{\\mathbf{n}}\n",
-    "\\newcommand{\\vo}{\\mathbf{o}}\n",
-    "\\newcommand{\\vp}{\\mathbf{p}}\n",
-    "\\newcommand{\\vq}{\\mathbf{q}}\n",
-    "\\newcommand{\\vr}{\\mathbf{r}}\n",
-    "\\newcommand{\\vs}{\\mathbf{s}}\n",
-    "\\newcommand{\\vt}{\\mathbf{t}}\n",
-    "\\newcommand{\\vu}{\\mathbf{u}}\n",
-    "\\newcommand{\\vv}{\\mathbf{v}}\n",
-    "\\newcommand{\\vw}{\\mathbf{w}}\n",
-    "\\newcommand{\\vws}{\\vw_s}\n",
-    "\\newcommand{\\vwt}{\\mathbf{\\tilde{w}}}\n",
-    "\\newcommand{\\vWt}{\\mathbf{\\tilde{W}}}\n",
-    "\\newcommand{\\vwh}{\\hat{\\vw}}\n",
-    "\\newcommand{\\vx}{\\mathbf{x}}\n",
-    "%\\newcommand{\\vx}{\\mathbf{x}}\n",
-    "\\newcommand{\\vxt}{\\mathbf{\\tilde{x}}}\n",
-    "\\newcommand{\\vy}{\\mathbf{y}}\n",
-    "\\newcommand{\\vyt}{\\mathbf{\\tilde{y}}}\n",
-    "\\newcommand{\\vz}{\\mathbf{z}}\n",
-    "%\\newcommand{\\vzt}{\\mathbf{\\tilde{z}}}\n",
-    "\n",
-    "\n",
-    "% Upper Roman (Matrices)\n",
-    "\\newcommand{\\vA}{\\mathbf{A}}\n",
-    "\\newcommand{\\vB}{\\mathbf{B}}\n",
-    "\\newcommand{\\vC}{\\mathbf{C}}\n",
-    "\\newcommand{\\vD}{\\mathbf{D}}\n",
-    "\\newcommand{\\vE}{\\mathbf{E}}\n",
-    "\\newcommand{\\vF}{\\mathbf{F}}\n",
-    "\\newcommand{\\vG}{\\mathbf{G}}\n",
-    "\\newcommand{\\vH}{\\mathbf{H}}\n",
-    "\\newcommand{\\vI}{\\mathbf{I}}\n",
-    "\\newcommand{\\vJ}{\\mathbf{J}}\n",
-    "\\newcommand{\\vK}{\\mathbf{K}}\n",
-    "\\newcommand{\\vL}{\\mathbf{L}}\n",
-    "\\newcommand{\\vM}{\\mathbf{M}}\n",
-    "\\newcommand{\\vMt}{\\mathbf{\\tilde{M}}}\n",
-    "\\newcommand{\\vN}{\\mathbf{N}}\n",
-    "\\newcommand{\\vO}{\\mathbf{O}}\n",
-    "\\newcommand{\\vP}{\\mathbf{P}}\n",
-    "\\newcommand{\\vQ}{\\mathbf{Q}}\n",
-    "\\newcommand{\\vR}{\\mathbf{R}}\n",
-    "\\newcommand{\\vS}{\\mathbf{S}}\n",
-    "\\newcommand{\\vT}{\\mathbf{T}}\n",
-    "\\newcommand{\\vU}{\\mathbf{U}}\n",
-    "\\newcommand{\\vV}{\\mathbf{V}}\n",
-    "\\newcommand{\\vW}{\\mathbf{W}}\n",
-    "\\newcommand{\\vX}{\\mathbf{X}}\n",
-    "%\\newcommand{\\vXs}{\\vX_{\\vs}}\n",
-    "\\newcommand{\\vXs}{\\vX_{s}}\n",
-    "\\newcommand{\\vXt}{\\mathbf{\\tilde{X}}}\n",
-    "\\newcommand{\\vY}{\\mathbf{Y}}\n",
-    "\\newcommand{\\vZ}{\\mathbf{Z}}\n",
-    "\\newcommand{\\vZt}{\\mathbf{\\tilde{Z}}}\n",
-    "\\newcommand{\\vzt}{\\mathbf{\\tilde{z}}}\n",
-    "\n",
-    "\n",
-    "%%%%\n",
-    "\\newcommand{\\hidden}{\\vz}\n",
-    "\\newcommand{\\hid}{\\hidden}\n",
-    "\\newcommand{\\observed}{\\vy}\n",
-    "\\newcommand{\\obs}{\\observed}\n",
-    "\\newcommand{\\inputs}{\\vu}\n",
-    "\\newcommand{\\input}{\\inputs}\n",
-    "\n",
-    "\\newcommand{\\hmmTrans}{\\vA}\n",
-    "\\newcommand{\\hmmObs}{\\vB}\n",
-    "\\newcommand{\\hmmInit}{\\vpi}\n",
-    "\\newcommand{\\hmmhid}{\\hidden}\n",
-    "\\newcommand{\\hmmobs}{\\obs}\n",
-    "\n",
-    "\\newcommand{\\ldsDyn}{\\vA}\n",
-    "\\newcommand{\\ldsObs}{\\vC}\n",
-    "\\newcommand{\\ldsDynIn}{\\vB}\n",
-    "\\newcommand{\\ldsObsIn}{\\vD}\n",
-    "\\newcommand{\\ldsDynNoise}{\\vQ}\n",
-    "\\newcommand{\\ldsObsNoise}{\\vR}\n",
-    "\n",
-    "\\newcommand{\\ssmDynFn}{f}\n",
-    "\\newcommand{\\ssmObsFn}{h}\n",
-    "\n",
-    "\n",
-    "%%%\n",
-    "\\newcommand{\\gauss}{\\mathcal{N}}\n",
-    "\n",
-    "\\newcommand{\\diag}{\\mathrm{diag}}\n",
-    "```\n"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
     "(sec:nlds-intro)=\n",
     "# Nonlinear Gaussian SSMs\n",
     "\n",
@@ -307,11 +143,11 @@
     "but the process noise and observation noise are Gaussian.\n",
     "That is, \n",
     "\\begin{align}\n",
-    "\\hmmhid_t &= \\ssmDynFn(\\hmmhid_{t-1}, \\inputs_t) +  \\vepsilon_t  \\\\\n",
-    "\\hmmobs_t &= \\ssmObsFn(\\hmmhid_{t}, \\inputs_t) + \\veta_t\n",
+    "\\hidden_t &= \\dynamicsFn(\\hidden_{t-1}, \\inputs_t) +  \\transNoise_t  \\\\\n",
+    "\\obs_t &= \\obsFn(\\hidden_{t}, \\inputs_t) + \\obsNoise_t\n",
     "\\end{align}\n",
-    "where $\\vepsilon_t \\sim \\gauss(\\vzero,\\vQ_t)$\n",
-    "and $\\veta_t \\sim \\gauss(\\vzero,\\vR_t)$.\n",
+    "where $\\transNoise_t \\sim \\gauss(\\vzero,\\transCov)$\n",
+    "and $\\obsNoise_t \\sim \\gauss(\\vzero,\\obsCov)$.\n",
     "This is a very widely used model class. We give some examples below."
    ]
   },
@@ -337,7 +173,7 @@
     "% Sarka p45, p74\n",
     "Consider a simple pendulum of unit mass and length swinging from\n",
     "a fixed attachment, as in\n",
-    "{numref}`Figure %s <fig:pendulum>`.\n",
+    "{numref}`fig:pendulum`.\n",
     "Such an object is in principle entirely deterministic in its behavior.\n",
     "However, in the real world, there are often unknown forces at work\n",
     "(e.g., air turbulence, friction).\n",
@@ -348,11 +184,11 @@
     "= -g \\sin(\\alpha) + w(t)\n",
     "\\end{align}\n",
     "We can write this as a nonlinear SSM by defining the state to be\n",
-    "$z_1(t) = \\alpha(t)$ and $z_2(t) = d\\alpha(t)/dt$.\n",
+    "$\\hidden_1(t) = \\alpha(t)$ and $\\hidden_2(t) = d\\alpha(t)/dt$.\n",
     "Thus\n",
     "\\begin{align}\n",
-    "\\frac{d \\vz}{dt}\n",
-    "= \\begin{pmatrix} z_2 \\\\ -g \\sin(z_1) \\end{pmatrix}\n",
+    "\\frac{d \\hidden}{dt}\n",
+    "= \\begin{pmatrix} \\hiddenScalar_2 \\\\ -g \\sin(\\hiddenScalar_1) \\end{pmatrix}\n",
     "+ \\begin{pmatrix} 0 \\\\ 1 \\end{pmatrix} w(t)\n",
     "\\end{align}\n",
     "If we discretize this step size $\\Delta$,\n",
@@ -360,18 +196,18 @@
     "formulation {cite}`Sarkka13` p74:\n",
     "\\begin{align}\n",
     "\\underbrace{\n",
-    "  \\begin{pmatrix} z_{1,t} \\\\ z_{2,t} \\end{pmatrix}\n",
-    "  }_{\\hmmhid_t}\n",
+    "  \\begin{pmatrix} \\hiddenScalar_{1,t} \\\\ \\hiddenScalar_{2,t} \\end{pmatrix}\n",
+    "  }_{\\hidden_t}\n",
     "=\n",
     "\\underbrace{\n",
-    "  \\begin{pmatrix} z_{1,t-1} + z_{2,t-1} \\Delta  \\\\\n",
-    "    z_{2,t-1} -g \\sin(z_{1,t-1}) \\Delta  \\end{pmatrix}\n",
-    "  }_{\\vf(\\hmmhid_{t-1})}\n",
-    "+\\vq_{t-1}\n",
+    "  \\begin{pmatrix} \\hiddenScalar_{1,t-1} + \\hiddenScalar_{2,t-1} \\Delta  \\\\\n",
+    "    \\hiddenScalar_{2,t-1} -g \\sin(\\hiddenScalar_{1,t-1}) \\Delta  \\end{pmatrix}\n",
+    "  }_{\\dynamicsFn(\\hidden_{t-1})}\n",
+    "+\\transNoise_{t-1}\n",
     "\\end{align}\n",
-    "where $\\vq_{t-1} \\sim \\gauss(\\vzero,\\vQ)$ with\n",
+    "where $\\transNoise_{t-1} \\sim \\gauss(\\vzero,\\transCov)$ with\n",
     "\\begin{align}\n",
-    "\\vQ = q^c \\begin{pmatrix}\n",
+    "\\transCov = q^c \\begin{pmatrix}\n",
     "  \\frac{\\Delta^3}{3} &   \\frac{\\Delta^2}{2} \\\\\n",
     "  \\frac{\\Delta^2}{2} & \\Delta\n",
     "  \\end{pmatrix}\n",
@@ -382,17 +218,10 @@
     "\n",
     "If we observe the angular position, we\n",
     "get the linear observation model\n",
-    "\\begin{align}\n",
-    "y_t = \\alpha_t + r_t =  h(\\hmmhid_t) + r_t\n",
-    "\\end{align}\n",
-    "where $h(\\hmmhid_t) = z_{1,t}$\n",
-    "and $r_t$ is the observation noise.\n",
+    "$\\obsFn(\\hidden_t)  = \\alpha_t = \\hiddenScalar_{1,t}$.\n",
     "If we only observe  the horizontal position,\n",
     "we get the nonlinear observation model\n",
-    "\\begin{align}\n",
-    "y_t = \\sin(\\alpha_t) + r_t =  h(\\hmmhid_t) + r_t\n",
-    "\\end{align}\n",
-    "where $h(\\hmmhid_t) = \\sin(z_{1,t})$.\n",
+    "$\\obsFn(\\hidden_t) = \\sin(\\alpha_t) = \\sin(\\hiddenScalar_{1,t})$.\n",
     "\n",
     "\n",
     "\n",

+ 21 - 186
_build/html/_sources/chapters/ssm/ssm_intro.ipynb

@@ -4,175 +4,6 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "```{math}\n",
-    "\n",
-    "\\newcommand\\floor[1]{\\lfloor#1\\rfloor}\n",
-    "\n",
-    "\\newcommand{\\real}{\\mathbb{R}}\n",
-    "\n",
-    "% Numbers\n",
-    "\\newcommand{\\vzero}{\\boldsymbol{0}}\n",
-    "\\newcommand{\\vone}{\\boldsymbol{1}}\n",
-    "\n",
-    "% Greek https://www.latex-tutorial.com/symbols/greek-alphabet/\n",
-    "\\newcommand{\\valpha}{\\boldsymbol{\\alpha}}\n",
-    "\\newcommand{\\vbeta}{\\boldsymbol{\\beta}}\n",
-    "\\newcommand{\\vchi}{\\boldsymbol{\\chi}}\n",
-    "\\newcommand{\\vdelta}{\\boldsymbol{\\delta}}\n",
-    "\\newcommand{\\vDelta}{\\boldsymbol{\\Delta}}\n",
-    "\\newcommand{\\vepsilon}{\\boldsymbol{\\epsilon}}\n",
-    "\\newcommand{\\vzeta}{\\boldsymbol{\\zeta}}\n",
-    "\\newcommand{\\vXi}{\\boldsymbol{\\Xi}}\n",
-    "\\newcommand{\\vell}{\\boldsymbol{\\ell}}\n",
-    "\\newcommand{\\veta}{\\boldsymbol{\\eta}}\n",
-    "%\\newcommand{\\vEta}{\\boldsymbol{\\Eta}}\n",
-    "\\newcommand{\\vgamma}{\\boldsymbol{\\gamma}}\n",
-    "\\newcommand{\\vGamma}{\\boldsymbol{\\Gamma}}\n",
-    "\\newcommand{\\vmu}{\\boldsymbol{\\mu}}\n",
-    "\\newcommand{\\vmut}{\\boldsymbol{\\tilde{\\mu}}}\n",
-    "\\newcommand{\\vnu}{\\boldsymbol{\\nu}}\n",
-    "\\newcommand{\\vkappa}{\\boldsymbol{\\kappa}}\n",
-    "\\newcommand{\\vlambda}{\\boldsymbol{\\lambda}}\n",
-    "\\newcommand{\\vLambda}{\\boldsymbol{\\Lambda}}\n",
-    "\\newcommand{\\vLambdaBar}{\\overline{\\vLambda}}\n",
-    "%\\newcommand{\\vnu}{\\boldsymbol{\\nu}}\n",
-    "\\newcommand{\\vomega}{\\boldsymbol{\\omega}}\n",
-    "\\newcommand{\\vOmega}{\\boldsymbol{\\Omega}}\n",
-    "\\newcommand{\\vphi}{\\boldsymbol{\\phi}}\n",
-    "\\newcommand{\\vvarphi}{\\boldsymbol{\\varphi}}\n",
-    "\\newcommand{\\vPhi}{\\boldsymbol{\\Phi}}\n",
-    "\\newcommand{\\vpi}{\\boldsymbol{\\pi}}\n",
-    "\\newcommand{\\vPi}{\\boldsymbol{\\Pi}}\n",
-    "\\newcommand{\\vpsi}{\\boldsymbol{\\psi}}\n",
-    "\\newcommand{\\vPsi}{\\boldsymbol{\\Psi}}\n",
-    "\\newcommand{\\vrho}{\\boldsymbol{\\rho}}\n",
-    "\\newcommand{\\vtheta}{\\boldsymbol{\\theta}}\n",
-    "\\newcommand{\\vthetat}{\\boldsymbol{\\tilde{\\theta}}}\n",
-    "\\newcommand{\\vTheta}{\\boldsymbol{\\Theta}}\n",
-    "\\newcommand{\\vsigma}{\\boldsymbol{\\sigma}}\n",
-    "\\newcommand{\\vSigma}{\\boldsymbol{\\Sigma}}\n",
-    "\\newcommand{\\vSigmat}{\\boldsymbol{\\tilde{\\Sigma}}}\n",
-    "\\newcommand{\\vsigmoid}{\\vsigma}\n",
-    "\\newcommand{\\vtau}{\\boldsymbol{\\tau}}\n",
-    "\\newcommand{\\vxi}{\\boldsymbol{\\xi}}\n",
-    "\n",
-    "\n",
-    "% Lower Roman (Vectors)\n",
-    "\\newcommand{\\va}{\\mathbf{a}}\n",
-    "\\newcommand{\\vb}{\\mathbf{b}}\n",
-    "\\newcommand{\\vBt}{\\mathbf{\\tilde{B}}}\n",
-    "\\newcommand{\\vc}{\\mathbf{c}}\n",
-    "\\newcommand{\\vct}{\\mathbf{\\tilde{c}}}\n",
-    "\\newcommand{\\vd}{\\mathbf{d}}\n",
-    "\\newcommand{\\ve}{\\mathbf{e}}\n",
-    "\\newcommand{\\vf}{\\mathbf{f}}\n",
-    "\\newcommand{\\vg}{\\mathbf{g}}\n",
-    "\\newcommand{\\vh}{\\mathbf{h}}\n",
-    "%\\newcommand{\\myvh}{\\mathbf{h}}\n",
-    "\\newcommand{\\vi}{\\mathbf{i}}\n",
-    "\\newcommand{\\vj}{\\mathbf{j}}\n",
-    "\\newcommand{\\vk}{\\mathbf{k}}\n",
-    "\\newcommand{\\vl}{\\mathbf{l}}\n",
-    "\\newcommand{\\vm}{\\mathbf{m}}\n",
-    "\\newcommand{\\vn}{\\mathbf{n}}\n",
-    "\\newcommand{\\vo}{\\mathbf{o}}\n",
-    "\\newcommand{\\vp}{\\mathbf{p}}\n",
-    "\\newcommand{\\vq}{\\mathbf{q}}\n",
-    "\\newcommand{\\vr}{\\mathbf{r}}\n",
-    "\\newcommand{\\vs}{\\mathbf{s}}\n",
-    "\\newcommand{\\vt}{\\mathbf{t}}\n",
-    "\\newcommand{\\vu}{\\mathbf{u}}\n",
-    "\\newcommand{\\vv}{\\mathbf{v}}\n",
-    "\\newcommand{\\vw}{\\mathbf{w}}\n",
-    "\\newcommand{\\vws}{\\vw_s}\n",
-    "\\newcommand{\\vwt}{\\mathbf{\\tilde{w}}}\n",
-    "\\newcommand{\\vWt}{\\mathbf{\\tilde{W}}}\n",
-    "\\newcommand{\\vwh}{\\hat{\\vw}}\n",
-    "\\newcommand{\\vx}{\\mathbf{x}}\n",
-    "%\\newcommand{\\vx}{\\mathbf{x}}\n",
-    "\\newcommand{\\vxt}{\\mathbf{\\tilde{x}}}\n",
-    "\\newcommand{\\vy}{\\mathbf{y}}\n",
-    "\\newcommand{\\vyt}{\\mathbf{\\tilde{y}}}\n",
-    "\\newcommand{\\vz}{\\mathbf{z}}\n",
-    "%\\newcommand{\\vzt}{\\mathbf{\\tilde{z}}}\n",
-    "\n",
-    "\n",
-    "% Upper Roman (Matrices)\n",
-    "\\newcommand{\\vA}{\\mathbf{A}}\n",
-    "\\newcommand{\\vB}{\\mathbf{B}}\n",
-    "\\newcommand{\\vC}{\\mathbf{C}}\n",
-    "\\newcommand{\\vD}{\\mathbf{D}}\n",
-    "\\newcommand{\\vE}{\\mathbf{E}}\n",
-    "\\newcommand{\\vF}{\\mathbf{F}}\n",
-    "\\newcommand{\\vG}{\\mathbf{G}}\n",
-    "\\newcommand{\\vH}{\\mathbf{H}}\n",
-    "\\newcommand{\\vI}{\\mathbf{I}}\n",
-    "\\newcommand{\\vJ}{\\mathbf{J}}\n",
-    "\\newcommand{\\vK}{\\mathbf{K}}\n",
-    "\\newcommand{\\vL}{\\mathbf{L}}\n",
-    "\\newcommand{\\vM}{\\mathbf{M}}\n",
-    "\\newcommand{\\vMt}{\\mathbf{\\tilde{M}}}\n",
-    "\\newcommand{\\vN}{\\mathbf{N}}\n",
-    "\\newcommand{\\vO}{\\mathbf{O}}\n",
-    "\\newcommand{\\vP}{\\mathbf{P}}\n",
-    "\\newcommand{\\vQ}{\\mathbf{Q}}\n",
-    "\\newcommand{\\vR}{\\mathbf{R}}\n",
-    "\\newcommand{\\vS}{\\mathbf{S}}\n",
-    "\\newcommand{\\vT}{\\mathbf{T}}\n",
-    "\\newcommand{\\vU}{\\mathbf{U}}\n",
-    "\\newcommand{\\vV}{\\mathbf{V}}\n",
-    "\\newcommand{\\vW}{\\mathbf{W}}\n",
-    "\\newcommand{\\vX}{\\mathbf{X}}\n",
-    "%\\newcommand{\\vXs}{\\vX_{\\vs}}\n",
-    "\\newcommand{\\vXs}{\\vX_{s}}\n",
-    "\\newcommand{\\vXt}{\\mathbf{\\tilde{X}}}\n",
-    "\\newcommand{\\vY}{\\mathbf{Y}}\n",
-    "\\newcommand{\\vZ}{\\mathbf{Z}}\n",
-    "\\newcommand{\\vZt}{\\mathbf{\\tilde{Z}}}\n",
-    "\\newcommand{\\vzt}{\\mathbf{\\tilde{z}}}\n",
-    "\n",
-    "\n",
-    "%%%%\n",
-    "\\newcommand{\\hidden}{\\vz}\n",
-    "\\newcommand{\\hid}{\\hidden}\n",
-    "\\newcommand{\\observed}{\\vy}\n",
-    "\\newcommand{\\obs}{\\observed}\n",
-    "\\newcommand{\\inputs}{\\vu}\n",
-    "\\newcommand{\\input}{\\inputs}\n",
-    "\n",
-    "\\newcommand{\\hmmTrans}{\\vA}\n",
-    "\\newcommand{\\hmmObs}{\\vB}\n",
-    "\\newcommand{\\hmmInit}{\\vpi}\n",
-    "\\newcommand{\\hmmhid}{\\hidden}\n",
-    "\\newcommand{\\hmmobs}{\\obs}\n",
-    "\n",
-    "\\newcommand{\\ldsDyn}{\\vA}\n",
-    "\\newcommand{\\ldsObs}{\\vC}\n",
-    "\\newcommand{\\ldsDynIn}{\\vB}\n",
-    "\\newcommand{\\ldsObsIn}{\\vD}\n",
-    "\\newcommand{\\ldsDynNoise}{\\vQ}\n",
-    "\\newcommand{\\ldsObsNoise}{\\vR}\n",
-    "\n",
-    "\\newcommand{\\ssmDynFn}{f}\n",
-    "\\newcommand{\\ssmObsFn}{h}\n",
-    "\n",
-    "\n",
-    "%%%\n",
-    "\\newcommand{\\gauss}{\\mathcal{N}}\n",
-    "\n",
-    "\\newcommand{\\diag}{\\mathrm{diag}}\n",
-    "```\n"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": []
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
     "(sec:ssm-intro)=\n",
     "# What are State Space Models?\n",
     "\n",
@@ -197,7 +28,7 @@
     "unlike a standard Markov model.\n",
     "\n",
     "```{figure} /figures/SSM-AR-inputs.png\n",
-    ":height: 300px\n",
+    ":height: 150px\n",
     ":name: fig:ssm-ar\n",
     "\n",
     "Illustration of an SSM as a graphical model.\n",
@@ -208,14 +39,14 @@
     "as the following joint distribution:\n",
     "```{math}\n",
     ":label: eq:SSM-ar\n",
-    "p(\\hmmobs_{1:T},\\hmmhid_{1:T}|\\inputs_{1:T})\n",
-    " = \\left[ p(\\hmmhid_1|\\inputs_1) \\prod_{t=2}^{T}\n",
-    " p(\\hmmhid_t|\\hmmhid_{t-1},\\inputs_t) \\right]\n",
-    " \\left[ \\prod_{t=1}^T p(\\hmmobs_t|\\hmmhid_t, \\inputs_t, \\hmmobs_{t-1}) \\right]\n",
+    "p(\\obs_{1:T},\\hidden_{1:T}|\\inputs_{1:T})\n",
+    " = \\left[ p(\\hidden_1|\\inputs_1) \\prod_{t=2}^{T}\n",
+    " p(\\hidden_t|\\hidden_{t-1},\\inputs_t) \\right]\n",
+    " \\left[ \\prod_{t=1}^T p(\\obs_t|\\hidden_t, \\inputs_t, \\obs_{t-1}) \\right]\n",
     "```\n",
-    "where $p(\\hmmhid_t|\\hmmhid_{t-1},\\inputs_t)$ is the\n",
+    "where $p(\\hidden_t|\\hidden_{t-1},\\inputs_t)$ is the\n",
     "transition model,\n",
-    "$p(\\hmmobs_t|\\hmmhid_t, \\inputs_t, \\hmmobs_{t-1})$ is the\n",
+    "$p(\\obs_t|\\hidden_t, \\inputs_t, \\obs_{t-1})$ is the\n",
     "observation model,\n",
     "and $\\inputs_{t}$ is an optional input or action.\n",
     "See {numref}`fig:ssm-ar` \n",
@@ -228,31 +59,35 @@
     "In this case the joint simplifies to \n",
     "```{math}\n",
     ":label: eq:SSM-input\n",
-    "p(\\hmmobs_{1:T},\\hmmhid_{1:T}|\\inputs_{1:T})\n",
-    " = \\left[ p(\\hmmhid_1|\\inputs_1) \\prod_{t=2}^{T}\n",
-    " p(\\hmmhid_t|\\hmmhid_{t-1},\\inputs_t) \\right]\n",
-    " \\left[ \\prod_{t=1}^T p(\\hmmobs_t|\\hmmhid_t, \\inputs_t) \\right]\n",
+    "p(\\obs_{1:T},\\hidden_{1:T}|\\inputs_{1:T})\n",
+    " = \\left[ p(\\hidden_1|\\inputs_1) \\prod_{t=2}^{T}\n",
+    " p(\\hidden_t|\\hidden_{t-1},\\inputs_t) \\right]\n",
+    " \\left[ \\prod_{t=1}^T p(\\obs_t|\\hidden_t, \\inputs_t) \\right]\n",
     "```\n",
     "Sometimes there are no external inputs, so the model further\n",
     "simplifies to the following unconditional generative model: \n",
     "```{math}\n",
     ":label: eq:SSM-no-input\n",
-    "p(\\hmmobs_{1:T},\\hmmhid_{1:T})\n",
-    " = \\left[ p(\\hmmhid_1) \\prod_{t=2}^{T}\n",
-    " p(\\hmmhid_t|\\hmmhid_{t-1}) \\right]\n",
-    " \\left[ \\prod_{t=1}^T p(\\hmmobs_t|\\hmmhid_t) \\right]\n",
+    "p(\\obs_{1:T},\\hidden_{1:T})\n",
+    " = \\left[ p(\\hidden_1) \\prod_{t=2}^{T}\n",
+    " p(\\hidden_t|\\hidden_{t-1}) \\right]\n",
+    " \\left[ \\prod_{t=1}^T p(\\obs_t|\\hidden_t) \\right]\n",
     "```\n",
     "See {numref}`ssm-simplified` \n",
     "for an illustration of the corresponding graphical model.\n",
     "\n",
     "\n",
     "```{figure} /figures/SSM-simplified.png\n",
-    ":scale: 100%\n",
+    ":height: 150px\n",
     ":name: ssm-simplified\n",
     "\n",
     "Illustration of a simplified SSM.\n",
     "```\n",
-    "\n"
+    "\n",
+    "SSMs are widely used in many areas of science, engineering, finance, economics, etc.\n",
+    "The main applications are state estimation (i.e., inferring the underlying hidden state of the system given the observation),\n",
+    "forecasting (i.e., predicting future states and observations), and control (i.e., inferring the sequence of inputs that will\n",
+    "give rise to a desired target state). We will discuss these applications in later chapters."
    ]
   },
   {

+ 11 - 38
_build/html/bib.html

@@ -93,11 +93,6 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
 </ul>
 <ul class="current nav bd-sidenav">
- <li class="toctree-l1">
-  <a class="reference internal" href="chapters/scratch.html">
-   Scratchpad
-  </a>
- </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="chapters/ssm/ssm_index.html">
    State Space Models
@@ -130,7 +125,12 @@ const thebe_selector_output = ".output, .cell_output"
    </li>
    <li class="toctree-l2">
     <a class="reference internal" href="chapters/ssm/inference.html">
-     Inferential goals
+     States estimation (inference)
+    </a>
+   </li>
+   <li class="toctree-l2">
+    <a class="reference internal" href="chapters/ssm/learning.html">
+     Parameter estimation (learning)
     </a>
    </li>
   </ul>
@@ -174,7 +174,7 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="chapters/lgssm/lgssm_index.html">
-   Inference in linear-Gaussian SSMs
+   Linear-Gaussian SSMs
   </a>
   <input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
   <label for="toctree-checkbox-3">
@@ -283,37 +283,10 @@ const thebe_selector_output = ".output, .cell_output"
    Sequential Monte Carlo
   </a>
  </li>
- <li class="toctree-l1 has-children">
+ <li class="toctree-l1">
   <a class="reference internal" href="chapters/learning/learning_index.html">
    Offline parameter estimation (learning)
   </a>
-  <input class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" type="checkbox"/>
-  <label for="toctree-checkbox-6">
-   <i class="fas fa-chevron-down">
-   </i>
-  </label>
-  <ul>
-   <li class="toctree-l2">
-    <a class="reference internal" href="chapters/learning/em.html">
-     Expectation Maximization (EM)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="chapters/learning/sgd.html">
-     Stochastic Gradient Descent (SGD)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="chapters/learning/vb.html">
-     Variational Bayes (VB)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="chapters/learning/mcmc.html">
-     Markov Chain Monte Carlo (MCMC)
-    </a>
-   </li>
-  </ul>
  </li>
  <li class="toctree-l1">
   <a class="reference internal" href="chapters/tracking/tracking_index.html">
@@ -484,6 +457,9 @@ const thebe_selector_output = ".output, .cell_output"
 <dt class="label" id="id13"><span class="brackets">CWSchon21</span></dt>
 <dd><p>Jarrad Courts, Adrian G Wills, and Thomas B Schön. Gaussian variational state estimation for nonlinear State-Space models. <em>IEEE Trans. Signal Process.</em>, 69:5979–5993, 2021. URL: <a class="reference external" href="https://arxiv.org/abs/2002.02620">https://arxiv.org/abs/2002.02620</a>.</p>
 </dd>
+<dt class="label" id="id35"><span class="brackets">Dev85</span></dt>
+<dd><p>Pierre A Devijver. Baum's forward-backward algorithm revisited. <em>Pattern Recognition Letters</em>, 3(6):369–373, 1985.</p>
+</dd>
 <dt class="label" id="id3"><span class="brackets">DEKM98</span></dt>
 <dd><p>R. Durbin, S. Eddy, A. Krogh, and G. Mitchison. <em>Biological Sequence Analysis: Probabilistic Models of Proteins and Nucleic Acids</em>. Cambridge University Press, 1998.</p>
 </dd>
@@ -511,9 +487,6 @@ const thebe_selector_output = ".output, .cell_output"
 <dt class="label" id="id11"><span class="brackets">GarciaFernandezTSarkka19</span></dt>
 <dd><p>Ángel F García-Fernández, Filip Tronarp, and Simo Särkkä. Gaussian process classification using posterior linearization. <em>IEEE Signal Process. Lett.</em>, 26(5):735–739, May 2019. URL: <a class="reference external" href="http://dx.doi.org/10.1109/LSP.2019.2906929">http://dx.doi.org/10.1109/LSP.2019.2906929</a>.</p>
 </dd>
-<dt class="label" id="id8"><span class="brackets">GH96</span></dt>
-<dd><p>Z. Ghahramani and G. Hinton. Parameter estimation for linear dynamical systems. Technical Report CRG-TR-96-2, Dept. Comp. Sci., Univ. Toronto, 1996. URL: <a class="reference external" href="http://www.gatsby.ucl.ac.uk/~zoubin/papers/tr-96-2.ps.gz">http://www.gatsby.ucl.ac.uk/~zoubin/papers/tr-96-2.ps.gz</a>.</p>
-</dd>
 <dt class="label" id="id7"><span class="brackets">HSarkkaGarciaFernandez21</span></dt>
 <dd><p>Sakira Hassan, Simo Särkkä, and Ángel F García-Fernández. Temporal parallelization of inference in hidden markov models. <em>IEEE Trans. Signal Processing</em>, 69:4875–4887, February 2021. URL: <a class="reference external" href="http://arxiv.org/abs/2102.05743">http://arxiv.org/abs/2102.05743</a>.</p>
 </dd>

+ 8 - 35
_build/html/chapters/adf/adf_index.html

@@ -94,11 +94,6 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
 </ul>
 <ul class="current nav bd-sidenav">
- <li class="toctree-l1">
-  <a class="reference internal" href="../scratch.html">
-   Scratchpad
-  </a>
- </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../ssm/ssm_index.html">
    State Space Models
@@ -131,7 +126,12 @@ const thebe_selector_output = ".output, .cell_output"
    </li>
    <li class="toctree-l2">
     <a class="reference internal" href="../ssm/inference.html">
-     Inferential goals
+     States estimation (inference)
+    </a>
+   </li>
+   <li class="toctree-l2">
+    <a class="reference internal" href="../ssm/learning.html">
+     Parameter estimation (learning)
     </a>
    </li>
   </ul>
@@ -175,7 +175,7 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../lgssm/lgssm_index.html">
-   Inference in linear-Gaussian SSMs
+   Linear-Gaussian SSMs
   </a>
   <input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
   <label for="toctree-checkbox-3">
@@ -284,37 +284,10 @@ const thebe_selector_output = ".output, .cell_output"
    Sequential Monte Carlo
   </a>
  </li>
- <li class="toctree-l1 has-children">
+ <li class="toctree-l1">
   <a class="reference internal" href="../learning/learning_index.html">
    Offline parameter estimation (learning)
   </a>
-  <input class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" type="checkbox"/>
-  <label for="toctree-checkbox-6">
-   <i class="fas fa-chevron-down">
-   </i>
-  </label>
-  <ul>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/em.html">
-     Expectation Maximization (EM)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/sgd.html">
-     Stochastic Gradient Descent (SGD)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/vb.html">
-     Variational Bayes (VB)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/mcmc.html">
-     Markov Chain Monte Carlo (MCMC)
-    </a>
-   </li>
-  </ul>
  </li>
  <li class="toctree-l1">
   <a class="reference internal" href="../tracking/tracking_index.html">

+ 2 - 29
_build/html/chapters/blank.html

@@ -173,7 +173,7 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="lgssm/lgssm_index.html">
-   Inference in linear-Gaussian SSMs
+   Linear-Gaussian SSMs
   </a>
   <input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
   <label for="toctree-checkbox-3">
@@ -282,37 +282,10 @@ const thebe_selector_output = ".output, .cell_output"
    Sequential Monte Carlo
   </a>
  </li>
- <li class="toctree-l1 has-children">
+ <li class="toctree-l1">
   <a class="reference internal" href="learning/learning_index.html">
    Offline parameter estimation (learning)
   </a>
-  <input class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" type="checkbox"/>
-  <label for="toctree-checkbox-6">
-   <i class="fas fa-chevron-down">
-   </i>
-  </label>
-  <ul>
-   <li class="toctree-l2">
-    <a class="reference internal" href="learning/em.html">
-     Expectation Maximization (EM)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="learning/sgd.html">
-     Stochastic Gradient Descent (SGD)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="learning/vb.html">
-     Variational Bayes (VB)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="learning/mcmc.html">
-     Markov Chain Monte Carlo (MCMC)
-    </a>
-   </li>
-  </ul>
  </li>
  <li class="toctree-l1">
   <a class="reference internal" href="tracking/tracking_index.html">

+ 8 - 35
_build/html/chapters/bnp/bnp_index.html

@@ -94,11 +94,6 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
 </ul>
 <ul class="current nav bd-sidenav">
- <li class="toctree-l1">
-  <a class="reference internal" href="../scratch.html">
-   Scratchpad
-  </a>
- </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../ssm/ssm_index.html">
    State Space Models
@@ -131,7 +126,12 @@ const thebe_selector_output = ".output, .cell_output"
    </li>
    <li class="toctree-l2">
     <a class="reference internal" href="../ssm/inference.html">
-     Inferential goals
+     States estimation (inference)
+    </a>
+   </li>
+   <li class="toctree-l2">
+    <a class="reference internal" href="../ssm/learning.html">
+     Parameter estimation (learning)
     </a>
    </li>
   </ul>
@@ -175,7 +175,7 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../lgssm/lgssm_index.html">
-   Inference in linear-Gaussian SSMs
+   Linear-Gaussian SSMs
   </a>
   <input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
   <label for="toctree-checkbox-3">
@@ -284,37 +284,10 @@ const thebe_selector_output = ".output, .cell_output"
    Sequential Monte Carlo
   </a>
  </li>
- <li class="toctree-l1 has-children">
+ <li class="toctree-l1">
   <a class="reference internal" href="../learning/learning_index.html">
    Offline parameter estimation (learning)
   </a>
-  <input class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" type="checkbox"/>
-  <label for="toctree-checkbox-6">
-   <i class="fas fa-chevron-down">
-   </i>
-  </label>
-  <ul>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/em.html">
-     Expectation Maximization (EM)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/sgd.html">
-     Stochastic Gradient Descent (SGD)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/vb.html">
-     Variational Bayes (VB)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/mcmc.html">
-     Markov Chain Monte Carlo (MCMC)
-    </a>
-   </li>
-  </ul>
  </li>
  <li class="toctree-l1">
   <a class="reference internal" href="../tracking/tracking_index.html">

+ 8 - 35
_build/html/chapters/changepoint/changepoint_index.html

@@ -94,11 +94,6 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
 </ul>
 <ul class="current nav bd-sidenav">
- <li class="toctree-l1">
-  <a class="reference internal" href="../scratch.html">
-   Scratchpad
-  </a>
- </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../ssm/ssm_index.html">
    State Space Models
@@ -131,7 +126,12 @@ const thebe_selector_output = ".output, .cell_output"
    </li>
    <li class="toctree-l2">
     <a class="reference internal" href="../ssm/inference.html">
-     Inferential goals
+     States estimation (inference)
+    </a>
+   </li>
+   <li class="toctree-l2">
+    <a class="reference internal" href="../ssm/learning.html">
+     Parameter estimation (learning)
     </a>
    </li>
   </ul>
@@ -175,7 +175,7 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../lgssm/lgssm_index.html">
-   Inference in linear-Gaussian SSMs
+   Linear-Gaussian SSMs
   </a>
   <input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
   <label for="toctree-checkbox-3">
@@ -284,37 +284,10 @@ const thebe_selector_output = ".output, .cell_output"
    Sequential Monte Carlo
   </a>
  </li>
- <li class="toctree-l1 has-children">
+ <li class="toctree-l1">
   <a class="reference internal" href="../learning/learning_index.html">
    Offline parameter estimation (learning)
   </a>
-  <input class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" type="checkbox"/>
-  <label for="toctree-checkbox-6">
-   <i class="fas fa-chevron-down">
-   </i>
-  </label>
-  <ul>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/em.html">
-     Expectation Maximization (EM)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/sgd.html">
-     Stochastic Gradient Descent (SGD)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/vb.html">
-     Variational Bayes (VB)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/mcmc.html">
-     Markov Chain Monte Carlo (MCMC)
-    </a>
-   </li>
-  </ul>
  </li>
  <li class="toctree-l1">
   <a class="reference internal" href="../tracking/tracking_index.html">

+ 8 - 35
_build/html/chapters/control/control_index.html

@@ -94,11 +94,6 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
 </ul>
 <ul class="current nav bd-sidenav">
- <li class="toctree-l1">
-  <a class="reference internal" href="../scratch.html">
-   Scratchpad
-  </a>
- </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../ssm/ssm_index.html">
    State Space Models
@@ -131,7 +126,12 @@ const thebe_selector_output = ".output, .cell_output"
    </li>
    <li class="toctree-l2">
     <a class="reference internal" href="../ssm/inference.html">
-     Inferential goals
+     States estimation (inference)
+    </a>
+   </li>
+   <li class="toctree-l2">
+    <a class="reference internal" href="../ssm/learning.html">
+     Parameter estimation (learning)
     </a>
    </li>
   </ul>
@@ -175,7 +175,7 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../lgssm/lgssm_index.html">
-   Inference in linear-Gaussian SSMs
+   Linear-Gaussian SSMs
   </a>
   <input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
   <label for="toctree-checkbox-3">
@@ -284,37 +284,10 @@ const thebe_selector_output = ".output, .cell_output"
    Sequential Monte Carlo
   </a>
  </li>
- <li class="toctree-l1 has-children">
+ <li class="toctree-l1">
   <a class="reference internal" href="../learning/learning_index.html">
    Offline parameter estimation (learning)
   </a>
-  <input class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" type="checkbox"/>
-  <label for="toctree-checkbox-6">
-   <i class="fas fa-chevron-down">
-   </i>
-  </label>
-  <ul>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/em.html">
-     Expectation Maximization (EM)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/sgd.html">
-     Stochastic Gradient Descent (SGD)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/vb.html">
-     Variational Bayes (VB)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/mcmc.html">
-     Markov Chain Monte Carlo (MCMC)
-    </a>
-   </li>
-  </ul>
  </li>
  <li class="toctree-l1">
   <a class="reference internal" href="../tracking/tracking_index.html">

+ 8 - 35
_build/html/chapters/ensemble/ensemble_index.html

@@ -94,11 +94,6 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
 </ul>
 <ul class="current nav bd-sidenav">
- <li class="toctree-l1">
-  <a class="reference internal" href="../scratch.html">
-   Scratchpad
-  </a>
- </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../ssm/ssm_index.html">
    State Space Models
@@ -131,7 +126,12 @@ const thebe_selector_output = ".output, .cell_output"
    </li>
    <li class="toctree-l2">
     <a class="reference internal" href="../ssm/inference.html">
-     Inferential goals
+     States estimation (inference)
+    </a>
+   </li>
+   <li class="toctree-l2">
+    <a class="reference internal" href="../ssm/learning.html">
+     Parameter estimation (learning)
     </a>
    </li>
   </ul>
@@ -175,7 +175,7 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../lgssm/lgssm_index.html">
-   Inference in linear-Gaussian SSMs
+   Linear-Gaussian SSMs
   </a>
   <input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
   <label for="toctree-checkbox-3">
@@ -284,37 +284,10 @@ const thebe_selector_output = ".output, .cell_output"
    Sequential Monte Carlo
   </a>
  </li>
- <li class="toctree-l1 has-children">
+ <li class="toctree-l1">
   <a class="reference internal" href="../learning/learning_index.html">
    Offline parameter estimation (learning)
   </a>
-  <input class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" type="checkbox"/>
-  <label for="toctree-checkbox-6">
-   <i class="fas fa-chevron-down">
-   </i>
-  </label>
-  <ul>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/em.html">
-     Expectation Maximization (EM)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/sgd.html">
-     Stochastic Gradient Descent (SGD)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/vb.html">
-     Variational Bayes (VB)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/mcmc.html">
-     Markov Chain Monte Carlo (MCMC)
-    </a>
-   </li>
-  </ul>
  </li>
  <li class="toctree-l1">
   <a class="reference internal" href="../tracking/tracking_index.html">

+ 8 - 35
_build/html/chapters/extended/extended_filter.html

@@ -94,11 +94,6 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
 </ul>
 <ul class="current nav bd-sidenav">
- <li class="toctree-l1">
-  <a class="reference internal" href="../scratch.html">
-   Scratchpad
-  </a>
- </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../ssm/ssm_index.html">
    State Space Models
@@ -131,7 +126,12 @@ const thebe_selector_output = ".output, .cell_output"
    </li>
    <li class="toctree-l2">
     <a class="reference internal" href="../ssm/inference.html">
-     Inferential goals
+     States estimation (inference)
+    </a>
+   </li>
+   <li class="toctree-l2">
+    <a class="reference internal" href="../ssm/learning.html">
+     Parameter estimation (learning)
     </a>
    </li>
   </ul>
@@ -175,7 +175,7 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../lgssm/lgssm_index.html">
-   Inference in linear-Gaussian SSMs
+   Linear-Gaussian SSMs
   </a>
   <input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
   <label for="toctree-checkbox-3">
@@ -284,37 +284,10 @@ const thebe_selector_output = ".output, .cell_output"
    Sequential Monte Carlo
   </a>
  </li>
- <li class="toctree-l1 has-children">
+ <li class="toctree-l1">
   <a class="reference internal" href="../learning/learning_index.html">
    Offline parameter estimation (learning)
   </a>
-  <input class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" type="checkbox"/>
-  <label for="toctree-checkbox-6">
-   <i class="fas fa-chevron-down">
-   </i>
-  </label>
-  <ul>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/em.html">
-     Expectation Maximization (EM)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/sgd.html">
-     Stochastic Gradient Descent (SGD)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/vb.html">
-     Variational Bayes (VB)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/mcmc.html">
-     Markov Chain Monte Carlo (MCMC)
-    </a>
-   </li>
-  </ul>
  </li>
  <li class="toctree-l1">
   <a class="reference internal" href="../tracking/tracking_index.html">

+ 8 - 35
_build/html/chapters/extended/extended_index.html

@@ -94,11 +94,6 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
 </ul>
 <ul class="current nav bd-sidenav">
- <li class="toctree-l1">
-  <a class="reference internal" href="../scratch.html">
-   Scratchpad
-  </a>
- </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../ssm/ssm_index.html">
    State Space Models
@@ -131,7 +126,12 @@ const thebe_selector_output = ".output, .cell_output"
    </li>
    <li class="toctree-l2">
     <a class="reference internal" href="../ssm/inference.html">
-     Inferential goals
+     States estimation (inference)
+    </a>
+   </li>
+   <li class="toctree-l2">
+    <a class="reference internal" href="../ssm/learning.html">
+     Parameter estimation (learning)
     </a>
    </li>
   </ul>
@@ -175,7 +175,7 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../lgssm/lgssm_index.html">
-   Inference in linear-Gaussian SSMs
+   Linear-Gaussian SSMs
   </a>
   <input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
   <label for="toctree-checkbox-3">
@@ -284,37 +284,10 @@ const thebe_selector_output = ".output, .cell_output"
    Sequential Monte Carlo
   </a>
  </li>
- <li class="toctree-l1 has-children">
+ <li class="toctree-l1">
   <a class="reference internal" href="../learning/learning_index.html">
    Offline parameter estimation (learning)
   </a>
-  <input class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" type="checkbox"/>
-  <label for="toctree-checkbox-6">
-   <i class="fas fa-chevron-down">
-   </i>
-  </label>
-  <ul>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/em.html">
-     Expectation Maximization (EM)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/sgd.html">
-     Stochastic Gradient Descent (SGD)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/vb.html">
-     Variational Bayes (VB)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/mcmc.html">
-     Markov Chain Monte Carlo (MCMC)
-    </a>
-   </li>
-  </ul>
  </li>
  <li class="toctree-l1">
   <a class="reference internal" href="../tracking/tracking_index.html">

+ 8 - 35
_build/html/chapters/extended/extended_parallel.html

@@ -94,11 +94,6 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
 </ul>
 <ul class="current nav bd-sidenav">
- <li class="toctree-l1">
-  <a class="reference internal" href="../scratch.html">
-   Scratchpad
-  </a>
- </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../ssm/ssm_index.html">
    State Space Models
@@ -131,7 +126,12 @@ const thebe_selector_output = ".output, .cell_output"
    </li>
    <li class="toctree-l2">
     <a class="reference internal" href="../ssm/inference.html">
-     Inferential goals
+     States estimation (inference)
+    </a>
+   </li>
+   <li class="toctree-l2">
+    <a class="reference internal" href="../ssm/learning.html">
+     Parameter estimation (learning)
     </a>
    </li>
   </ul>
@@ -175,7 +175,7 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../lgssm/lgssm_index.html">
-   Inference in linear-Gaussian SSMs
+   Linear-Gaussian SSMs
   </a>
   <input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
   <label for="toctree-checkbox-3">
@@ -284,37 +284,10 @@ const thebe_selector_output = ".output, .cell_output"
    Sequential Monte Carlo
   </a>
  </li>
- <li class="toctree-l1 has-children">
+ <li class="toctree-l1">
   <a class="reference internal" href="../learning/learning_index.html">
    Offline parameter estimation (learning)
   </a>
-  <input class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" type="checkbox"/>
-  <label for="toctree-checkbox-6">
-   <i class="fas fa-chevron-down">
-   </i>
-  </label>
-  <ul>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/em.html">
-     Expectation Maximization (EM)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/sgd.html">
-     Stochastic Gradient Descent (SGD)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/vb.html">
-     Variational Bayes (VB)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/mcmc.html">
-     Markov Chain Monte Carlo (MCMC)
-    </a>
-   </li>
-  </ul>
  </li>
  <li class="toctree-l1">
   <a class="reference internal" href="../tracking/tracking_index.html">

+ 8 - 35
_build/html/chapters/extended/extended_smoother.html

@@ -94,11 +94,6 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
 </ul>
 <ul class="current nav bd-sidenav">
- <li class="toctree-l1">
-  <a class="reference internal" href="../scratch.html">
-   Scratchpad
-  </a>
- </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../ssm/ssm_index.html">
    State Space Models
@@ -131,7 +126,12 @@ const thebe_selector_output = ".output, .cell_output"
    </li>
    <li class="toctree-l2">
     <a class="reference internal" href="../ssm/inference.html">
-     Inferential goals
+     States estimation (inference)
+    </a>
+   </li>
+   <li class="toctree-l2">
+    <a class="reference internal" href="../ssm/learning.html">
+     Parameter estimation (learning)
     </a>
    </li>
   </ul>
@@ -175,7 +175,7 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../lgssm/lgssm_index.html">
-   Inference in linear-Gaussian SSMs
+   Linear-Gaussian SSMs
   </a>
   <input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
   <label for="toctree-checkbox-3">
@@ -284,37 +284,10 @@ const thebe_selector_output = ".output, .cell_output"
    Sequential Monte Carlo
   </a>
  </li>
- <li class="toctree-l1 has-children">
+ <li class="toctree-l1">
   <a class="reference internal" href="../learning/learning_index.html">
    Offline parameter estimation (learning)
   </a>
-  <input class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" type="checkbox"/>
-  <label for="toctree-checkbox-6">
-   <i class="fas fa-chevron-down">
-   </i>
-  </label>
-  <ul>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/em.html">
-     Expectation Maximization (EM)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/sgd.html">
-     Stochastic Gradient Descent (SGD)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/vb.html">
-     Variational Bayes (VB)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/mcmc.html">
-     Markov Chain Monte Carlo (MCMC)
-    </a>
-   </li>
-  </ul>
  </li>
  <li class="toctree-l1">
   <a class="reference internal" href="../tracking/tracking_index.html">

+ 8 - 35
_build/html/chapters/gp/gp_index.html

@@ -94,11 +94,6 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
 </ul>
 <ul class="current nav bd-sidenav">
- <li class="toctree-l1">
-  <a class="reference internal" href="../scratch.html">
-   Scratchpad
-  </a>
- </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../ssm/ssm_index.html">
    State Space Models
@@ -131,7 +126,12 @@ const thebe_selector_output = ".output, .cell_output"
    </li>
    <li class="toctree-l2">
     <a class="reference internal" href="../ssm/inference.html">
-     Inferential goals
+     States estimation (inference)
+    </a>
+   </li>
+   <li class="toctree-l2">
+    <a class="reference internal" href="../ssm/learning.html">
+     Parameter estimation (learning)
     </a>
    </li>
   </ul>
@@ -175,7 +175,7 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../lgssm/lgssm_index.html">
-   Inference in linear-Gaussian SSMs
+   Linear-Gaussian SSMs
   </a>
   <input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
   <label for="toctree-checkbox-3">
@@ -284,37 +284,10 @@ const thebe_selector_output = ".output, .cell_output"
    Sequential Monte Carlo
   </a>
  </li>
- <li class="toctree-l1 has-children">
+ <li class="toctree-l1">
   <a class="reference internal" href="../learning/learning_index.html">
    Offline parameter estimation (learning)
   </a>
-  <input class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" type="checkbox"/>
-  <label for="toctree-checkbox-6">
-   <i class="fas fa-chevron-down">
-   </i>
-  </label>
-  <ul>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/em.html">
-     Expectation Maximization (EM)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/sgd.html">
-     Stochastic Gradient Descent (SGD)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/vb.html">
-     Variational Bayes (VB)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/mcmc.html">
-     Markov Chain Monte Carlo (MCMC)
-    </a>
-   </li>
-  </ul>
  </li>
  <li class="toctree-l1">
   <a class="reference internal" href="../tracking/tracking_index.html">

+ 2 - 29
_build/html/chapters/hmm/hmm.html

@@ -173,7 +173,7 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../lgssm/lgssm_index.html">
-   Inference in linear-Gaussian SSMs
+   Linear-Gaussian SSMs
   </a>
   <input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
   <label for="toctree-checkbox-3">
@@ -282,37 +282,10 @@ const thebe_selector_output = ".output, .cell_output"
    Sequential Monte Carlo
   </a>
  </li>
- <li class="toctree-l1 has-children">
+ <li class="toctree-l1">
   <a class="reference internal" href="../learning/learning_index.html">
    Offline parameter estimation (learning)
   </a>
-  <input class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" type="checkbox"/>
-  <label for="toctree-checkbox-6">
-   <i class="fas fa-chevron-down">
-   </i>
-  </label>
-  <ul>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/em.html">
-     Expectation Maximization (EM)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/sgd.html">
-     Stochastic Gradient Descent (SGD)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/vb.html">
-     Variational Bayes (VB)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/mcmc.html">
-     Markov Chain Monte Carlo (MCMC)
-    </a>
-   </li>
-  </ul>
  </li>
  <li class="toctree-l1">
   <a class="reference internal" href="../tracking/tracking_index.html">

文件差異過大導致無法顯示
+ 283 - 311
_build/html/chapters/hmm/hmm_filter.html


文件差異過大導致無法顯示
+ 11 - 11
_build/html/chapters/hmm/hmm_index.html


+ 8 - 35
_build/html/chapters/hmm/hmm_parallel.html

@@ -94,11 +94,6 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
 </ul>
 <ul class="current nav bd-sidenav">
- <li class="toctree-l1">
-  <a class="reference internal" href="../scratch.html">
-   Scratchpad
-  </a>
- </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../ssm/ssm_index.html">
    State Space Models
@@ -131,7 +126,12 @@ const thebe_selector_output = ".output, .cell_output"
    </li>
    <li class="toctree-l2">
     <a class="reference internal" href="../ssm/inference.html">
-     Inferential goals
+     States estimation (inference)
+    </a>
+   </li>
+   <li class="toctree-l2">
+    <a class="reference internal" href="../ssm/learning.html">
+     Parameter estimation (learning)
     </a>
    </li>
   </ul>
@@ -175,7 +175,7 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../lgssm/lgssm_index.html">
-   Inference in linear-Gaussian SSMs
+   Linear-Gaussian SSMs
   </a>
   <input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
   <label for="toctree-checkbox-3">
@@ -284,37 +284,10 @@ const thebe_selector_output = ".output, .cell_output"
    Sequential Monte Carlo
   </a>
  </li>
- <li class="toctree-l1 has-children">
+ <li class="toctree-l1">
   <a class="reference internal" href="../learning/learning_index.html">
    Offline parameter estimation (learning)
   </a>
-  <input class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" type="checkbox"/>
-  <label for="toctree-checkbox-6">
-   <i class="fas fa-chevron-down">
-   </i>
-  </label>
-  <ul>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/em.html">
-     Expectation Maximization (EM)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/sgd.html">
-     Stochastic Gradient Descent (SGD)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/vb.html">
-     Variational Bayes (VB)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/mcmc.html">
-     Markov Chain Monte Carlo (MCMC)
-    </a>
-   </li>
-  </ul>
  </li>
  <li class="toctree-l1">
   <a class="reference internal" href="../tracking/tracking_index.html">

+ 10 - 37
_build/html/chapters/hmm/hmm_sampling.html

@@ -53,7 +53,7 @@ const thebe_selector_output = ".output, .cell_output"
     <script async="async" src="../../_static/sphinx-thebe.js"></script>
     <link rel="index" title="Index" href="../../genindex.html" />
     <link rel="search" title="Search" href="../../search.html" />
-    <link rel="next" title="Inference in linear-Gaussian SSMs" href="../lgssm/lgssm_index.html" />
+    <link rel="next" title="Linear-Gaussian SSMs" href="../lgssm/lgssm_index.html" />
     <link rel="prev" title="Parallel HMM smoothing" href="hmm_parallel.html" />
     <meta name="viewport" content="width=device-width, initial-scale=1" />
     <meta name="docsearch:language" content="None">
@@ -94,11 +94,6 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
 </ul>
 <ul class="current nav bd-sidenav">
- <li class="toctree-l1">
-  <a class="reference internal" href="../scratch.html">
-   Scratchpad
-  </a>
- </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../ssm/ssm_index.html">
    State Space Models
@@ -131,7 +126,12 @@ const thebe_selector_output = ".output, .cell_output"
    </li>
    <li class="toctree-l2">
     <a class="reference internal" href="../ssm/inference.html">
-     Inferential goals
+     States estimation (inference)
+    </a>
+   </li>
+   <li class="toctree-l2">
+    <a class="reference internal" href="../ssm/learning.html">
+     Parameter estimation (learning)
     </a>
    </li>
   </ul>
@@ -175,7 +175,7 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../lgssm/lgssm_index.html">
-   Inference in linear-Gaussian SSMs
+   Linear-Gaussian SSMs
   </a>
   <input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
   <label for="toctree-checkbox-3">
@@ -284,37 +284,10 @@ const thebe_selector_output = ".output, .cell_output"
    Sequential Monte Carlo
   </a>
  </li>
- <li class="toctree-l1 has-children">
+ <li class="toctree-l1">
   <a class="reference internal" href="../learning/learning_index.html">
    Offline parameter estimation (learning)
   </a>
-  <input class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" type="checkbox"/>
-  <label for="toctree-checkbox-6">
-   <i class="fas fa-chevron-down">
-   </i>
-  </label>
-  <ul>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/em.html">
-     Expectation Maximization (EM)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/sgd.html">
-     Stochastic Gradient Descent (SGD)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/vb.html">
-     Variational Bayes (VB)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/mcmc.html">
-     Markov Chain Monte Carlo (MCMC)
-    </a>
-   </li>
-  </ul>
  </li>
  <li class="toctree-l1">
   <a class="reference internal" href="../tracking/tracking_index.html">
@@ -522,7 +495,7 @@ const thebe_selector_output = ".output, .cell_output"
     <a class='right-next' id="next-link" href="../lgssm/lgssm_index.html" title="next page">
     <div class="prev-next-info">
         <p class="prev-next-subtitle">next</p>
-        <p class="prev-next-title">Inference in linear-Gaussian SSMs</p>
+        <p class="prev-next-title">Linear-Gaussian SSMs</p>
     </div>
     <i class="fas fa-angle-right"></i>
     </a>

+ 8 - 35
_build/html/chapters/hmm/hmm_smoother.html

@@ -94,11 +94,6 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
 </ul>
 <ul class="current nav bd-sidenav">
- <li class="toctree-l1">
-  <a class="reference internal" href="../scratch.html">
-   Scratchpad
-  </a>
- </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../ssm/ssm_index.html">
    State Space Models
@@ -131,7 +126,12 @@ const thebe_selector_output = ".output, .cell_output"
    </li>
    <li class="toctree-l2">
     <a class="reference internal" href="../ssm/inference.html">
-     Inferential goals
+     States estimation (inference)
+    </a>
+   </li>
+   <li class="toctree-l2">
+    <a class="reference internal" href="../ssm/learning.html">
+     Parameter estimation (learning)
     </a>
    </li>
   </ul>
@@ -175,7 +175,7 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../lgssm/lgssm_index.html">
-   Inference in linear-Gaussian SSMs
+   Linear-Gaussian SSMs
   </a>
   <input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
   <label for="toctree-checkbox-3">
@@ -284,37 +284,10 @@ const thebe_selector_output = ".output, .cell_output"
    Sequential Monte Carlo
   </a>
  </li>
- <li class="toctree-l1 has-children">
+ <li class="toctree-l1">
   <a class="reference internal" href="../learning/learning_index.html">
    Offline parameter estimation (learning)
   </a>
-  <input class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" type="checkbox"/>
-  <label for="toctree-checkbox-6">
-   <i class="fas fa-chevron-down">
-   </i>
-  </label>
-  <ul>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/em.html">
-     Expectation Maximization (EM)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/sgd.html">
-     Stochastic Gradient Descent (SGD)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/vb.html">
-     Variational Bayes (VB)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/mcmc.html">
-     Markov Chain Monte Carlo (MCMC)
-    </a>
-   </li>
-  </ul>
  </li>
  <li class="toctree-l1">
   <a class="reference internal" href="../tracking/tracking_index.html">

+ 8 - 35
_build/html/chapters/hmm/hmm_viterbi.html

@@ -94,11 +94,6 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
 </ul>
 <ul class="current nav bd-sidenav">
- <li class="toctree-l1">
-  <a class="reference internal" href="../scratch.html">
-   Scratchpad
-  </a>
- </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../ssm/ssm_index.html">
    State Space Models
@@ -131,7 +126,12 @@ const thebe_selector_output = ".output, .cell_output"
    </li>
    <li class="toctree-l2">
     <a class="reference internal" href="../ssm/inference.html">
-     Inferential goals
+     States estimation (inference)
+    </a>
+   </li>
+   <li class="toctree-l2">
+    <a class="reference internal" href="../ssm/learning.html">
+     Parameter estimation (learning)
     </a>
    </li>
   </ul>
@@ -175,7 +175,7 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../lgssm/lgssm_index.html">
-   Inference in linear-Gaussian SSMs
+   Linear-Gaussian SSMs
   </a>
   <input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
   <label for="toctree-checkbox-3">
@@ -284,37 +284,10 @@ const thebe_selector_output = ".output, .cell_output"
    Sequential Monte Carlo
   </a>
  </li>
- <li class="toctree-l1 has-children">
+ <li class="toctree-l1">
   <a class="reference internal" href="../learning/learning_index.html">
    Offline parameter estimation (learning)
   </a>
-  <input class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" type="checkbox"/>
-  <label for="toctree-checkbox-6">
-   <i class="fas fa-chevron-down">
-   </i>
-  </label>
-  <ul>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/em.html">
-     Expectation Maximization (EM)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/sgd.html">
-     Stochastic Gradient Descent (SGD)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/vb.html">
-     Variational Bayes (VB)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/mcmc.html">
-     Markov Chain Monte Carlo (MCMC)
-    </a>
-   </li>
-  </ul>
  </li>
  <li class="toctree-l1">
   <a class="reference internal" href="../tracking/tracking_index.html">

+ 6 - 6
_build/html/chapters/learning/learning_index.html

@@ -94,11 +94,6 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
 </ul>
 <ul class="current nav bd-sidenav">
- <li class="toctree-l1">
-  <a class="reference internal" href="../scratch.html">
-   Scratchpad
-  </a>
- </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../ssm/ssm_index.html">
    State Space Models
@@ -131,7 +126,12 @@ const thebe_selector_output = ".output, .cell_output"
    </li>
    <li class="toctree-l2">
     <a class="reference internal" href="../ssm/inference.html">
-     Inferential goals
+     States estimation (inference)
+    </a>
+   </li>
+   <li class="toctree-l2">
+    <a class="reference internal" href="../ssm/learning.html">
+     Parameter estimation (learning)
     </a>
    </li>
   </ul>

+ 10 - 37
_build/html/chapters/lgssm/kalman_filter.html

@@ -54,7 +54,7 @@ const thebe_selector_output = ".output, .cell_output"
     <link rel="index" title="Index" href="../../genindex.html" />
     <link rel="search" title="Search" href="../../search.html" />
     <link rel="next" title="Kalman (RTS) smoother" href="kalman_smoother.html" />
-    <link rel="prev" title="Inference in linear-Gaussian SSMs" href="lgssm_index.html" />
+    <link rel="prev" title="Linear-Gaussian SSMs" href="lgssm_index.html" />
     <meta name="viewport" content="width=device-width, initial-scale=1" />
     <meta name="docsearch:language" content="None">
     
@@ -94,11 +94,6 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
 </ul>
 <ul class="current nav bd-sidenav">
- <li class="toctree-l1">
-  <a class="reference internal" href="../scratch.html">
-   Scratchpad
-  </a>
- </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../ssm/ssm_index.html">
    State Space Models
@@ -131,7 +126,12 @@ const thebe_selector_output = ".output, .cell_output"
    </li>
    <li class="toctree-l2">
     <a class="reference internal" href="../ssm/inference.html">
-     Inferential goals
+     States estimation (inference)
+    </a>
+   </li>
+   <li class="toctree-l2">
+    <a class="reference internal" href="../ssm/learning.html">
+     Parameter estimation (learning)
     </a>
    </li>
   </ul>
@@ -175,7 +175,7 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
  <li class="toctree-l1 current active has-children">
   <a class="reference internal" href="lgssm_index.html">
-   Inference in linear-Gaussian SSMs
+   Linear-Gaussian SSMs
   </a>
   <input checked="" class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
   <label for="toctree-checkbox-3">
@@ -284,37 +284,10 @@ const thebe_selector_output = ".output, .cell_output"
    Sequential Monte Carlo
   </a>
  </li>
- <li class="toctree-l1 has-children">
+ <li class="toctree-l1">
   <a class="reference internal" href="../learning/learning_index.html">
    Offline parameter estimation (learning)
   </a>
-  <input class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" type="checkbox"/>
-  <label for="toctree-checkbox-6">
-   <i class="fas fa-chevron-down">
-   </i>
-  </label>
-  <ul>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/em.html">
-     Expectation Maximization (EM)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/sgd.html">
-     Stochastic Gradient Descent (SGD)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/vb.html">
-     Variational Bayes (VB)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/mcmc.html">
-     Markov Chain Monte Carlo (MCMC)
-    </a>
-   </li>
-  </ul>
  </li>
  <li class="toctree-l1">
   <a class="reference internal" href="../tracking/tracking_index.html">
@@ -516,7 +489,7 @@ const thebe_selector_output = ".output, .cell_output"
         <i class="fas fa-angle-left"></i>
         <div class="prev-next-info">
             <p class="prev-next-subtitle">previous</p>
-            <p class="prev-next-title">Inference in linear-Gaussian SSMs</p>
+            <p class="prev-next-title">Linear-Gaussian SSMs</p>
         </div>
     </a>
     <a class='right-next' id="next-link" href="kalman_smoother.html" title="next page">

+ 8 - 35
_build/html/chapters/lgssm/kalman_parallel.html

@@ -94,11 +94,6 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
 </ul>
 <ul class="current nav bd-sidenav">
- <li class="toctree-l1">
-  <a class="reference internal" href="../scratch.html">
-   Scratchpad
-  </a>
- </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../ssm/ssm_index.html">
    State Space Models
@@ -131,7 +126,12 @@ const thebe_selector_output = ".output, .cell_output"
    </li>
    <li class="toctree-l2">
     <a class="reference internal" href="../ssm/inference.html">
-     Inferential goals
+     States estimation (inference)
+    </a>
+   </li>
+   <li class="toctree-l2">
+    <a class="reference internal" href="../ssm/learning.html">
+     Parameter estimation (learning)
     </a>
    </li>
   </ul>
@@ -175,7 +175,7 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
  <li class="toctree-l1 current active has-children">
   <a class="reference internal" href="lgssm_index.html">
-   Inference in linear-Gaussian SSMs
+   Linear-Gaussian SSMs
   </a>
   <input checked="" class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
   <label for="toctree-checkbox-3">
@@ -284,37 +284,10 @@ const thebe_selector_output = ".output, .cell_output"
    Sequential Monte Carlo
   </a>
  </li>
- <li class="toctree-l1 has-children">
+ <li class="toctree-l1">
   <a class="reference internal" href="../learning/learning_index.html">
    Offline parameter estimation (learning)
   </a>
-  <input class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" type="checkbox"/>
-  <label for="toctree-checkbox-6">
-   <i class="fas fa-chevron-down">
-   </i>
-  </label>
-  <ul>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/em.html">
-     Expectation Maximization (EM)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/sgd.html">
-     Stochastic Gradient Descent (SGD)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/vb.html">
-     Variational Bayes (VB)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/mcmc.html">
-     Markov Chain Monte Carlo (MCMC)
-    </a>
-   </li>
-  </ul>
  </li>
  <li class="toctree-l1">
   <a class="reference internal" href="../tracking/tracking_index.html">

+ 8 - 35
_build/html/chapters/lgssm/kalman_sampling.html

@@ -94,11 +94,6 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
 </ul>
 <ul class="current nav bd-sidenav">
- <li class="toctree-l1">
-  <a class="reference internal" href="../scratch.html">
-   Scratchpad
-  </a>
- </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../ssm/ssm_index.html">
    State Space Models
@@ -131,7 +126,12 @@ const thebe_selector_output = ".output, .cell_output"
    </li>
    <li class="toctree-l2">
     <a class="reference internal" href="../ssm/inference.html">
-     Inferential goals
+     States estimation (inference)
+    </a>
+   </li>
+   <li class="toctree-l2">
+    <a class="reference internal" href="../ssm/learning.html">
+     Parameter estimation (learning)
     </a>
    </li>
   </ul>
@@ -175,7 +175,7 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
  <li class="toctree-l1 current active has-children">
   <a class="reference internal" href="lgssm_index.html">
-   Inference in linear-Gaussian SSMs
+   Linear-Gaussian SSMs
   </a>
   <input checked="" class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
   <label for="toctree-checkbox-3">
@@ -284,37 +284,10 @@ const thebe_selector_output = ".output, .cell_output"
    Sequential Monte Carlo
   </a>
  </li>
- <li class="toctree-l1 has-children">
+ <li class="toctree-l1">
   <a class="reference internal" href="../learning/learning_index.html">
    Offline parameter estimation (learning)
   </a>
-  <input class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" type="checkbox"/>
-  <label for="toctree-checkbox-6">
-   <i class="fas fa-chevron-down">
-   </i>
-  </label>
-  <ul>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/em.html">
-     Expectation Maximization (EM)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/sgd.html">
-     Stochastic Gradient Descent (SGD)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/vb.html">
-     Variational Bayes (VB)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/mcmc.html">
-     Markov Chain Monte Carlo (MCMC)
-    </a>
-   </li>
-  </ul>
  </li>
  <li class="toctree-l1">
   <a class="reference internal" href="../tracking/tracking_index.html">

+ 8 - 35
_build/html/chapters/lgssm/kalman_smoother.html

@@ -94,11 +94,6 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
 </ul>
 <ul class="current nav bd-sidenav">
- <li class="toctree-l1">
-  <a class="reference internal" href="../scratch.html">
-   Scratchpad
-  </a>
- </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../ssm/ssm_index.html">
    State Space Models
@@ -131,7 +126,12 @@ const thebe_selector_output = ".output, .cell_output"
    </li>
    <li class="toctree-l2">
     <a class="reference internal" href="../ssm/inference.html">
-     Inferential goals
+     States estimation (inference)
+    </a>
+   </li>
+   <li class="toctree-l2">
+    <a class="reference internal" href="../ssm/learning.html">
+     Parameter estimation (learning)
     </a>
    </li>
   </ul>
@@ -175,7 +175,7 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
  <li class="toctree-l1 current active has-children">
   <a class="reference internal" href="lgssm_index.html">
-   Inference in linear-Gaussian SSMs
+   Linear-Gaussian SSMs
   </a>
   <input checked="" class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
   <label for="toctree-checkbox-3">
@@ -284,37 +284,10 @@ const thebe_selector_output = ".output, .cell_output"
    Sequential Monte Carlo
   </a>
  </li>
- <li class="toctree-l1 has-children">
+ <li class="toctree-l1">
   <a class="reference internal" href="../learning/learning_index.html">
    Offline parameter estimation (learning)
   </a>
-  <input class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" type="checkbox"/>
-  <label for="toctree-checkbox-6">
-   <i class="fas fa-chevron-down">
-   </i>
-  </label>
-  <ul>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/em.html">
-     Expectation Maximization (EM)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/sgd.html">
-     Stochastic Gradient Descent (SGD)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/vb.html">
-     Variational Bayes (VB)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/mcmc.html">
-     Markov Chain Monte Carlo (MCMC)
-    </a>
-   </li>
-  </ul>
  </li>
  <li class="toctree-l1">
   <a class="reference internal" href="../tracking/tracking_index.html">

+ 6 - 6
_build/html/chapters/lgssm/lgssm_index.html

@@ -94,11 +94,6 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
 </ul>
 <ul class="current nav bd-sidenav">
- <li class="toctree-l1">
-  <a class="reference internal" href="../scratch.html">
-   Scratchpad
-  </a>
- </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../ssm/ssm_index.html">
    State Space Models
@@ -131,7 +126,12 @@ const thebe_selector_output = ".output, .cell_output"
    </li>
    <li class="toctree-l2">
     <a class="reference internal" href="../ssm/inference.html">
-     Inferential goals
+     States estimation (inference)
+    </a>
+   </li>
+   <li class="toctree-l2">
+    <a class="reference internal" href="../ssm/learning.html">
+     Parameter estimation (learning)
     </a>
    </li>
   </ul>

+ 8 - 35
_build/html/chapters/ode/ode_index.html

@@ -94,11 +94,6 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
 </ul>
 <ul class="current nav bd-sidenav">
- <li class="toctree-l1">
-  <a class="reference internal" href="../scratch.html">
-   Scratchpad
-  </a>
- </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../ssm/ssm_index.html">
    State Space Models
@@ -131,7 +126,12 @@ const thebe_selector_output = ".output, .cell_output"
    </li>
    <li class="toctree-l2">
     <a class="reference internal" href="../ssm/inference.html">
-     Inferential goals
+     States estimation (inference)
+    </a>
+   </li>
+   <li class="toctree-l2">
+    <a class="reference internal" href="../ssm/learning.html">
+     Parameter estimation (learning)
     </a>
    </li>
   </ul>
@@ -175,7 +175,7 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../lgssm/lgssm_index.html">
-   Inference in linear-Gaussian SSMs
+   Linear-Gaussian SSMs
   </a>
   <input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
   <label for="toctree-checkbox-3">
@@ -284,37 +284,10 @@ const thebe_selector_output = ".output, .cell_output"
    Sequential Monte Carlo
   </a>
  </li>
- <li class="toctree-l1 has-children">
+ <li class="toctree-l1">
   <a class="reference internal" href="../learning/learning_index.html">
    Offline parameter estimation (learning)
   </a>
-  <input class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" type="checkbox"/>
-  <label for="toctree-checkbox-6">
-   <i class="fas fa-chevron-down">
-   </i>
-  </label>
-  <ul>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/em.html">
-     Expectation Maximization (EM)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/sgd.html">
-     Stochastic Gradient Descent (SGD)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/vb.html">
-     Variational Bayes (VB)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/mcmc.html">
-     Markov Chain Monte Carlo (MCMC)
-    </a>
-   </li>
-  </ul>
  </li>
  <li class="toctree-l1">
   <a class="reference internal" href="../tracking/tracking_index.html">

+ 8 - 35
_build/html/chapters/pf/pf_index.html

@@ -94,11 +94,6 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
 </ul>
 <ul class="current nav bd-sidenav">
- <li class="toctree-l1">
-  <a class="reference internal" href="../scratch.html">
-   Scratchpad
-  </a>
- </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../ssm/ssm_index.html">
    State Space Models
@@ -131,7 +126,12 @@ const thebe_selector_output = ".output, .cell_output"
    </li>
    <li class="toctree-l2">
     <a class="reference internal" href="../ssm/inference.html">
-     Inferential goals
+     States estimation (inference)
+    </a>
+   </li>
+   <li class="toctree-l2">
+    <a class="reference internal" href="../ssm/learning.html">
+     Parameter estimation (learning)
     </a>
    </li>
   </ul>
@@ -175,7 +175,7 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../lgssm/lgssm_index.html">
-   Inference in linear-Gaussian SSMs
+   Linear-Gaussian SSMs
   </a>
   <input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
   <label for="toctree-checkbox-3">
@@ -284,37 +284,10 @@ const thebe_selector_output = ".output, .cell_output"
    Sequential Monte Carlo
   </a>
  </li>
- <li class="toctree-l1 has-children">
+ <li class="toctree-l1">
   <a class="reference internal" href="../learning/learning_index.html">
    Offline parameter estimation (learning)
   </a>
-  <input class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" type="checkbox"/>
-  <label for="toctree-checkbox-6">
-   <i class="fas fa-chevron-down">
-   </i>
-  </label>
-  <ul>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/em.html">
-     Expectation Maximization (EM)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/sgd.html">
-     Stochastic Gradient Descent (SGD)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/vb.html">
-     Variational Bayes (VB)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/mcmc.html">
-     Markov Chain Monte Carlo (MCMC)
-    </a>
-   </li>
-  </ul>
  </li>
  <li class="toctree-l1">
   <a class="reference internal" href="../tracking/tracking_index.html">

+ 8 - 35
_build/html/chapters/postlin/postlin_index.html

@@ -94,11 +94,6 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
 </ul>
 <ul class="current nav bd-sidenav">
- <li class="toctree-l1">
-  <a class="reference internal" href="../scratch.html">
-   Scratchpad
-  </a>
- </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../ssm/ssm_index.html">
    State Space Models
@@ -131,7 +126,12 @@ const thebe_selector_output = ".output, .cell_output"
    </li>
    <li class="toctree-l2">
     <a class="reference internal" href="../ssm/inference.html">
-     Inferential goals
+     States estimation (inference)
+    </a>
+   </li>
+   <li class="toctree-l2">
+    <a class="reference internal" href="../ssm/learning.html">
+     Parameter estimation (learning)
     </a>
    </li>
   </ul>
@@ -175,7 +175,7 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../lgssm/lgssm_index.html">
-   Inference in linear-Gaussian SSMs
+   Linear-Gaussian SSMs
   </a>
   <input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
   <label for="toctree-checkbox-3">
@@ -284,37 +284,10 @@ const thebe_selector_output = ".output, .cell_output"
    Sequential Monte Carlo
   </a>
  </li>
- <li class="toctree-l1 has-children">
+ <li class="toctree-l1">
   <a class="reference internal" href="../learning/learning_index.html">
    Offline parameter estimation (learning)
   </a>
-  <input class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" type="checkbox"/>
-  <label for="toctree-checkbox-6">
-   <i class="fas fa-chevron-down">
-   </i>
-  </label>
-  <ul>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/em.html">
-     Expectation Maximization (EM)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/sgd.html">
-     Stochastic Gradient Descent (SGD)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/vb.html">
-     Variational Bayes (VB)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/mcmc.html">
-     Markov Chain Monte Carlo (MCMC)
-    </a>
-   </li>
-  </ul>
  </li>
  <li class="toctree-l1">
   <a class="reference internal" href="../tracking/tracking_index.html">

+ 8 - 35
_build/html/chapters/quadrature/quadrature_index.html

@@ -94,11 +94,6 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
 </ul>
 <ul class="current nav bd-sidenav">
- <li class="toctree-l1">
-  <a class="reference internal" href="../scratch.html">
-   Scratchpad
-  </a>
- </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../ssm/ssm_index.html">
    State Space Models
@@ -131,7 +126,12 @@ const thebe_selector_output = ".output, .cell_output"
    </li>
    <li class="toctree-l2">
     <a class="reference internal" href="../ssm/inference.html">
-     Inferential goals
+     States estimation (inference)
+    </a>
+   </li>
+   <li class="toctree-l2">
+    <a class="reference internal" href="../ssm/learning.html">
+     Parameter estimation (learning)
     </a>
    </li>
   </ul>
@@ -175,7 +175,7 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../lgssm/lgssm_index.html">
-   Inference in linear-Gaussian SSMs
+   Linear-Gaussian SSMs
   </a>
   <input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
   <label for="toctree-checkbox-3">
@@ -284,37 +284,10 @@ const thebe_selector_output = ".output, .cell_output"
    Sequential Monte Carlo
   </a>
  </li>
- <li class="toctree-l1 has-children">
+ <li class="toctree-l1">
   <a class="reference internal" href="../learning/learning_index.html">
    Offline parameter estimation (learning)
   </a>
-  <input class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" type="checkbox"/>
-  <label for="toctree-checkbox-6">
-   <i class="fas fa-chevron-down">
-   </i>
-  </label>
-  <ul>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/em.html">
-     Expectation Maximization (EM)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/sgd.html">
-     Stochastic Gradient Descent (SGD)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/vb.html">
-     Variational Bayes (VB)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/mcmc.html">
-     Markov Chain Monte Carlo (MCMC)
-    </a>
-   </li>
-  </ul>
  </li>
  <li class="toctree-l1">
   <a class="reference internal" href="../tracking/tracking_index.html">

文件差異過大導致無法顯示
+ 18 - 61
_build/html/chapters/scratch.html


文件差異過大導致無法顯示
+ 856 - 0
_build/html/chapters/scratchpad.html


+ 8 - 35
_build/html/chapters/smc/smc_index.html

@@ -94,11 +94,6 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
 </ul>
 <ul class="current nav bd-sidenav">
- <li class="toctree-l1">
-  <a class="reference internal" href="../scratch.html">
-   Scratchpad
-  </a>
- </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../ssm/ssm_index.html">
    State Space Models
@@ -131,7 +126,12 @@ const thebe_selector_output = ".output, .cell_output"
    </li>
    <li class="toctree-l2">
     <a class="reference internal" href="../ssm/inference.html">
-     Inferential goals
+     States estimation (inference)
+    </a>
+   </li>
+   <li class="toctree-l2">
+    <a class="reference internal" href="../ssm/learning.html">
+     Parameter estimation (learning)
     </a>
    </li>
   </ul>
@@ -175,7 +175,7 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../lgssm/lgssm_index.html">
-   Inference in linear-Gaussian SSMs
+   Linear-Gaussian SSMs
   </a>
   <input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
   <label for="toctree-checkbox-3">
@@ -284,37 +284,10 @@ const thebe_selector_output = ".output, .cell_output"
    Sequential Monte Carlo
   </a>
  </li>
- <li class="toctree-l1 has-children">
+ <li class="toctree-l1">
   <a class="reference internal" href="../learning/learning_index.html">
    Offline parameter estimation (learning)
   </a>
-  <input class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" type="checkbox"/>
-  <label for="toctree-checkbox-6">
-   <i class="fas fa-chevron-down">
-   </i>
-  </label>
-  <ul>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/em.html">
-     Expectation Maximization (EM)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/sgd.html">
-     Stochastic Gradient Descent (SGD)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/vb.html">
-     Variational Bayes (VB)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/mcmc.html">
-     Markov Chain Monte Carlo (MCMC)
-    </a>
-   </li>
-  </ul>
  </li>
  <li class="toctree-l1">
   <a class="reference internal" href="../tracking/tracking_index.html">

文件差異過大導致無法顯示
+ 63 - 255
_build/html/chapters/ssm/hmm.html


文件差異過大導致無法顯示
+ 50 - 230
_build/html/chapters/ssm/inference.html


文件差異過大導致無法顯示
+ 87 - 256
_build/html/chapters/ssm/lds.html


文件差異過大導致無法顯示
+ 596 - 0
_build/html/chapters/ssm/learning.html


文件差異過大導致無法顯示
+ 42 - 212
_build/html/chapters/ssm/nlds.html


+ 11 - 10
_build/html/chapters/ssm/ssm_index.html

@@ -54,7 +54,7 @@ const thebe_selector_output = ".output, .cell_output"
     <link rel="index" title="Index" href="../../genindex.html" />
     <link rel="search" title="Search" href="../../search.html" />
     <link rel="next" title="What are State Space Models?" href="ssm_intro.html" />
-    <link rel="prev" title="Scratchpad" href="../scratch.html" />
+    <link rel="prev" title="State Space Models: A Modern Approach" href="../../root.html" />
     <meta name="viewport" content="width=device-width, initial-scale=1" />
     <meta name="docsearch:language" content="None">
     
@@ -94,11 +94,6 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
 </ul>
 <ul class="current nav bd-sidenav">
- <li class="toctree-l1">
-  <a class="reference internal" href="../scratch.html">
-   Scratchpad
-  </a>
- </li>
  <li class="toctree-l1 current active has-children">
   <a class="current reference internal" href="#">
    State Space Models
@@ -131,7 +126,12 @@ const thebe_selector_output = ".output, .cell_output"
    </li>
    <li class="toctree-l2">
     <a class="reference internal" href="inference.html">
-     Inferential goals
+     States estimation (inference)
+    </a>
+   </li>
+   <li class="toctree-l2">
+    <a class="reference internal" href="learning.html">
+     Parameter estimation (learning)
     </a>
    </li>
   </ul>
@@ -443,7 +443,8 @@ const thebe_selector_output = ".output, .cell_output"
 <li class="toctree-l1"><a class="reference internal" href="hmm.html">Hidden Markov Models</a></li>
 <li class="toctree-l1"><a class="reference internal" href="lds.html">Linear Gaussian SSMs</a></li>
 <li class="toctree-l1"><a class="reference internal" href="nlds.html">Nonlinear Gaussian SSMs</a></li>
-<li class="toctree-l1"><a class="reference internal" href="inference.html">Inferential goals</a></li>
+<li class="toctree-l1"><a class="reference internal" href="inference.html">States estimation (inference)</a></li>
+<li class="toctree-l1"><a class="reference internal" href="learning.html">Parameter estimation (learning)</a></li>
 </ul>
 </div>
 </div>
@@ -473,11 +474,11 @@ const thebe_selector_output = ".output, .cell_output"
             
                 <!-- Previous / next buttons -->
 <div class='prev-next-area'> 
-    <a class='left-prev' id="prev-link" href="../scratch.html" title="previous page">
+    <a class='left-prev' id="prev-link" href="../../root.html" title="previous page">
         <i class="fas fa-angle-left"></i>
         <div class="prev-next-info">
             <p class="prev-next-subtitle">previous</p>
-            <p class="prev-next-title">Scratchpad</p>
+            <p class="prev-next-title">State Space Models: A Modern Approach</p>
         </div>
     </a>
     <a class='right-next' id="next-link" href="ssm_intro.html" title="next page">

文件差異過大導致無法顯示
+ 34 - 191
_build/html/chapters/ssm/ssm_intro.html


+ 8 - 35
_build/html/chapters/timeseries/timeseries_index.html

@@ -94,11 +94,6 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
 </ul>
 <ul class="current nav bd-sidenav">
- <li class="toctree-l1">
-  <a class="reference internal" href="../scratch.html">
-   Scratchpad
-  </a>
- </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../ssm/ssm_index.html">
    State Space Models
@@ -131,7 +126,12 @@ const thebe_selector_output = ".output, .cell_output"
    </li>
    <li class="toctree-l2">
     <a class="reference internal" href="../ssm/inference.html">
-     Inferential goals
+     States estimation (inference)
+    </a>
+   </li>
+   <li class="toctree-l2">
+    <a class="reference internal" href="../ssm/learning.html">
+     Parameter estimation (learning)
     </a>
    </li>
   </ul>
@@ -175,7 +175,7 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../lgssm/lgssm_index.html">
-   Inference in linear-Gaussian SSMs
+   Linear-Gaussian SSMs
   </a>
   <input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
   <label for="toctree-checkbox-3">
@@ -284,37 +284,10 @@ const thebe_selector_output = ".output, .cell_output"
    Sequential Monte Carlo
   </a>
  </li>
- <li class="toctree-l1 has-children">
+ <li class="toctree-l1">
   <a class="reference internal" href="../learning/learning_index.html">
    Offline parameter estimation (learning)
   </a>
-  <input class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" type="checkbox"/>
-  <label for="toctree-checkbox-6">
-   <i class="fas fa-chevron-down">
-   </i>
-  </label>
-  <ul>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/em.html">
-     Expectation Maximization (EM)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/sgd.html">
-     Stochastic Gradient Descent (SGD)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/vb.html">
-     Variational Bayes (VB)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/mcmc.html">
-     Markov Chain Monte Carlo (MCMC)
-    </a>
-   </li>
-  </ul>
  </li>
  <li class="toctree-l1">
   <a class="reference internal" href="../tracking/tracking_index.html">

+ 11 - 38
_build/html/chapters/tracking/tracking_index.html

@@ -54,7 +54,7 @@ const thebe_selector_output = ".output, .cell_output"
     <link rel="index" title="Index" href="../../genindex.html" />
     <link rel="search" title="Search" href="../../search.html" />
     <link rel="next" title="Data assimilation using Ensemble Kalman filter" href="../ensemble/ensemble_index.html" />
-    <link rel="prev" title="Markov Chain Monte Carlo (MCMC)" href="../learning/mcmc.html" />
+    <link rel="prev" title="Offline parameter estimation (learning)" href="../learning/learning_index.html" />
     <meta name="viewport" content="width=device-width, initial-scale=1" />
     <meta name="docsearch:language" content="None">
     
@@ -94,11 +94,6 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
 </ul>
 <ul class="current nav bd-sidenav">
- <li class="toctree-l1">
-  <a class="reference internal" href="../scratch.html">
-   Scratchpad
-  </a>
- </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../ssm/ssm_index.html">
    State Space Models
@@ -131,7 +126,12 @@ const thebe_selector_output = ".output, .cell_output"
    </li>
    <li class="toctree-l2">
     <a class="reference internal" href="../ssm/inference.html">
-     Inferential goals
+     States estimation (inference)
+    </a>
+   </li>
+   <li class="toctree-l2">
+    <a class="reference internal" href="../ssm/learning.html">
+     Parameter estimation (learning)
     </a>
    </li>
   </ul>
@@ -175,7 +175,7 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../lgssm/lgssm_index.html">
-   Inference in linear-Gaussian SSMs
+   Linear-Gaussian SSMs
   </a>
   <input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
   <label for="toctree-checkbox-3">
@@ -284,37 +284,10 @@ const thebe_selector_output = ".output, .cell_output"
    Sequential Monte Carlo
   </a>
  </li>
- <li class="toctree-l1 has-children">
+ <li class="toctree-l1">
   <a class="reference internal" href="../learning/learning_index.html">
    Offline parameter estimation (learning)
   </a>
-  <input class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" type="checkbox"/>
-  <label for="toctree-checkbox-6">
-   <i class="fas fa-chevron-down">
-   </i>
-  </label>
-  <ul>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/em.html">
-     Expectation Maximization (EM)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/sgd.html">
-     Stochastic Gradient Descent (SGD)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/vb.html">
-     Variational Bayes (VB)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/mcmc.html">
-     Markov Chain Monte Carlo (MCMC)
-    </a>
-   </li>
-  </ul>
  </li>
  <li class="toctree-l1 current active">
   <a class="current reference internal" href="#">
@@ -490,11 +463,11 @@ const thebe_selector_output = ".output, .cell_output"
             
                 <!-- Previous / next buttons -->
 <div class='prev-next-area'> 
-    <a class='left-prev' id="prev-link" href="../learning/mcmc.html" title="previous page">
+    <a class='left-prev' id="prev-link" href="../learning/learning_index.html" title="previous page">
         <i class="fas fa-angle-left"></i>
         <div class="prev-next-info">
             <p class="prev-next-subtitle">previous</p>
-            <p class="prev-next-title">Markov Chain Monte Carlo (MCMC)</p>
+            <p class="prev-next-title">Offline parameter estimation (learning)</p>
         </div>
     </a>
     <a class='right-next' id="next-link" href="../ensemble/ensemble_index.html" title="next page">

+ 8 - 35
_build/html/chapters/unscented/unscented_filter.html

@@ -94,11 +94,6 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
 </ul>
 <ul class="current nav bd-sidenav">
- <li class="toctree-l1">
-  <a class="reference internal" href="../scratch.html">
-   Scratchpad
-  </a>
- </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../ssm/ssm_index.html">
    State Space Models
@@ -131,7 +126,12 @@ const thebe_selector_output = ".output, .cell_output"
    </li>
    <li class="toctree-l2">
     <a class="reference internal" href="../ssm/inference.html">
-     Inferential goals
+     States estimation (inference)
+    </a>
+   </li>
+   <li class="toctree-l2">
+    <a class="reference internal" href="../ssm/learning.html">
+     Parameter estimation (learning)
     </a>
    </li>
   </ul>
@@ -175,7 +175,7 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../lgssm/lgssm_index.html">
-   Inference in linear-Gaussian SSMs
+   Linear-Gaussian SSMs
   </a>
   <input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
   <label for="toctree-checkbox-3">
@@ -284,37 +284,10 @@ const thebe_selector_output = ".output, .cell_output"
    Sequential Monte Carlo
   </a>
  </li>
- <li class="toctree-l1 has-children">
+ <li class="toctree-l1">
   <a class="reference internal" href="../learning/learning_index.html">
    Offline parameter estimation (learning)
   </a>
-  <input class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" type="checkbox"/>
-  <label for="toctree-checkbox-6">
-   <i class="fas fa-chevron-down">
-   </i>
-  </label>
-  <ul>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/em.html">
-     Expectation Maximization (EM)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/sgd.html">
-     Stochastic Gradient Descent (SGD)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/vb.html">
-     Variational Bayes (VB)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/mcmc.html">
-     Markov Chain Monte Carlo (MCMC)
-    </a>
-   </li>
-  </ul>
  </li>
  <li class="toctree-l1">
   <a class="reference internal" href="../tracking/tracking_index.html">

+ 8 - 35
_build/html/chapters/unscented/unscented_index.html

@@ -94,11 +94,6 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
 </ul>
 <ul class="current nav bd-sidenav">
- <li class="toctree-l1">
-  <a class="reference internal" href="../scratch.html">
-   Scratchpad
-  </a>
- </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../ssm/ssm_index.html">
    State Space Models
@@ -131,7 +126,12 @@ const thebe_selector_output = ".output, .cell_output"
    </li>
    <li class="toctree-l2">
     <a class="reference internal" href="../ssm/inference.html">
-     Inferential goals
+     States estimation (inference)
+    </a>
+   </li>
+   <li class="toctree-l2">
+    <a class="reference internal" href="../ssm/learning.html">
+     Parameter estimation (learning)
     </a>
    </li>
   </ul>
@@ -175,7 +175,7 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../lgssm/lgssm_index.html">
-   Inference in linear-Gaussian SSMs
+   Linear-Gaussian SSMs
   </a>
   <input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
   <label for="toctree-checkbox-3">
@@ -284,37 +284,10 @@ const thebe_selector_output = ".output, .cell_output"
    Sequential Monte Carlo
   </a>
  </li>
- <li class="toctree-l1 has-children">
+ <li class="toctree-l1">
   <a class="reference internal" href="../learning/learning_index.html">
    Offline parameter estimation (learning)
   </a>
-  <input class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" type="checkbox"/>
-  <label for="toctree-checkbox-6">
-   <i class="fas fa-chevron-down">
-   </i>
-  </label>
-  <ul>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/em.html">
-     Expectation Maximization (EM)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/sgd.html">
-     Stochastic Gradient Descent (SGD)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/vb.html">
-     Variational Bayes (VB)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/mcmc.html">
-     Markov Chain Monte Carlo (MCMC)
-    </a>
-   </li>
-  </ul>
  </li>
  <li class="toctree-l1">
   <a class="reference internal" href="../tracking/tracking_index.html">

+ 8 - 35
_build/html/chapters/unscented/unscented_smoother.html

@@ -94,11 +94,6 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
 </ul>
 <ul class="current nav bd-sidenav">
- <li class="toctree-l1">
-  <a class="reference internal" href="../scratch.html">
-   Scratchpad
-  </a>
- </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../ssm/ssm_index.html">
    State Space Models
@@ -131,7 +126,12 @@ const thebe_selector_output = ".output, .cell_output"
    </li>
    <li class="toctree-l2">
     <a class="reference internal" href="../ssm/inference.html">
-     Inferential goals
+     States estimation (inference)
+    </a>
+   </li>
+   <li class="toctree-l2">
+    <a class="reference internal" href="../ssm/learning.html">
+     Parameter estimation (learning)
     </a>
    </li>
   </ul>
@@ -175,7 +175,7 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../lgssm/lgssm_index.html">
-   Inference in linear-Gaussian SSMs
+   Linear-Gaussian SSMs
   </a>
   <input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
   <label for="toctree-checkbox-3">
@@ -284,37 +284,10 @@ const thebe_selector_output = ".output, .cell_output"
    Sequential Monte Carlo
   </a>
  </li>
- <li class="toctree-l1 has-children">
+ <li class="toctree-l1">
   <a class="reference internal" href="../learning/learning_index.html">
    Offline parameter estimation (learning)
   </a>
-  <input class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" type="checkbox"/>
-  <label for="toctree-checkbox-6">
-   <i class="fas fa-chevron-down">
-   </i>
-  </label>
-  <ul>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/em.html">
-     Expectation Maximization (EM)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/sgd.html">
-     Stochastic Gradient Descent (SGD)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/vb.html">
-     Variational Bayes (VB)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/mcmc.html">
-     Markov Chain Monte Carlo (MCMC)
-    </a>
-   </li>
-  </ul>
  </li>
  <li class="toctree-l1">
   <a class="reference internal" href="../tracking/tracking_index.html">

+ 8 - 35
_build/html/chapters/vi/vi_index.html

@@ -94,11 +94,6 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
 </ul>
 <ul class="current nav bd-sidenav">
- <li class="toctree-l1">
-  <a class="reference internal" href="../scratch.html">
-   Scratchpad
-  </a>
- </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../ssm/ssm_index.html">
    State Space Models
@@ -131,7 +126,12 @@ const thebe_selector_output = ".output, .cell_output"
    </li>
    <li class="toctree-l2">
     <a class="reference internal" href="../ssm/inference.html">
-     Inferential goals
+     States estimation (inference)
+    </a>
+   </li>
+   <li class="toctree-l2">
+    <a class="reference internal" href="../ssm/learning.html">
+     Parameter estimation (learning)
     </a>
    </li>
   </ul>
@@ -175,7 +175,7 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="../lgssm/lgssm_index.html">
-   Inference in linear-Gaussian SSMs
+   Linear-Gaussian SSMs
   </a>
   <input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
   <label for="toctree-checkbox-3">
@@ -284,37 +284,10 @@ const thebe_selector_output = ".output, .cell_output"
    Sequential Monte Carlo
   </a>
  </li>
- <li class="toctree-l1 has-children">
+ <li class="toctree-l1">
   <a class="reference internal" href="../learning/learning_index.html">
    Offline parameter estimation (learning)
   </a>
-  <input class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" type="checkbox"/>
-  <label for="toctree-checkbox-6">
-   <i class="fas fa-chevron-down">
-   </i>
-  </label>
-  <ul>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/em.html">
-     Expectation Maximization (EM)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/sgd.html">
-     Stochastic Gradient Descent (SGD)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/vb.html">
-     Variational Bayes (VB)
-    </a>
-   </li>
-   <li class="toctree-l2">
-    <a class="reference internal" href="../learning/mcmc.html">
-     Markov Chain Monte Carlo (MCMC)
-    </a>
-   </li>
-  </ul>
  </li>
  <li class="toctree-l1">
   <a class="reference internal" href="../tracking/tracking_index.html">

+ 6 - 6
_build/html/genindex.html

@@ -92,11 +92,6 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
 </ul>
 <ul class="nav bd-sidenav">
- <li class="toctree-l1">
-  <a class="reference internal" href="chapters/scratch.html">
-   Scratchpad
-  </a>
- </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="chapters/ssm/ssm_index.html">
    State Space Models
@@ -129,7 +124,12 @@ const thebe_selector_output = ".output, .cell_output"
    </li>
    <li class="toctree-l2">
     <a class="reference internal" href="chapters/ssm/inference.html">
-     Inferential goals
+     States estimation (inference)
+    </a>
+   </li>
+   <li class="toctree-l2">
+    <a class="reference internal" href="chapters/ssm/learning.html">
+     Parameter estimation (learning)
     </a>
    </li>
   </ul>

二進制
_build/html/objects.inv


+ 75 - 0
_build/html/reports/hmm.log

@@ -0,0 +1,75 @@
+Traceback (most recent call last):
+  File "/opt/anaconda3/lib/python3.8/site-packages/jupyter_cache/executors/utils.py", line 51, in single_nb_execution
+    executenb(
+  File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 1087, in execute
+    return NotebookClient(nb=nb, resources=resources, km=km, **kwargs).execute()
+  File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/util.py", line 74, in wrapped
+    return just_run(coro(*args, **kwargs))
+  File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/util.py", line 53, in just_run
+    return loop.run_until_complete(coro)
+  File "/opt/anaconda3/lib/python3.8/asyncio/base_events.py", line 616, in run_until_complete
+    return future.result()
+  File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 540, in async_execute
+    await self.async_execute_cell(
+  File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 832, in async_execute_cell
+    self._check_raise_for_error(cell, exec_reply)
+  File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 740, in _check_raise_for_error
+    raise CellExecutionError.from_cell_and_msg(cell, exec_reply['content'])
+nbclient.exceptions.CellExecutionError: An error occurred while executing the following cell:
+------------------
+
+
+import collections
+def compute_counts(state_seq, nstates):
+    wseq = np.array(state_seq)
+    word_pairs = [pair for pair in zip(wseq[:-1], wseq[1:])]
+    counter_pairs = collections.Counter(word_pairs)
+    counts = np.zeros((nstates, nstates))
+    for (k,v) in counter_pairs.items():
+        counts[k[0], k[1]] = v
+    return counts
+
+
+def normalize(u, axis=0, eps=1e-15):
+    u = jnp.where(u == 0, 0, jnp.where(u < eps, eps, u))
+    c = u.sum(axis=axis)
+    c = jnp.where(c == 0, 1, c)
+    return u / c, c
+
+def normalize_counts(counts):
+    ncounts = vmap(lambda v: normalize(v)[0], in_axes=0)(counts)
+    return ncounts
+
+init_dist = jnp.array([1.0, 0.0])
+trans_mat = jnp.array([[0.7, 0.3], [0.5, 0.5]])
+obs_mat = jnp.eye(2)
+
+hmm = HMM(trans_dist=distrax.Categorical(probs=trans_mat),
+            init_dist=distrax.Categorical(probs=init_dist),
+            obs_dist=distrax.Categorical(probs=obs_mat))
+
+rng_key = jax.random.PRNGKey(0)
+seq_len = 500
+state_seq, _ = hmm.sample(seed=PRNGKey(seed), seq_len=seq_len)
+
+counts = compute_counts(state_seq, nstates=2)
+print(counts)
+
+trans_mat_empirical = normalize_counts(counts)
+print(trans_mat_empirical)
+
+assert jnp.allclose(trans_mat, trans_mat_empirical, atol=1e-1)
+------------------
+
+---------------------------------------------------------------------------
+NameError                                 Traceback (most recent call last)
+<ipython-input-6-f054683fcd82> in <module>
+     30 rng_key = jax.random.PRNGKey(0)
+     31 seq_len = 500
+---> 32 state_seq, _ = hmm.sample(seed=PRNGKey(seed), seq_len=seq_len)
+     33 
+     34 counts = compute_counts(state_seq, nstates=2)
+
+NameError: name 'PRNGKey' is not defined
+NameError: name 'PRNGKey' is not defined
+

+ 37 - 0
_build/html/reports/hmm_filter.log

@@ -0,0 +1,37 @@
+Traceback (most recent call last):
+  File "/opt/anaconda3/lib/python3.8/site-packages/jupyter_cache/executors/utils.py", line 51, in single_nb_execution
+    executenb(
+  File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 1087, in execute
+    return NotebookClient(nb=nb, resources=resources, km=km, **kwargs).execute()
+  File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/util.py", line 74, in wrapped
+    return just_run(coro(*args, **kwargs))
+  File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/util.py", line 53, in just_run
+    return loop.run_until_complete(coro)
+  File "/opt/anaconda3/lib/python3.8/asyncio/base_events.py", line 616, in run_until_complete
+    return future.result()
+  File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 540, in async_execute
+    await self.async_execute_cell(
+  File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 832, in async_execute_cell
+    self._check_raise_for_error(cell, exec_reply)
+  File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 740, in _check_raise_for_error
+    raise CellExecutionError.from_cell_and_msg(cell, exec_reply['content'])
+nbclient.exceptions.CellExecutionError: An error occurred while executing the following cell:
+------------------
+import jsl.hmm.hmm_lib as hmm_lib
+print_source(hmm_lib.hmm_forwards_jax)
+#https://github.com/probml/JSL/blob/main/jsl/hmm/hmm_lib.py#L189
+
+
+------------------
+
+---------------------------------------------------------------------------
+NameError                                 Traceback (most recent call last)
+<ipython-input-1-7da36fbd64d0> in <module>
+      1 import jsl.hmm.hmm_lib as hmm_lib
+----> 2 print_source(hmm_lib.hmm_forwards_jax)
+      3 #https://github.com/probml/JSL/blob/main/jsl/hmm/hmm_lib.py#L189
+      4 
+
+NameError: name 'print_source' is not defined
+NameError: name 'print_source' is not defined
+

+ 104 - 0
_build/html/reports/inference.log

@@ -0,0 +1,104 @@
+Traceback (most recent call last):
+  File "/opt/anaconda3/lib/python3.8/site-packages/jupyter_cache/executors/utils.py", line 51, in single_nb_execution
+    executenb(
+  File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 1087, in execute
+    return NotebookClient(nb=nb, resources=resources, km=km, **kwargs).execute()
+  File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/util.py", line 74, in wrapped
+    return just_run(coro(*args, **kwargs))
+  File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/util.py", line 53, in just_run
+    return loop.run_until_complete(coro)
+  File "/opt/anaconda3/lib/python3.8/asyncio/base_events.py", line 616, in run_until_complete
+    return future.result()
+  File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 540, in async_execute
+    await self.async_execute_cell(
+  File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 832, in async_execute_cell
+    self._check_raise_for_error(cell, exec_reply)
+  File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 740, in _check_raise_for_error
+    raise CellExecutionError.from_cell_and_msg(cell, exec_reply['content'])
+nbclient.exceptions.CellExecutionError: An error occurred while executing the following cell:
+------------------
+# state transition matrix
+A = np.array([
+    [0.95, 0.05],
+    [0.10, 0.90]
+])
+
+# observation matrix
+B = np.array([
+    [1/6, 1/6, 1/6, 1/6, 1/6, 1/6], # fair die
+    [1/10, 1/10, 1/10, 1/10, 1/10, 5/10] # loaded die
+])
+
+pi = np.array([0.5, 0.5])
+
+(nstates, nobs) = np.shape(B)
+
+import distrax
+from distrax import HMM
+
+
+hmm = HMM(trans_dist=distrax.Categorical(probs=A),
+            init_dist=distrax.Categorical(probs=pi),
+            obs_dist=distrax.Categorical(probs=B))
+
+
+seed = 314
+n_samples = 300
+z_hist, x_hist = hmm.sample(seed=PRNGKey(seed), seq_len=n_samples)
+
+------------------
+
+---------------------------------------------------------------------------
+AttributeError                            Traceback (most recent call last)
+<ipython-input-2-155367c5b347> in <module>
+     15 (nstates, nobs) = np.shape(B)
+     16 
+---> 17 import distrax
+     18 from distrax import HMM
+     19 
+
+/opt/anaconda3/lib/python3.8/site-packages/distrax/__init__.py in <module>
+     18 from distrax._src.bijectors.bijector import Bijector
+     19 from distrax._src.bijectors.bijector import BijectorLike
+---> 20 from distrax._src.bijectors.block import Block
+     21 from distrax._src.bijectors.chain import Chain
+     22 from distrax._src.bijectors.gumbel_cdf import GumbelCDF
+
+/opt/anaconda3/lib/python3.8/site-packages/distrax/_src/bijectors/block.py in <module>
+     18 
+     19 from distrax._src.bijectors import bijector as base
+---> 20 from distrax._src.utils import conversion
+     21 from distrax._src.utils import math
+     22 
+
+/opt/anaconda3/lib/python3.8/site-packages/distrax/_src/utils/conversion.py in <module>
+     17 from typing import Optional, Union
+     18 
+---> 19 import chex
+     20 from distrax._src.bijectors import bijector
+     21 from distrax._src.bijectors import bijector_from_tfp
+
+/opt/anaconda3/lib/python3.8/site-packages/chex/__init__.py in <module>
+     15 """Chex: Testing made fun, in JAX!"""
+     16 
+---> 17 from chex._src.asserts import assert_axis_dimension
+     18 from chex._src.asserts import assert_axis_dimension_gt
+     19 from chex._src.asserts import assert_devices_available
+
+/opt/anaconda3/lib/python3.8/site-packages/chex/_src/asserts.py in <module>
+     26 
+     27 from chex._src import asserts_internal as ai
+---> 28 from chex._src import pytypes
+     29 import jax
+     30 import jax.numpy as jnp
+
+/opt/anaconda3/lib/python3.8/site-packages/chex/_src/pytypes.py in <module>
+     37 Shape = Tuple[int, ...]
+     38 
+---> 39 CpuDevice = jax.lib.xla_extension.CpuDevice
+     40 GpuDevice = jax.lib.xla_extension.GpuDevice
+     41 TpuDevice = jax.lib.xla_extension.TpuDevice
+
+AttributeError: module 'jaxlib.xla_extension' has no attribute 'CpuDevice'
+AttributeError: module 'jaxlib.xla_extension' has no attribute 'CpuDevice'
+

+ 154 - 0
_build/html/reports/lds.log

@@ -0,0 +1,154 @@
+Traceback (most recent call last):
+  File "/opt/anaconda3/lib/python3.8/site-packages/jupyter_cache/executors/utils.py", line 51, in single_nb_execution
+    executenb(
+  File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 1087, in execute
+    return NotebookClient(nb=nb, resources=resources, km=km, **kwargs).execute()
+  File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/util.py", line 74, in wrapped
+    return just_run(coro(*args, **kwargs))
+  File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/util.py", line 53, in just_run
+    return loop.run_until_complete(coro)
+  File "/opt/anaconda3/lib/python3.8/asyncio/base_events.py", line 616, in run_until_complete
+    return future.result()
+  File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 540, in async_execute
+    await self.async_execute_cell(
+  File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 832, in async_execute_cell
+    self._check_raise_for_error(cell, exec_reply)
+  File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 740, in _check_raise_for_error
+    raise CellExecutionError.from_cell_and_msg(cell, exec_reply['content'])
+nbclient.exceptions.CellExecutionError: An error occurred while executing the following cell:
+------------------
+# meta-data does not work yet in VScode
+# https://github.com/microsoft/vscode-jupyter/issues/1121
+
+{
+    "tags": [
+        "hide-cell"
+    ]
+}
+
+
+### Install necessary libraries
+
+try:
+    import jax
+except:
+    # For cuda version, see https://github.com/google/jax#installation
+    %pip install --upgrade "jax[cpu]" 
+    import jax
+
+try:
+    import distrax
+except:
+    %pip install --upgrade  distrax
+    import distrax
+
+try:
+    import jsl
+except:
+    %pip install git+https://github.com/probml/jsl
+    import jsl
+
+try:
+    import rich
+except:
+    %pip install rich
+    import rich
+
+
+
+------------------
+
+---------------------------------------------------------------------------
+AttributeError                            Traceback (most recent call last)
+<ipython-input-1-035f9f944d53> in <module>
+     20 try:
+---> 21     import distrax
+     22 except:
+
+/opt/anaconda3/lib/python3.8/site-packages/distrax/__init__.py in <module>
+     19 from distrax._src.bijectors.bijector import BijectorLike
+---> 20 from distrax._src.bijectors.block import Block
+     21 from distrax._src.bijectors.chain import Chain
+
+/opt/anaconda3/lib/python3.8/site-packages/distrax/_src/bijectors/block.py in <module>
+     19 from distrax._src.bijectors import bijector as base
+---> 20 from distrax._src.utils import conversion
+     21 from distrax._src.utils import math
+
+/opt/anaconda3/lib/python3.8/site-packages/distrax/_src/utils/conversion.py in <module>
+     18 
+---> 19 import chex
+     20 from distrax._src.bijectors import bijector
+
+/opt/anaconda3/lib/python3.8/site-packages/chex/__init__.py in <module>
+     16 
+---> 17 from chex._src.asserts import assert_axis_dimension
+     18 from chex._src.asserts import assert_axis_dimension_gt
+
+/opt/anaconda3/lib/python3.8/site-packages/chex/_src/asserts.py in <module>
+     27 from chex._src import asserts_internal as ai
+---> 28 from chex._src import pytypes
+     29 import jax
+
+/opt/anaconda3/lib/python3.8/site-packages/chex/_src/pytypes.py in <module>
+     38 
+---> 39 CpuDevice = jax.lib.xla_extension.CpuDevice
+     40 GpuDevice = jax.lib.xla_extension.GpuDevice
+
+AttributeError: module 'jaxlib.xla_extension' has no attribute 'CpuDevice'
+
+During handling of the above exception, another exception occurred:
+
+AttributeError                            Traceback (most recent call last)
+<ipython-input-1-035f9f944d53> in <module>
+     22 except:
+     23     get_ipython().run_line_magic('pip', 'install --upgrade  distrax')
+---> 24     import distrax
+     25 
+     26 try:
+
+/opt/anaconda3/lib/python3.8/site-packages/distrax/__init__.py in <module>
+     18 from distrax._src.bijectors.bijector import Bijector
+     19 from distrax._src.bijectors.bijector import BijectorLike
+---> 20 from distrax._src.bijectors.block import Block
+     21 from distrax._src.bijectors.chain import Chain
+     22 from distrax._src.bijectors.gumbel_cdf import GumbelCDF
+
+/opt/anaconda3/lib/python3.8/site-packages/distrax/_src/bijectors/block.py in <module>
+     18 
+     19 from distrax._src.bijectors import bijector as base
+---> 20 from distrax._src.utils import conversion
+     21 from distrax._src.utils import math
+     22 
+
+/opt/anaconda3/lib/python3.8/site-packages/distrax/_src/utils/conversion.py in <module>
+     17 from typing import Optional, Union
+     18 
+---> 19 import chex
+     20 from distrax._src.bijectors import bijector
+     21 from distrax._src.bijectors import bijector_from_tfp
+
+/opt/anaconda3/lib/python3.8/site-packages/chex/__init__.py in <module>
+     15 """Chex: Testing made fun, in JAX!"""
+     16 
+---> 17 from chex._src.asserts import assert_axis_dimension
+     18 from chex._src.asserts import assert_axis_dimension_gt
+     19 from chex._src.asserts import assert_devices_available
+
+/opt/anaconda3/lib/python3.8/site-packages/chex/_src/asserts.py in <module>
+     26 
+     27 from chex._src import asserts_internal as ai
+---> 28 from chex._src import pytypes
+     29 import jax
+     30 import jax.numpy as jnp
+
+/opt/anaconda3/lib/python3.8/site-packages/chex/_src/pytypes.py in <module>
+     37 Shape = Tuple[int, ...]
+     38 
+---> 39 CpuDevice = jax.lib.xla_extension.CpuDevice
+     40 GpuDevice = jax.lib.xla_extension.GpuDevice
+     41 TpuDevice = jax.lib.xla_extension.TpuDevice
+
+AttributeError: module 'jaxlib.xla_extension' has no attribute 'CpuDevice'
+AttributeError: module 'jaxlib.xla_extension' has no attribute 'CpuDevice'
+

+ 102 - 0
_build/html/reports/learning.log

@@ -0,0 +1,102 @@
+Traceback (most recent call last):
+  File "/opt/anaconda3/lib/python3.8/site-packages/jupyter_cache/executors/utils.py", line 51, in single_nb_execution
+    executenb(
+  File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 1087, in execute
+    return NotebookClient(nb=nb, resources=resources, km=km, **kwargs).execute()
+  File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/util.py", line 74, in wrapped
+    return just_run(coro(*args, **kwargs))
+  File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/util.py", line 53, in just_run
+    return loop.run_until_complete(coro)
+  File "/opt/anaconda3/lib/python3.8/asyncio/base_events.py", line 616, in run_until_complete
+    return future.result()
+  File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 540, in async_execute
+    await self.async_execute_cell(
+  File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 832, in async_execute_cell
+    self._check_raise_for_error(cell, exec_reply)
+  File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 740, in _check_raise_for_error
+    raise CellExecutionError.from_cell_and_msg(cell, exec_reply['content'])
+nbclient.exceptions.CellExecutionError: An error occurred while executing the following cell:
+------------------
+### Import standard libraries
+
+import abc
+from dataclasses import dataclass
+import functools
+from functools import partial
+import itertools
+import matplotlib.pyplot as plt
+import numpy as np
+from typing import Any, Callable, NamedTuple, Optional, Union, Tuple
+
+import jax
+import jax.numpy as jnp
+from jax import lax, vmap, jit, grad
+#from jax.scipy.special import logit
+#from jax.nn import softmax
+import jax.random as jr
+
+
+
+import distrax
+import optax
+
+import jsl
+import ssm_jax
+
+
+------------------
+
+---------------------------------------------------------------------------
+AttributeError                            Traceback (most recent call last)
+<ipython-input-1-00ce8083638c> in <module>
+     19 
+     20 
+---> 21 import distrax
+     22 import optax
+     23 
+
+/opt/anaconda3/lib/python3.8/site-packages/distrax/__init__.py in <module>
+     18 from distrax._src.bijectors.bijector import Bijector
+     19 from distrax._src.bijectors.bijector import BijectorLike
+---> 20 from distrax._src.bijectors.block import Block
+     21 from distrax._src.bijectors.chain import Chain
+     22 from distrax._src.bijectors.gumbel_cdf import GumbelCDF
+
+/opt/anaconda3/lib/python3.8/site-packages/distrax/_src/bijectors/block.py in <module>
+     18 
+     19 from distrax._src.bijectors import bijector as base
+---> 20 from distrax._src.utils import conversion
+     21 from distrax._src.utils import math
+     22 
+
+/opt/anaconda3/lib/python3.8/site-packages/distrax/_src/utils/conversion.py in <module>
+     17 from typing import Optional, Union
+     18 
+---> 19 import chex
+     20 from distrax._src.bijectors import bijector
+     21 from distrax._src.bijectors import bijector_from_tfp
+
+/opt/anaconda3/lib/python3.8/site-packages/chex/__init__.py in <module>
+     15 """Chex: Testing made fun, in JAX!"""
+     16 
+---> 17 from chex._src.asserts import assert_axis_dimension
+     18 from chex._src.asserts import assert_axis_dimension_gt
+     19 from chex._src.asserts import assert_devices_available
+
+/opt/anaconda3/lib/python3.8/site-packages/chex/_src/asserts.py in <module>
+     26 
+     27 from chex._src import asserts_internal as ai
+---> 28 from chex._src import pytypes
+     29 import jax
+     30 import jax.numpy as jnp
+
+/opt/anaconda3/lib/python3.8/site-packages/chex/_src/pytypes.py in <module>
+     37 Shape = Tuple[int, ...]
+     38 
+---> 39 CpuDevice = jax.lib.xla_extension.CpuDevice
+     40 GpuDevice = jax.lib.xla_extension.GpuDevice
+     41 TpuDevice = jax.lib.xla_extension.TpuDevice
+
+AttributeError: module 'jaxlib.xla_extension' has no attribute 'CpuDevice'
+AttributeError: module 'jaxlib.xla_extension' has no attribute 'CpuDevice'
+

+ 154 - 0
_build/html/reports/nlds.log

@@ -0,0 +1,154 @@
+Traceback (most recent call last):
+  File "/opt/anaconda3/lib/python3.8/site-packages/jupyter_cache/executors/utils.py", line 51, in single_nb_execution
+    executenb(
+  File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 1087, in execute
+    return NotebookClient(nb=nb, resources=resources, km=km, **kwargs).execute()
+  File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/util.py", line 74, in wrapped
+    return just_run(coro(*args, **kwargs))
+  File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/util.py", line 53, in just_run
+    return loop.run_until_complete(coro)
+  File "/opt/anaconda3/lib/python3.8/asyncio/base_events.py", line 616, in run_until_complete
+    return future.result()
+  File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 540, in async_execute
+    await self.async_execute_cell(
+  File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 832, in async_execute_cell
+    self._check_raise_for_error(cell, exec_reply)
+  File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 740, in _check_raise_for_error
+    raise CellExecutionError.from_cell_and_msg(cell, exec_reply['content'])
+nbclient.exceptions.CellExecutionError: An error occurred while executing the following cell:
+------------------
+# meta-data does not work yet in VScode
+# https://github.com/microsoft/vscode-jupyter/issues/1121
+
+{
+    "tags": [
+        "hide-cell"
+    ]
+}
+
+
+### Install necessary libraries
+
+try:
+    import jax
+except:
+    # For cuda version, see https://github.com/google/jax#installation
+    %pip install --upgrade "jax[cpu]" 
+    import jax
+
+try:
+    import distrax
+except:
+    %pip install --upgrade  distrax
+    import distrax
+
+try:
+    import jsl
+except:
+    %pip install git+https://github.com/probml/jsl
+    import jsl
+
+try:
+    import rich
+except:
+    %pip install rich
+    import rich
+
+
+
+------------------
+
+---------------------------------------------------------------------------
+AttributeError                            Traceback (most recent call last)
+<ipython-input-1-035f9f944d53> in <module>
+     20 try:
+---> 21     import distrax
+     22 except:
+
+/opt/anaconda3/lib/python3.8/site-packages/distrax/__init__.py in <module>
+     19 from distrax._src.bijectors.bijector import BijectorLike
+---> 20 from distrax._src.bijectors.block import Block
+     21 from distrax._src.bijectors.chain import Chain
+
+/opt/anaconda3/lib/python3.8/site-packages/distrax/_src/bijectors/block.py in <module>
+     19 from distrax._src.bijectors import bijector as base
+---> 20 from distrax._src.utils import conversion
+     21 from distrax._src.utils import math
+
+/opt/anaconda3/lib/python3.8/site-packages/distrax/_src/utils/conversion.py in <module>
+     18 
+---> 19 import chex
+     20 from distrax._src.bijectors import bijector
+
+/opt/anaconda3/lib/python3.8/site-packages/chex/__init__.py in <module>
+     16 
+---> 17 from chex._src.asserts import assert_axis_dimension
+     18 from chex._src.asserts import assert_axis_dimension_gt
+
+/opt/anaconda3/lib/python3.8/site-packages/chex/_src/asserts.py in <module>
+     27 from chex._src import asserts_internal as ai
+---> 28 from chex._src import pytypes
+     29 import jax
+
+/opt/anaconda3/lib/python3.8/site-packages/chex/_src/pytypes.py in <module>
+     38 
+---> 39 CpuDevice = jax.lib.xla_extension.CpuDevice
+     40 GpuDevice = jax.lib.xla_extension.GpuDevice
+
+AttributeError: module 'jaxlib.xla_extension' has no attribute 'CpuDevice'
+
+During handling of the above exception, another exception occurred:
+
+AttributeError                            Traceback (most recent call last)
+<ipython-input-1-035f9f944d53> in <module>
+     22 except:
+     23     get_ipython().run_line_magic('pip', 'install --upgrade  distrax')
+---> 24     import distrax
+     25 
+     26 try:
+
+/opt/anaconda3/lib/python3.8/site-packages/distrax/__init__.py in <module>
+     18 from distrax._src.bijectors.bijector import Bijector
+     19 from distrax._src.bijectors.bijector import BijectorLike
+---> 20 from distrax._src.bijectors.block import Block
+     21 from distrax._src.bijectors.chain import Chain
+     22 from distrax._src.bijectors.gumbel_cdf import GumbelCDF
+
+/opt/anaconda3/lib/python3.8/site-packages/distrax/_src/bijectors/block.py in <module>
+     18 
+     19 from distrax._src.bijectors import bijector as base
+---> 20 from distrax._src.utils import conversion
+     21 from distrax._src.utils import math
+     22 
+
+/opt/anaconda3/lib/python3.8/site-packages/distrax/_src/utils/conversion.py in <module>
+     17 from typing import Optional, Union
+     18 
+---> 19 import chex
+     20 from distrax._src.bijectors import bijector
+     21 from distrax._src.bijectors import bijector_from_tfp
+
+/opt/anaconda3/lib/python3.8/site-packages/chex/__init__.py in <module>
+     15 """Chex: Testing made fun, in JAX!"""
+     16 
+---> 17 from chex._src.asserts import assert_axis_dimension
+     18 from chex._src.asserts import assert_axis_dimension_gt
+     19 from chex._src.asserts import assert_devices_available
+
+/opt/anaconda3/lib/python3.8/site-packages/chex/_src/asserts.py in <module>
+     26 
+     27 from chex._src import asserts_internal as ai
+---> 28 from chex._src import pytypes
+     29 import jax
+     30 import jax.numpy as jnp
+
+/opt/anaconda3/lib/python3.8/site-packages/chex/_src/pytypes.py in <module>
+     37 Shape = Tuple[int, ...]
+     38 
+---> 39 CpuDevice = jax.lib.xla_extension.CpuDevice
+     40 GpuDevice = jax.lib.xla_extension.GpuDevice
+     41 TpuDevice = jax.lib.xla_extension.TpuDevice
+
+AttributeError: module 'jaxlib.xla_extension' has no attribute 'CpuDevice'
+AttributeError: module 'jaxlib.xla_extension' has no attribute 'CpuDevice'
+

+ 102 - 0
_build/html/reports/scratchpad.log

@@ -0,0 +1,102 @@
+Traceback (most recent call last):
+  File "/opt/anaconda3/lib/python3.8/site-packages/jupyter_cache/executors/utils.py", line 51, in single_nb_execution
+    executenb(
+  File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 1087, in execute
+    return NotebookClient(nb=nb, resources=resources, km=km, **kwargs).execute()
+  File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/util.py", line 74, in wrapped
+    return just_run(coro(*args, **kwargs))
+  File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/util.py", line 53, in just_run
+    return loop.run_until_complete(coro)
+  File "/opt/anaconda3/lib/python3.8/asyncio/base_events.py", line 616, in run_until_complete
+    return future.result()
+  File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 540, in async_execute
+    await self.async_execute_cell(
+  File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 832, in async_execute_cell
+    self._check_raise_for_error(cell, exec_reply)
+  File "/opt/anaconda3/lib/python3.8/site-packages/nbclient/client.py", line 740, in _check_raise_for_error
+    raise CellExecutionError.from_cell_and_msg(cell, exec_reply['content'])
+nbclient.exceptions.CellExecutionError: An error occurred while executing the following cell:
+------------------
+### Import standard libraries
+
+import abc
+from dataclasses import dataclass
+import functools
+from functools import partial
+import itertools
+import matplotlib.pyplot as plt
+import numpy as np
+from typing import Any, Callable, NamedTuple, Optional, Union, Tuple
+
+import jax
+import jax.numpy as jnp
+from jax import lax, vmap, jit, grad
+#from jax.scipy.special import logit
+#from jax.nn import softmax
+import jax.random as jr
+
+
+
+import distrax
+import optax
+
+import jsl
+import ssm_jax
+
+
+------------------
+
+---------------------------------------------------------------------------
+AttributeError                            Traceback (most recent call last)
+<ipython-input-1-00ce8083638c> in <module>
+     19 
+     20 
+---> 21 import distrax
+     22 import optax
+     23 
+
+/opt/anaconda3/lib/python3.8/site-packages/distrax/__init__.py in <module>
+     18 from distrax._src.bijectors.bijector import Bijector
+     19 from distrax._src.bijectors.bijector import BijectorLike
+---> 20 from distrax._src.bijectors.block import Block
+     21 from distrax._src.bijectors.chain import Chain
+     22 from distrax._src.bijectors.gumbel_cdf import GumbelCDF
+
+/opt/anaconda3/lib/python3.8/site-packages/distrax/_src/bijectors/block.py in <module>
+     18 
+     19 from distrax._src.bijectors import bijector as base
+---> 20 from distrax._src.utils import conversion
+     21 from distrax._src.utils import math
+     22 
+
+/opt/anaconda3/lib/python3.8/site-packages/distrax/_src/utils/conversion.py in <module>
+     17 from typing import Optional, Union
+     18 
+---> 19 import chex
+     20 from distrax._src.bijectors import bijector
+     21 from distrax._src.bijectors import bijector_from_tfp
+
+/opt/anaconda3/lib/python3.8/site-packages/chex/__init__.py in <module>
+     15 """Chex: Testing made fun, in JAX!"""
+     16 
+---> 17 from chex._src.asserts import assert_axis_dimension
+     18 from chex._src.asserts import assert_axis_dimension_gt
+     19 from chex._src.asserts import assert_devices_available
+
+/opt/anaconda3/lib/python3.8/site-packages/chex/_src/asserts.py in <module>
+     26 
+     27 from chex._src import asserts_internal as ai
+---> 28 from chex._src import pytypes
+     29 import jax
+     30 import jax.numpy as jnp
+
+/opt/anaconda3/lib/python3.8/site-packages/chex/_src/pytypes.py in <module>
+     37 Shape = Tuple[int, ...]
+     38 
+---> 39 CpuDevice = jax.lib.xla_extension.CpuDevice
+     40 GpuDevice = jax.lib.xla_extension.GpuDevice
+     41 TpuDevice = jax.lib.xla_extension.TpuDevice
+
+AttributeError: module 'jaxlib.xla_extension' has no attribute 'CpuDevice'
+AttributeError: module 'jaxlib.xla_extension' has no attribute 'CpuDevice'
+

+ 11 - 11
_build/html/root.html

@@ -53,7 +53,7 @@ const thebe_selector_output = ".output, .cell_output"
     <script async="async" src="_static/sphinx-thebe.js"></script>
     <link rel="index" title="Index" href="genindex.html" />
     <link rel="search" title="Search" href="search.html" />
-    <link rel="next" title="Scratchpad" href="chapters/scratch.html" />
+    <link rel="next" title="State Space Models" href="chapters/ssm/ssm_index.html" />
     <meta name="viewport" content="width=device-width, initial-scale=1" />
     <meta name="docsearch:language" content="None">
     
@@ -93,11 +93,6 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
 </ul>
 <ul class="nav bd-sidenav">
- <li class="toctree-l1">
-  <a class="reference internal" href="chapters/scratch.html">
-   Scratchpad
-  </a>
- </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="chapters/ssm/ssm_index.html">
    State Space Models
@@ -130,7 +125,12 @@ const thebe_selector_output = ".output, .cell_output"
    </li>
    <li class="toctree-l2">
     <a class="reference internal" href="chapters/ssm/inference.html">
-     Inferential goals
+     States estimation (inference)
+    </a>
+   </li>
+   <li class="toctree-l2">
+    <a class="reference internal" href="chapters/ssm/learning.html">
+     Parameter estimation (learning)
     </a>
    </li>
   </ul>
@@ -445,13 +445,13 @@ exploiting recent progress
 in automatic differentiation and parallel computing.</p>
 <div class="toctree-wrapper compound">
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="chapters/scratch.html">Scratchpad</a></li>
 <li class="toctree-l1"><a class="reference internal" href="chapters/ssm/ssm_index.html">State Space Models</a><ul>
 <li class="toctree-l2"><a class="reference internal" href="chapters/ssm/ssm_intro.html">What are State Space Models?</a></li>
 <li class="toctree-l2"><a class="reference internal" href="chapters/ssm/hmm.html">Hidden Markov Models</a></li>
 <li class="toctree-l2"><a class="reference internal" href="chapters/ssm/lds.html">Linear Gaussian SSMs</a></li>
 <li class="toctree-l2"><a class="reference internal" href="chapters/ssm/nlds.html">Nonlinear Gaussian SSMs</a></li>
-<li class="toctree-l2"><a class="reference internal" href="chapters/ssm/inference.html">Inferential goals</a></li>
+<li class="toctree-l2"><a class="reference internal" href="chapters/ssm/inference.html">States estimation (inference)</a></li>
+<li class="toctree-l2"><a class="reference internal" href="chapters/ssm/learning.html">Parameter estimation (learning)</a></li>
 </ul>
 </li>
 <li class="toctree-l1"><a class="reference internal" href="chapters/hmm/hmm_index.html">Hidden Markov Models</a><ul>
@@ -525,10 +525,10 @@ in automatic differentiation and parallel computing.</p>
             
                 <!-- Previous / next buttons -->
 <div class='prev-next-area'>
-    <a class='right-next' id="next-link" href="chapters/scratch.html" title="next page">
+    <a class='right-next' id="next-link" href="chapters/ssm/ssm_index.html" title="next page">
     <div class="prev-next-info">
         <p class="prev-next-subtitle">next</p>
-        <p class="prev-next-title">Scratchpad</p>
+        <p class="prev-next-title">State Space Models</p>
     </div>
     <i class="fas fa-angle-right"></i>
     </a>

+ 6 - 6
_build/html/search.html

@@ -98,11 +98,6 @@ const thebe_selector_output = ".output, .cell_output"
  </li>
 </ul>
 <ul class="nav bd-sidenav">
- <li class="toctree-l1">
-  <a class="reference internal" href="chapters/scratch.html">
-   Scratchpad
-  </a>
- </li>
  <li class="toctree-l1 has-children">
   <a class="reference internal" href="chapters/ssm/ssm_index.html">
    State Space Models
@@ -135,7 +130,12 @@ const thebe_selector_output = ".output, .cell_output"
    </li>
    <li class="toctree-l2">
     <a class="reference internal" href="chapters/ssm/inference.html">
-     Inferential goals
+     States estimation (inference)
+    </a>
+   </li>
+   <li class="toctree-l2">
+    <a class="reference internal" href="chapters/ssm/learning.html">
+     Parameter estimation (learning)
     </a>
    </li>
   </ul>

文件差異過大導致無法顯示
+ 1 - 1
_build/html/searchindex.js


+ 252 - 346
_build/jupyter_execute/chapters/hmm/hmm_filter.ipynb

@@ -4,171 +4,8 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "```{math}\n",
-    "\n",
-    "\\newcommand{\\defeq}{\\triangleq}\n",
-    "\\newcommand{\\trans}{{\\mkern-1.5mu\\mathsf{T}}}\n",
-    "\\newcommand{\\transpose}[1]{{#1}^{\\trans}}\n",
-    "\n",
-    "\\newcommand{\\inv}[1]{{#1}^{-1}}\n",
-    "\\DeclareMathOperator{\\dotstar}{\\odot}\n",
-    "\n",
-    "\n",
-    "\\newcommand\\floor[1]{\\lfloor#1\\rfloor}\n",
-    "\n",
-    "\\newcommand{\\real}{\\mathbb{R}}\n",
-    "\n",
-    "% Numbers\n",
-    "\\newcommand{\\vzero}{\\boldsymbol{0}}\n",
-    "\\newcommand{\\vone}{\\boldsymbol{1}}\n",
-    "\n",
-    "% Greek https://www.latex-tutorial.com/symbols/greek-alphabet/\n",
-    "\\newcommand{\\valpha}{\\boldsymbol{\\alpha}}\n",
-    "\\newcommand{\\vbeta}{\\boldsymbol{\\beta}}\n",
-    "\\newcommand{\\vchi}{\\boldsymbol{\\chi}}\n",
-    "\\newcommand{\\vdelta}{\\boldsymbol{\\delta}}\n",
-    "\\newcommand{\\vDelta}{\\boldsymbol{\\Delta}}\n",
-    "\\newcommand{\\vepsilon}{\\boldsymbol{\\epsilon}}\n",
-    "\\newcommand{\\vzeta}{\\boldsymbol{\\zeta}}\n",
-    "\\newcommand{\\vXi}{\\boldsymbol{\\Xi}}\n",
-    "\\newcommand{\\vell}{\\boldsymbol{\\ell}}\n",
-    "\\newcommand{\\veta}{\\boldsymbol{\\eta}}\n",
-    "%\\newcommand{\\vEta}{\\boldsymbol{\\Eta}}\n",
-    "\\newcommand{\\vgamma}{\\boldsymbol{\\gamma}}\n",
-    "\\newcommand{\\vGamma}{\\boldsymbol{\\Gamma}}\n",
-    "\\newcommand{\\vmu}{\\boldsymbol{\\mu}}\n",
-    "\\newcommand{\\vmut}{\\boldsymbol{\\tilde{\\mu}}}\n",
-    "\\newcommand{\\vnu}{\\boldsymbol{\\nu}}\n",
-    "\\newcommand{\\vkappa}{\\boldsymbol{\\kappa}}\n",
-    "\\newcommand{\\vlambda}{\\boldsymbol{\\lambda}}\n",
-    "\\newcommand{\\vLambda}{\\boldsymbol{\\Lambda}}\n",
-    "\\newcommand{\\vLambdaBar}{\\overline{\\vLambda}}\n",
-    "%\\newcommand{\\vnu}{\\boldsymbol{\\nu}}\n",
-    "\\newcommand{\\vomega}{\\boldsymbol{\\omega}}\n",
-    "\\newcommand{\\vOmega}{\\boldsymbol{\\Omega}}\n",
-    "\\newcommand{\\vphi}{\\boldsymbol{\\phi}}\n",
-    "\\newcommand{\\vvarphi}{\\boldsymbol{\\varphi}}\n",
-    "\\newcommand{\\vPhi}{\\boldsymbol{\\Phi}}\n",
-    "\\newcommand{\\vpi}{\\boldsymbol{\\pi}}\n",
-    "\\newcommand{\\vPi}{\\boldsymbol{\\Pi}}\n",
-    "\\newcommand{\\vpsi}{\\boldsymbol{\\psi}}\n",
-    "\\newcommand{\\vPsi}{\\boldsymbol{\\Psi}}\n",
-    "\\newcommand{\\vrho}{\\boldsymbol{\\rho}}\n",
-    "\\newcommand{\\vtheta}{\\boldsymbol{\\theta}}\n",
-    "\\newcommand{\\vthetat}{\\boldsymbol{\\tilde{\\theta}}}\n",
-    "\\newcommand{\\vTheta}{\\boldsymbol{\\Theta}}\n",
-    "\\newcommand{\\vsigma}{\\boldsymbol{\\sigma}}\n",
-    "\\newcommand{\\vSigma}{\\boldsymbol{\\Sigma}}\n",
-    "\\newcommand{\\vSigmat}{\\boldsymbol{\\tilde{\\Sigma}}}\n",
-    "\\newcommand{\\vsigmoid}{\\vsigma}\n",
-    "\\newcommand{\\vtau}{\\boldsymbol{\\tau}}\n",
-    "\\newcommand{\\vxi}{\\boldsymbol{\\xi}}\n",
-    "\n",
-    "\n",
-    "% Lower Roman (Vectors)\n",
-    "\\newcommand{\\va}{\\mathbf{a}}\n",
-    "\\newcommand{\\vb}{\\mathbf{b}}\n",
-    "\\newcommand{\\vBt}{\\mathbf{\\tilde{B}}}\n",
-    "\\newcommand{\\vc}{\\mathbf{c}}\n",
-    "\\newcommand{\\vct}{\\mathbf{\\tilde{c}}}\n",
-    "\\newcommand{\\vd}{\\mathbf{d}}\n",
-    "\\newcommand{\\ve}{\\mathbf{e}}\n",
-    "\\newcommand{\\vf}{\\mathbf{f}}\n",
-    "\\newcommand{\\vg}{\\mathbf{g}}\n",
-    "\\newcommand{\\vh}{\\mathbf{h}}\n",
-    "%\\newcommand{\\myvh}{\\mathbf{h}}\n",
-    "\\newcommand{\\vi}{\\mathbf{i}}\n",
-    "\\newcommand{\\vj}{\\mathbf{j}}\n",
-    "\\newcommand{\\vk}{\\mathbf{k}}\n",
-    "\\newcommand{\\vl}{\\mathbf{l}}\n",
-    "\\newcommand{\\vm}{\\mathbf{m}}\n",
-    "\\newcommand{\\vn}{\\mathbf{n}}\n",
-    "\\newcommand{\\vo}{\\mathbf{o}}\n",
-    "\\newcommand{\\vp}{\\mathbf{p}}\n",
-    "\\newcommand{\\vq}{\\mathbf{q}}\n",
-    "\\newcommand{\\vr}{\\mathbf{r}}\n",
-    "\\newcommand{\\vs}{\\mathbf{s}}\n",
-    "\\newcommand{\\vt}{\\mathbf{t}}\n",
-    "\\newcommand{\\vu}{\\mathbf{u}}\n",
-    "\\newcommand{\\vv}{\\mathbf{v}}\n",
-    "\\newcommand{\\vw}{\\mathbf{w}}\n",
-    "\\newcommand{\\vws}{\\vw_s}\n",
-    "\\newcommand{\\vwt}{\\mathbf{\\tilde{w}}}\n",
-    "\\newcommand{\\vWt}{\\mathbf{\\tilde{W}}}\n",
-    "\\newcommand{\\vwh}{\\hat{\\vw}}\n",
-    "\\newcommand{\\vx}{\\mathbf{x}}\n",
-    "%\\newcommand{\\vx}{\\mathbf{x}}\n",
-    "\\newcommand{\\vxt}{\\mathbf{\\tilde{x}}}\n",
-    "\\newcommand{\\vy}{\\mathbf{y}}\n",
-    "\\newcommand{\\vyt}{\\mathbf{\\tilde{y}}}\n",
-    "\\newcommand{\\vz}{\\mathbf{z}}\n",
-    "%\\newcommand{\\vzt}{\\mathbf{\\tilde{z}}}\n",
-    "\n",
-    "\n",
-    "% Upper Roman (Matrices)\n",
-    "\\newcommand{\\vA}{\\mathbf{A}}\n",
-    "\\newcommand{\\vB}{\\mathbf{B}}\n",
-    "\\newcommand{\\vC}{\\mathbf{C}}\n",
-    "\\newcommand{\\vD}{\\mathbf{D}}\n",
-    "\\newcommand{\\vE}{\\mathbf{E}}\n",
-    "\\newcommand{\\vF}{\\mathbf{F}}\n",
-    "\\newcommand{\\vG}{\\mathbf{G}}\n",
-    "\\newcommand{\\vH}{\\mathbf{H}}\n",
-    "\\newcommand{\\vI}{\\mathbf{I}}\n",
-    "\\newcommand{\\vJ}{\\mathbf{J}}\n",
-    "\\newcommand{\\vK}{\\mathbf{K}}\n",
-    "\\newcommand{\\vL}{\\mathbf{L}}\n",
-    "\\newcommand{\\vM}{\\mathbf{M}}\n",
-    "\\newcommand{\\vMt}{\\mathbf{\\tilde{M}}}\n",
-    "\\newcommand{\\vN}{\\mathbf{N}}\n",
-    "\\newcommand{\\vO}{\\mathbf{O}}\n",
-    "\\newcommand{\\vP}{\\mathbf{P}}\n",
-    "\\newcommand{\\vQ}{\\mathbf{Q}}\n",
-    "\\newcommand{\\vR}{\\mathbf{R}}\n",
-    "\\newcommand{\\vS}{\\mathbf{S}}\n",
-    "\\newcommand{\\vT}{\\mathbf{T}}\n",
-    "\\newcommand{\\vU}{\\mathbf{U}}\n",
-    "\\newcommand{\\vV}{\\mathbf{V}}\n",
-    "\\newcommand{\\vW}{\\mathbf{W}}\n",
-    "\\newcommand{\\vX}{\\mathbf{X}}\n",
-    "%\\newcommand{\\vXs}{\\vX_{\\vs}}\n",
-    "\\newcommand{\\vXs}{\\vX_{s}}\n",
-    "\\newcommand{\\vXt}{\\mathbf{\\tilde{X}}}\n",
-    "\\newcommand{\\vY}{\\mathbf{Y}}\n",
-    "\\newcommand{\\vZ}{\\mathbf{Z}}\n",
-    "\\newcommand{\\vZt}{\\mathbf{\\tilde{Z}}}\n",
-    "\\newcommand{\\vzt}{\\mathbf{\\tilde{z}}}\n",
-    "\n",
-    "\n",
-    "%%%%\n",
-    "\\newcommand{\\hidden}{\\vz}\n",
-    "\\newcommand{\\hid}{\\hidden}\n",
-    "\\newcommand{\\observed}{\\vy}\n",
-    "\\newcommand{\\obs}{\\observed}\n",
-    "\\newcommand{\\inputs}{\\vu}\n",
-    "\\newcommand{\\input}{\\inputs}\n",
-    "\n",
-    "\\newcommand{\\hmmTrans}{\\vA}\n",
-    "\\newcommand{\\hmmObs}{\\vB}\n",
-    "\\newcommand{\\hmmInit}{\\vpi}\n",
-    "\n",
-    "\n",
-    "\\newcommand{\\ldsDyn}{\\vA}\n",
-    "\\newcommand{\\ldsObs}{\\vC}\n",
-    "\\newcommand{\\ldsDynIn}{\\vB}\n",
-    "\\newcommand{\\ldsObsIn}{\\vD}\n",
-    "\\newcommand{\\ldsDynNoise}{\\vQ}\n",
-    "\\newcommand{\\ldsObsNoise}{\\vR}\n",
-    "\n",
-    "\\newcommand{\\ssmDynFn}{f}\n",
-    "\\newcommand{\\ssmObsFn}{h}\n",
-    "\n",
-    "\n",
-    "%%%\n",
-    "\\newcommand{\\gauss}{\\mathcal{N}}\n",
-    "\n",
-    "\\newcommand{\\diag}{\\mathrm{diag}}\n",
-    "```\n"
+    "(sec:forwards)=\n",
+    "# HMM filtering (forwards algorithm)"
    ]
   },
   {
@@ -177,141 +14,85 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "# meta-data does not work yet in VScode\n",
-    "# https://github.com/microsoft/vscode-jupyter/issues/1121\n",
-    "\n",
-    "{\n",
-    "    \"tags\": [\n",
-    "        \"hide-cell\"\n",
-    "    ]\n",
-    "}\n",
-    "\n",
-    "\n",
-    "### Install necessary libraries\n",
-    "\n",
-    "try:\n",
-    "    import jax\n",
-    "except:\n",
-    "    # For cuda version, see https://github.com/google/jax#installation\n",
-    "    %pip install --upgrade \"jax[cpu]\" \n",
-    "    import jax\n",
-    "\n",
-    "try:\n",
-    "    import distrax\n",
-    "except:\n",
-    "    %pip install --upgrade  distrax\n",
-    "    import distrax\n",
-    "\n",
-    "try:\n",
-    "    import jsl\n",
-    "except:\n",
-    "    %pip install git+https://github.com/probml/jsl\n",
-    "    import jsl\n",
-    "\n",
-    "#try:\n",
-    "#    import ssm_jax\n",
-    "##except:\n",
-    "#    %pip install git+https://github.com/probml/ssm-jax\n",
-    "#    import ssm_jax\n",
-    "\n",
-    "try:\n",
-    "    import rich\n",
-    "except:\n",
-    "    %pip install rich\n",
-    "    import rich\n",
-    "\n",
-    "\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 2,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "{\n",
-    "    \"tags\": [\n",
-    "        \"hide-cell\"\n",
-    "    ]\n",
-    "}\n",
-    "\n",
-    "\n",
     "### Import standard libraries\n",
     "\n",
     "import abc\n",
     "from dataclasses import dataclass\n",
     "import functools\n",
+    "from functools import partial\n",
     "import itertools\n",
-    "\n",
-    "from typing import Any, Callable, NamedTuple, Optional, Union, Tuple\n",
-    "\n",
     "import matplotlib.pyplot as plt\n",
     "import numpy as np\n",
-    "\n",
+    "from typing import Any, Callable, NamedTuple, Optional, Union, Tuple\n",
     "\n",
     "import jax\n",
     "import jax.numpy as jnp\n",
     "from jax import lax, vmap, jit, grad\n",
-    "from jax.scipy.special import logit\n",
-    "from jax.nn import softmax\n",
-    "from functools import partial\n",
-    "from jax.random import PRNGKey, split\n",
+    "#from jax.scipy.special import logit\n",
+    "#from jax.nn import softmax\n",
+    "import jax.random as jr\n",
+    "\n",
     "\n",
-    "import inspect\n",
-    "import inspect as py_inspect\n",
-    "import rich\n",
-    "from rich import inspect as r_inspect\n",
-    "from rich import print as r_print\n",
     "\n",
-    "def print_source(fname):\n",
-    "    r_print(py_inspect.getsource(fname))"
+    "import distrax\n",
+    "import optax\n",
+    "\n",
+    "import jsl\n",
+    "import ssm_jax"
    ]
   },
   {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "(sec:forwards)=\n",
-    "# HMM filtering (forwards algorithm)\n",
     "\n",
+    "## Introduction\n",
     "\n",
-    "The  **Bayes filter** is an algorithm for recursively computing\n",
+    "\n",
+    "The  $\\keyword{Bayes filter}$ is an algorithm for recursively computing\n",
     "the belief state\n",
     "$p(\\hidden_t|\\obs_{1:t})$ given\n",
     "the prior belief from the previous step,\n",
     "$p(\\hidden_{t-1}|\\obs_{1:t-1})$,\n",
     "the new observation $\\obs_t$,\n",
     "and the model.\n",
-    "This can be done using **sequential Bayesian updating**.\n",
+    "This can be done using $\\keyword{sequential Bayesian updating}$.\n",
     "For a dynamical model, this reduces to the\n",
-    "**predict-update** cycle described below.\n",
-    "\n",
+    "$\\keyword{predict-update}$ cycle described below.\n",
     "\n",
-    "The **prediction step** is just the **Chapman-Kolmogorov equation**:\n",
-    "```{math}\n",
+    "The $\\keyword{prediction step}$ is just the $\\keyword{Chapman-Kolmogorov equation}$:\n",
+    "\\begin{align}\n",
     "p(\\hidden_t|\\obs_{1:t-1})\n",
     "= \\int p(\\hidden_t|\\hidden_{t-1}) p(\\hidden_{t-1}|\\obs_{1:t-1}) d\\hidden_{t-1}\n",
-    "```\n",
+    "\\end{align}\n",
     "The prediction step computes\n",
-    "the one-step-ahead predictive distribution\n",
-    "for the latent state, which updates\n",
-    "the posterior from the previous time step into the prior\n",
+    "the $\\keyword{one-step-ahead predictive distribution}$\n",
+    "for the latent state, which converts\n",
+    "the posterior from the previous time step to become the prior\n",
     "for the current step.\n",
     "\n",
     "\n",
-    "The **update step**\n",
+    "The $\\keyword{update step}$\n",
     "is just Bayes rule:\n",
-    "```{math}\n",
+    "\\begin{align}\n",
     "p(\\hidden_t|\\obs_{1:t}) = \\frac{1}{Z_t}\n",
     "p(\\obs_t|\\hidden_t) p(\\hidden_t|\\obs_{1:t-1})\n",
-    "```\n",
+    "\\end{align}\n",
     "where the normalization constant is\n",
-    "```{math}\n",
+    "\\begin{align}\n",
     "Z_t = \\int p(\\obs_t|\\hidden_t) p(\\hidden_t|\\obs_{1:t-1}) d\\hidden_{t}\n",
     "= p(\\obs_t|\\obs_{1:t-1})\n",
-    "```\n",
+    "\\end{align}\n",
     "\n",
+    "Note that we can derive the log marginal likelihood from these normalization constants\n",
+    "as follows:\n",
+    "```{math}\n",
+    ":label: eqn:logZ\n",
     "\n",
+    "\\log p(\\obs_{1:T})\n",
+    "= \\sum_{t=1}^{T} \\log p(\\obs_t|\\obs_{1:t-1})\n",
+    "= \\sum_{t=1}^{T} \\log Z_t\n",
+    "```\n",
     "\n"
    ]
   },
@@ -323,10 +104,11 @@
     "When the latent states $\\hidden_t$ are discrete, as in HMM,\n",
     "the above integrals become sums.\n",
     "In particular, suppose we define\n",
-    "the belief state as $\\alpha_t(j) \\defeq p(\\hidden_t=j|\\obs_{1:t})$,\n",
-    "the local evidence as $\\lambda_t(j) \\defeq p(\\obs_t|\\hidden_t=j)$,\n",
-    "and the transition matrix\n",
-    "$A(i,j)  = p(\\hidden_t=j|\\hidden_{t-1}=i)$.\n",
+    "the $\\keyword{belief state}$ as $\\alpha_t(j) \\defeq p(\\hidden_t=j|\\obs_{1:t})$,\n",
+    "the  $\\keyword{local evidence}$ (or $\\keyword{local likelihood}$)\n",
+    "as $\\lambda_t(j) \\defeq p(\\obs_t|\\hidden_t=j)$,\n",
+    "and the transition matrix as\n",
+    "$\\hmmTrans(i,j)  = p(\\hidden_t=j|\\hidden_{t-1}=i)$.\n",
     "Then the predict step becomes\n",
     "```{math}\n",
     ":label: eqn:predictiveHMM\n",
@@ -338,7 +120,7 @@
     ":label: eqn:fwdsEqn\n",
     "\\alpha_t(j)\n",
     "= \\frac{1}{Z_t} \\lambda_t(j) \\alpha_{t|t-1}(j)\n",
-    "= \\frac{1}{Z_t} \\lambda_t(j) \\left[\\sum_i \\alpha_{t-1}(i) A(i,j)  \\right]\n",
+    "= \\frac{1}{Z_t} \\lambda_t(j) \\left[\\sum_i \\alpha_{t-1}(i) \\hmmTrans(i,j)  \\right]\n",
     "```\n",
     "where\n",
     "the  normalization constant for each time step is given by\n",
@@ -350,20 +132,33 @@
     "&=  \\sum_{j=1}^K \\lambda_t(j) \\alpha_{t|t-1}(j)\n",
     "\\end{align}\n",
     "```\n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
     "\n",
     "Since all the quantities are finite length vectors and matrices,\n",
-    "we can write the update equation\n",
-    "in matrix-vector notation as follows:\n",
+    "we can implement the whole procedure using matrix vector multoplication:\n",
     "```{math}\n",
+    ":label: eqn:fwdsAlgoMatrixForm\n",
     "\\valpha_t =\\text{normalize}\\left(\n",
-    "\\vlambda_t \\dotstar  (\\vA^{\\trans} \\valpha_{t-1}) \\right)\n",
-    "\\label{eqn:fwdsAlgoMatrixForm}\n",
+    "\\vlambda_t \\dotstar  (\\hmmTrans^{\\trans} \\valpha_{t-1}) \\right)\n",
     "```\n",
     "where $\\dotstar$ represents\n",
     "elementwise vector multiplication,\n",
-    "and the $\\text{normalize}$ function just ensures its argument sums to one.\n",
+    "and the $\\text{normalize}$ function just ensures its argument sums to one.\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Example\n",
     "\n",
-    "In {ref}(sec:casino-inference)\n",
+    "In {ref}`sec:casino-inference`\n",
     "we illustrate\n",
     "filtering for the casino HMM,\n",
     "applied to a random sequence $\\obs_{1:T}$ of length $T=300$.\n",
@@ -371,98 +166,209 @@
     "based on the evidence seen so far.\n",
     "The gray bars indicate time intervals during which the generative\n",
     "process actually switched to the loaded dice.\n",
-    "We see that the probability generally increases in the right places.\n"
+    "We see that the probability generally increases in the right places."
    ]
   },
   {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "Here is a JAX implementation of the forwards algorithm."
+    "## Normalization constants\n",
+    "\n",
+    "In most publications on HMMs,\n",
+    "such as {cite}`Rabiner89`,\n",
+    "the forwards message is defined\n",
+    "as the following unnormalized joint probability:\n",
+    "```{math}\n",
+    "\\alpha'_t(j) = p(\\hidden_t=j,\\obs_{1:t}) \n",
+    "= \\lambda_t(j) \\left[\\sum_i \\alpha'_{t-1}(i) A(i,j)  \\right]\n",
+    "```\n",
+    "In this book we define the forwards message   as the normalized\n",
+    "conditional probability\n",
+    "```{math}\n",
+    "\\alpha_t(j) = p(\\hidden_t=j|\\obs_{1:t}) \n",
+    "= \\frac{1}{Z_t} \\lambda_t(j) \\left[\\sum_i \\alpha_{t-1}(i) A(i,j)  \\right]\n",
+    "```\n",
+    "where $Z_t = p(\\obs_t|\\obs_{1:t-1})$.\n",
+    "\n",
+    "The \"traditional\" unnormalized form has several problems.\n",
+    "First, it rapidly suffers from numerical underflow,\n",
+    "since the probability of\n",
+    "the joint event that $(\\hidden_t=j,\\obs_{1:t})$\n",
+    "is vanishingly small. \n",
+    "To see why, suppose the observations are independent of the states.\n",
+    "In this case, the unnormalized joint has the form\n",
+    "\\begin{align}\n",
+    "p(\\hidden_t=j,\\obs_{1:t}) = p(\\hidden_t=j)\\prod_{i=1}^t p(\\obs_i)\n",
+    "\\end{align}\n",
+    "which becomes exponentially small with $t$, because we multiply\n",
+    "many probabilities which are less than one.\n",
+    "Second, the unnormalized probability is less interpretable,\n",
+    "since it is a joint distribution over states and observations,\n",
+    "rather than a conditional probability of states given observations.\n",
+    "Third, the unnormalized joint form is harder to approximate\n",
+    "than the normalized form.\n",
+    "Of course,\n",
+    "the two definitions only differ by a\n",
+    "multiplicative constant\n",
+    "{cite}`Devijver85`,\n",
+    "so the algorithmic difference is just\n",
+    "one line of code (namely the presence or absence of a call to the `normalize` function).\n",
+    "\n",
+    "\n",
+    "\n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Naive implementation\n",
+    "\n",
+    "Below we give a simple numpy implementation of the forwards algorithm.\n",
+    "We assume the HMM uses categorical observations, for simplicity.\n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "\n",
+    "\n",
+    "def normalize_np(u, axis=0, eps=1e-15):\n",
+    "    u = np.where(u == 0, 0, np.where(u < eps, eps, u))\n",
+    "    c = u.sum(axis=axis)\n",
+    "    c = np.where(c == 0, 1, c)\n",
+    "    return u / c, c\n",
+    "\n",
+    "def hmm_forwards_np(trans_mat, obs_mat, init_dist, obs_seq):\n",
+    "    n_states, n_obs = obs_mat.shape\n",
+    "    seq_len = len(obs_seq)\n",
+    "\n",
+    "    alpha_hist = np.zeros((seq_len, n_states))\n",
+    "    ll_hist = np.zeros(seq_len)  # loglikelihood history\n",
+    "\n",
+    "    alpha_n = init_dist * obs_mat[:, obs_seq[0]]\n",
+    "    alpha_n, cn = normalize_np(alpha_n)\n",
+    "\n",
+    "    alpha_hist[0] = alpha_n\n",
+    "    log_normalizer = np.log(cn)\n",
+    "\n",
+    "    for t in range(1, seq_len):\n",
+    "        alpha_n = obs_mat[:, obs_seq[t]] * (alpha_n[:, None] * trans_mat).sum(axis=0)\n",
+    "        alpha_n, zn = normalize_np(alpha_n)\n",
+    "\n",
+    "        alpha_hist[t] = alpha_n\n",
+    "        log_normalizer = np.log(zn) + log_normalizer\n",
+    "\n",
+    "    return  log_normalizer, alpha_hist"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Numerically stable implementation \n",
+    "\n",
+    "\n",
+    "\n",
+    "In practice it is more numerically stable to compute\n",
+    "the log likelihoods $\\ell_t(j) = \\log p(\\obs_t|\\hidden_t=j)$,\n",
+    "rather than the likelioods $\\lambda_t(j) = p(\\obs_t|\\hidden_t=j)$.\n",
+    "In this case, we can perform the posterior updating in a numerically stable way as follows.\n",
+    "Define $L_t = \\max_j \\ell_t(j)$ and\n",
+    "\\begin{align}\n",
+    "\\tilde{p}(\\hidden_t=j,\\obs_t|\\obs_{1:t-1})\n",
+    "&\\defeq p(\\hidden_t=j|\\obs_{1:t-1}) p(\\obs_t|\\hidden_t=j) e^{-L_t} \\\\\n",
+    " &= p(\\hidden_t=j|\\obs_{1:t-1}) e^{\\ell_t(j) - L_t}\n",
+    "\\end{align}\n",
+    "Then we have\n",
+    "\\begin{align}\n",
+    "p(\\hidden_t=j|\\obs_t,\\obs_{1:t-1})\n",
+    "  &= \\frac{1}{\\tilde{Z}_t} \\tilde{p}(\\hidden_t=j,\\obs_t|\\obs_{1:t-1}) \\\\\n",
+    "\\tilde{Z}_t &= \\sum_j \\tilde{p}(\\hidden_t=j,\\obs_t|\\obs_{1:t-1})\n",
+    "= p(\\obs_t|\\obs_{1:t-1}) e^{-L_t} \\\\\n",
+    "\\log Z_t &= \\log p(\\obs_t|\\obs_{1:t-1}) = \\log \\tilde{Z}_t + L_t\n",
+    "\\end{align}\n",
+    "\n",
+    "Below we show some JAX code that implements this core operation.\n"
    ]
   },
   {
    "cell_type": "code",
    "execution_count": 3,
    "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/html": [
-       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">@jit\n",
-       "def hmm_forwards_jax<span style=\"font-weight: bold\">(</span>params, obs_seq, <span style=\"color: #808000; text-decoration-color: #808000\">length</span>=<span style=\"color: #800080; text-decoration-color: #800080; font-style: italic\">None</span><span style=\"font-weight: bold\">)</span>:\n",
-       "    <span style=\"color: #008000; text-decoration-color: #008000\">''</span>'\n",
-       "    Calculates a belief state\n",
-       "\n",
-       "    Parameters\n",
-       "    ----------\n",
-       "    params : HMMJax\n",
-       "        Hidden Markov Model\n",
-       "\n",
-       "    obs_seq: array<span style=\"font-weight: bold\">(</span>seq_len<span style=\"font-weight: bold\">)</span>\n",
-       "        History of observable events\n",
-       "\n",
-       "    Returns\n",
-       "    -------\n",
-       "    * float\n",
-       "        The loglikelihood giving log<span style=\"font-weight: bold\">(</span>p<span style=\"font-weight: bold\">(</span>x|model<span style=\"font-weight: bold\">))</span>\n",
-       "\n",
-       "    * array<span style=\"font-weight: bold\">(</span>seq_len, n_hidden<span style=\"font-weight: bold\">)</span> :\n",
-       "        All alpha values found for each sample\n",
-       "    <span style=\"color: #008000; text-decoration-color: #008000\">''</span>'\n",
-       "    seq_len = len<span style=\"font-weight: bold\">(</span>obs_seq<span style=\"font-weight: bold\">)</span>\n",
-       "\n",
-       "    if length is <span style=\"color: #800080; text-decoration-color: #800080; font-style: italic\">None</span>:\n",
-       "        length = seq_len\n",
-       "\n",
-       "    trans_mat, obs_mat, init_dist = params.trans_mat, params.obs_mat, params.init_dist\n",
-       "\n",
-       "    trans_mat = jnp.array<span style=\"font-weight: bold\">(</span>trans_mat<span style=\"font-weight: bold\">)</span>\n",
-       "    obs_mat = jnp.array<span style=\"font-weight: bold\">(</span>obs_mat<span style=\"font-weight: bold\">)</span>\n",
-       "    init_dist = jnp.array<span style=\"font-weight: bold\">(</span>init_dist<span style=\"font-weight: bold\">)</span>\n",
-       "\n",
-       "    n_states, n_obs = obs_mat.shape\n",
-       "\n",
-       "    def scan_fn<span style=\"font-weight: bold\">(</span>carry, t<span style=\"font-weight: bold\">)</span>:\n",
-       "        <span style=\"font-weight: bold\">(</span>alpha_prev, log_ll_prev<span style=\"font-weight: bold\">)</span> = carry\n",
-       "        alpha_n = jnp.where<span style=\"font-weight: bold\">(</span>t &lt; length,\n",
-       "                            obs_mat<span style=\"font-weight: bold\">[</span>:, obs_seq<span style=\"font-weight: bold\">]</span> * <span style=\"font-weight: bold\">(</span>alpha_prev<span style=\"font-weight: bold\">[</span>:, <span style=\"color: #800080; text-decoration-color: #800080; font-style: italic\">None</span><span style=\"font-weight: bold\">]</span> * \n",
-       "trans_mat<span style=\"font-weight: bold\">)</span>.sum<span style=\"font-weight: bold\">(</span><span style=\"color: #808000; text-decoration-color: #808000\">axis</span>=<span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">0</span><span style=\"font-weight: bold\">)</span>,\n",
-       "                            jnp.zeros_like<span style=\"font-weight: bold\">(</span>alpha_prev<span style=\"font-weight: bold\">))</span>\n",
-       "\n",
-       "        alpha_n, cn = normalize<span style=\"font-weight: bold\">(</span>alpha_n<span style=\"font-weight: bold\">)</span>\n",
-       "        carry = <span style=\"font-weight: bold\">(</span>alpha_n, jnp.log<span style=\"font-weight: bold\">(</span>cn<span style=\"font-weight: bold\">)</span> + log_ll_prev<span style=\"font-weight: bold\">)</span>\n",
-       "\n",
-       "        return carry, alpha_n\n",
-       "\n",
-       "    # initial belief state\n",
-       "    alpha_0, c0 = normalize<span style=\"font-weight: bold\">(</span>init_dist * obs_mat<span style=\"font-weight: bold\">[</span>:, obs_seq<span style=\"font-weight: bold\">[</span><span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">0</span><span style=\"font-weight: bold\">]])</span>\n",
-       "\n",
-       "    # setup scan loop\n",
-       "    init_state = <span style=\"font-weight: bold\">(</span>alpha_0, jnp.log<span style=\"font-weight: bold\">(</span>c0<span style=\"font-weight: bold\">))</span>\n",
-       "    ts = jnp.arange<span style=\"font-weight: bold\">(</span><span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">1</span>, seq_len<span style=\"font-weight: bold\">)</span>\n",
-       "    carry, alpha_hist = lax.scan<span style=\"font-weight: bold\">(</span>scan_fn, init_state, ts<span style=\"font-weight: bold\">)</span>\n",
-       "\n",
-       "    # post-process\n",
-       "    alpha_hist = jnp.vstack<span style=\"font-weight: bold\">()</span>\n",
-       "    <span style=\"font-weight: bold\">(</span>alpha_final, log_ll<span style=\"font-weight: bold\">)</span> = carry\n",
-       "    return log_ll, alpha_hist\n",
-       "\n",
-       "</pre>\n"
-      ],
-      "text/plain": [
-       "<rich.jupyter.JupyterRenderable at 0x7fe75997dfa0>"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    }
-   ],
+   "outputs": [],
    "source": [
-    "import jsl.hmm.hmm_lib as hmm_lib\n",
-    "print_source(hmm_lib.hmm_forwards_jax)\n",
-    "#https://github.com/probml/JSL/blob/main/jsl/hmm/hmm_lib.py#L189\n",
-    "\n"
+    "\n",
+    "def _condition_on(probs, ll):\n",
+    "    ll_max = ll.max()\n",
+    "    new_probs = probs * jnp.exp(ll - ll_max)\n",
+    "    norm = new_probs.sum()\n",
+    "    new_probs /= norm\n",
+    "    log_norm = jnp.log(norm) + ll_max\n",
+    "    return new_probs, log_norm"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "With the above function, we can implement a more numerically stable version of the forwards filter,\n",
+    "that works for any likelihood function, as shown below. It takes in the prior predictive distribution,\n",
+    "$\\alpha_{t|t-1}$,\n",
+    "stored in `predicted_probs`, and conditions them on the log-likelihood for each time step $\\ell_t$ to get the\n",
+    "posterior, $\\alpha_t$, stored in `filtered_probs`,\n",
+    "which is then converted to the prediction for the next state, $\\alpha_{t+1|t}$."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def _predict(probs, A):\n",
+    "    return A.T @ probs\n",
+    "\n",
+    "\n",
+    "def hmm_filter(initial_distribution,\n",
+    "               transition_matrix,\n",
+    "               log_likelihoods):\n",
+    "    def _step(carry, t):\n",
+    "        log_normalizer, predicted_probs = carry\n",
+    "\n",
+    "        # Get parameters for time t\n",
+    "        get = lambda x: x[t] if x.ndim == 3 else x\n",
+    "        A = get(transition_matrix)\n",
+    "        ll = log_likelihoods[t]\n",
+    "\n",
+    "        # Condition on emissions at time t, being careful not to overflow\n",
+    "        filtered_probs, log_norm = _condition_on(predicted_probs, ll)\n",
+    "        # Update the log normalizer\n",
+    "        log_normalizer += log_norm\n",
+    "        # Predict the next state\n",
+    "        predicted_probs = _predict(filtered_probs, A)\n",
+    "\n",
+    "        return (log_normalizer, predicted_probs), (filtered_probs, predicted_probs)\n",
+    "\n",
+    "    num_timesteps = len(log_likelihoods)\n",
+    "    carry = (0.0, initial_distribution)\n",
+    "    (log_normalizer, _), (filtered_probs, predicted_probs) = lax.scan(\n",
+    "        _step, carry, jnp.arange(num_timesteps))\n",
+    "    return log_normalizer, filtered_probs, predicted_probs"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "\n",
+    "TODO: check equivalence of these two implementations!"
    ]
   }
  ],

+ 217 - 267
_build/jupyter_execute/chapters/hmm/hmm_filter.py

@@ -1,303 +1,87 @@
 #!/usr/bin/env python
 # coding: utf-8
 
-# ```{math}
-# 
-# \newcommand{\defeq}{\triangleq}
-# \newcommand{\trans}{{\mkern-1.5mu\mathsf{T}}}
-# \newcommand{\transpose}[1]{{#1}^{\trans}}
-# 
-# \newcommand{\inv}[1]{{#1}^{-1}}
-# \DeclareMathOperator{\dotstar}{\odot}
-# 
-# 
-# \newcommand\floor[1]{\lfloor#1\rfloor}
-# 
-# \newcommand{\real}{\mathbb{R}}
-# 
-# % Numbers
-# \newcommand{\vzero}{\boldsymbol{0}}
-# \newcommand{\vone}{\boldsymbol{1}}
-# 
-# % Greek https://www.latex-tutorial.com/symbols/greek-alphabet/
-# \newcommand{\valpha}{\boldsymbol{\alpha}}
-# \newcommand{\vbeta}{\boldsymbol{\beta}}
-# \newcommand{\vchi}{\boldsymbol{\chi}}
-# \newcommand{\vdelta}{\boldsymbol{\delta}}
-# \newcommand{\vDelta}{\boldsymbol{\Delta}}
-# \newcommand{\vepsilon}{\boldsymbol{\epsilon}}
-# \newcommand{\vzeta}{\boldsymbol{\zeta}}
-# \newcommand{\vXi}{\boldsymbol{\Xi}}
-# \newcommand{\vell}{\boldsymbol{\ell}}
-# \newcommand{\veta}{\boldsymbol{\eta}}
-# %\newcommand{\vEta}{\boldsymbol{\Eta}}
-# \newcommand{\vgamma}{\boldsymbol{\gamma}}
-# \newcommand{\vGamma}{\boldsymbol{\Gamma}}
-# \newcommand{\vmu}{\boldsymbol{\mu}}
-# \newcommand{\vmut}{\boldsymbol{\tilde{\mu}}}
-# \newcommand{\vnu}{\boldsymbol{\nu}}
-# \newcommand{\vkappa}{\boldsymbol{\kappa}}
-# \newcommand{\vlambda}{\boldsymbol{\lambda}}
-# \newcommand{\vLambda}{\boldsymbol{\Lambda}}
-# \newcommand{\vLambdaBar}{\overline{\vLambda}}
-# %\newcommand{\vnu}{\boldsymbol{\nu}}
-# \newcommand{\vomega}{\boldsymbol{\omega}}
-# \newcommand{\vOmega}{\boldsymbol{\Omega}}
-# \newcommand{\vphi}{\boldsymbol{\phi}}
-# \newcommand{\vvarphi}{\boldsymbol{\varphi}}
-# \newcommand{\vPhi}{\boldsymbol{\Phi}}
-# \newcommand{\vpi}{\boldsymbol{\pi}}
-# \newcommand{\vPi}{\boldsymbol{\Pi}}
-# \newcommand{\vpsi}{\boldsymbol{\psi}}
-# \newcommand{\vPsi}{\boldsymbol{\Psi}}
-# \newcommand{\vrho}{\boldsymbol{\rho}}
-# \newcommand{\vtheta}{\boldsymbol{\theta}}
-# \newcommand{\vthetat}{\boldsymbol{\tilde{\theta}}}
-# \newcommand{\vTheta}{\boldsymbol{\Theta}}
-# \newcommand{\vsigma}{\boldsymbol{\sigma}}
-# \newcommand{\vSigma}{\boldsymbol{\Sigma}}
-# \newcommand{\vSigmat}{\boldsymbol{\tilde{\Sigma}}}
-# \newcommand{\vsigmoid}{\vsigma}
-# \newcommand{\vtau}{\boldsymbol{\tau}}
-# \newcommand{\vxi}{\boldsymbol{\xi}}
-# 
-# 
-# % Lower Roman (Vectors)
-# \newcommand{\va}{\mathbf{a}}
-# \newcommand{\vb}{\mathbf{b}}
-# \newcommand{\vBt}{\mathbf{\tilde{B}}}
-# \newcommand{\vc}{\mathbf{c}}
-# \newcommand{\vct}{\mathbf{\tilde{c}}}
-# \newcommand{\vd}{\mathbf{d}}
-# \newcommand{\ve}{\mathbf{e}}
-# \newcommand{\vf}{\mathbf{f}}
-# \newcommand{\vg}{\mathbf{g}}
-# \newcommand{\vh}{\mathbf{h}}
-# %\newcommand{\myvh}{\mathbf{h}}
-# \newcommand{\vi}{\mathbf{i}}
-# \newcommand{\vj}{\mathbf{j}}
-# \newcommand{\vk}{\mathbf{k}}
-# \newcommand{\vl}{\mathbf{l}}
-# \newcommand{\vm}{\mathbf{m}}
-# \newcommand{\vn}{\mathbf{n}}
-# \newcommand{\vo}{\mathbf{o}}
-# \newcommand{\vp}{\mathbf{p}}
-# \newcommand{\vq}{\mathbf{q}}
-# \newcommand{\vr}{\mathbf{r}}
-# \newcommand{\vs}{\mathbf{s}}
-# \newcommand{\vt}{\mathbf{t}}
-# \newcommand{\vu}{\mathbf{u}}
-# \newcommand{\vv}{\mathbf{v}}
-# \newcommand{\vw}{\mathbf{w}}
-# \newcommand{\vws}{\vw_s}
-# \newcommand{\vwt}{\mathbf{\tilde{w}}}
-# \newcommand{\vWt}{\mathbf{\tilde{W}}}
-# \newcommand{\vwh}{\hat{\vw}}
-# \newcommand{\vx}{\mathbf{x}}
-# %\newcommand{\vx}{\mathbf{x}}
-# \newcommand{\vxt}{\mathbf{\tilde{x}}}
-# \newcommand{\vy}{\mathbf{y}}
-# \newcommand{\vyt}{\mathbf{\tilde{y}}}
-# \newcommand{\vz}{\mathbf{z}}
-# %\newcommand{\vzt}{\mathbf{\tilde{z}}}
-# 
-# 
-# % Upper Roman (Matrices)
-# \newcommand{\vA}{\mathbf{A}}
-# \newcommand{\vB}{\mathbf{B}}
-# \newcommand{\vC}{\mathbf{C}}
-# \newcommand{\vD}{\mathbf{D}}
-# \newcommand{\vE}{\mathbf{E}}
-# \newcommand{\vF}{\mathbf{F}}
-# \newcommand{\vG}{\mathbf{G}}
-# \newcommand{\vH}{\mathbf{H}}
-# \newcommand{\vI}{\mathbf{I}}
-# \newcommand{\vJ}{\mathbf{J}}
-# \newcommand{\vK}{\mathbf{K}}
-# \newcommand{\vL}{\mathbf{L}}
-# \newcommand{\vM}{\mathbf{M}}
-# \newcommand{\vMt}{\mathbf{\tilde{M}}}
-# \newcommand{\vN}{\mathbf{N}}
-# \newcommand{\vO}{\mathbf{O}}
-# \newcommand{\vP}{\mathbf{P}}
-# \newcommand{\vQ}{\mathbf{Q}}
-# \newcommand{\vR}{\mathbf{R}}
-# \newcommand{\vS}{\mathbf{S}}
-# \newcommand{\vT}{\mathbf{T}}
-# \newcommand{\vU}{\mathbf{U}}
-# \newcommand{\vV}{\mathbf{V}}
-# \newcommand{\vW}{\mathbf{W}}
-# \newcommand{\vX}{\mathbf{X}}
-# %\newcommand{\vXs}{\vX_{\vs}}
-# \newcommand{\vXs}{\vX_{s}}
-# \newcommand{\vXt}{\mathbf{\tilde{X}}}
-# \newcommand{\vY}{\mathbf{Y}}
-# \newcommand{\vZ}{\mathbf{Z}}
-# \newcommand{\vZt}{\mathbf{\tilde{Z}}}
-# \newcommand{\vzt}{\mathbf{\tilde{z}}}
-# 
-# 
-# %%%%
-# \newcommand{\hidden}{\vz}
-# \newcommand{\hid}{\hidden}
-# \newcommand{\observed}{\vy}
-# \newcommand{\obs}{\observed}
-# \newcommand{\inputs}{\vu}
-# \newcommand{\input}{\inputs}
-# 
-# \newcommand{\hmmTrans}{\vA}
-# \newcommand{\hmmObs}{\vB}
-# \newcommand{\hmmInit}{\vpi}
-# 
-# 
-# \newcommand{\ldsDyn}{\vA}
-# \newcommand{\ldsObs}{\vC}
-# \newcommand{\ldsDynIn}{\vB}
-# \newcommand{\ldsObsIn}{\vD}
-# \newcommand{\ldsDynNoise}{\vQ}
-# \newcommand{\ldsObsNoise}{\vR}
-# 
-# \newcommand{\ssmDynFn}{f}
-# \newcommand{\ssmObsFn}{h}
-# 
-# 
-# %%%
-# \newcommand{\gauss}{\mathcal{N}}
-# 
-# \newcommand{\diag}{\mathrm{diag}}
-# ```
-# 
+# (sec:forwards)=
+# # HMM filtering (forwards algorithm)
 
 # In[1]:
 
 
-# meta-data does not work yet in VScode
-# https://github.com/microsoft/vscode-jupyter/issues/1121
-
-{
-    "tags": [
-        "hide-cell"
-    ]
-}
-
-
-### Install necessary libraries
-
-try:
-    import jax
-except:
-    # For cuda version, see https://github.com/google/jax#installation
-    get_ipython().run_line_magic('pip', 'install --upgrade "jax[cpu]"')
-    import jax
-
-try:
-    import distrax
-except:
-    get_ipython().run_line_magic('pip', 'install --upgrade  distrax')
-    import distrax
-
-try:
-    import jsl
-except:
-    get_ipython().run_line_magic('pip', 'install git+https://github.com/probml/jsl')
-    import jsl
-
-#try:
-#    import ssm_jax
-##except:
-#    %pip install git+https://github.com/probml/ssm-jax
-#    import ssm_jax
-
-try:
-    import rich
-except:
-    get_ipython().run_line_magic('pip', 'install rich')
-    import rich
-
-
-
-# In[2]:
-
-
-{
-    "tags": [
-        "hide-cell"
-    ]
-}
-
-
 ### Import standard libraries
 
 import abc
 from dataclasses import dataclass
 import functools
+from functools import partial
 import itertools
-
-from typing import Any, Callable, NamedTuple, Optional, Union, Tuple
-
 import matplotlib.pyplot as plt
 import numpy as np
-
+from typing import Any, Callable, NamedTuple, Optional, Union, Tuple
 
 import jax
 import jax.numpy as jnp
 from jax import lax, vmap, jit, grad
-from jax.scipy.special import logit
-from jax.nn import softmax
-from functools import partial
-from jax.random import PRNGKey, split
+#from jax.scipy.special import logit
+#from jax.nn import softmax
+import jax.random as jr
 
-import inspect
-import inspect as py_inspect
-import rich
-from rich import inspect as r_inspect
-from rich import print as r_print
 
-def print_source(fname):
-    r_print(py_inspect.getsource(fname))
 
+import distrax
+import optax
 
-# (sec:forwards)=
-# # HMM filtering (forwards algorithm)
+import jsl
+import ssm_jax
+
+
+# 
+# ## Introduction
 # 
 # 
-# The  **Bayes filter** is an algorithm for recursively computing
+# The  $\keyword{Bayes filter}$ is an algorithm for recursively computing
 # the belief state
 # $p(\hidden_t|\obs_{1:t})$ given
 # the prior belief from the previous step,
 # $p(\hidden_{t-1}|\obs_{1:t-1})$,
 # the new observation $\obs_t$,
 # and the model.
-# This can be done using **sequential Bayesian updating**.
+# This can be done using $\keyword{sequential Bayesian updating}$.
 # For a dynamical model, this reduces to the
-# **predict-update** cycle described below.
-# 
+# $\keyword{predict-update}$ cycle described below.
 # 
-# The **prediction step** is just the **Chapman-Kolmogorov equation**:
-# ```{math}
+# The $\keyword{prediction step}$ is just the $\keyword{Chapman-Kolmogorov equation}$:
+# \begin{align}
 # p(\hidden_t|\obs_{1:t-1})
 # = \int p(\hidden_t|\hidden_{t-1}) p(\hidden_{t-1}|\obs_{1:t-1}) d\hidden_{t-1}
-# ```
+# \end{align}
 # The prediction step computes
-# the one-step-ahead predictive distribution
-# for the latent state, which updates
-# the posterior from the previous time step into the prior
+# the $\keyword{one-step-ahead predictive distribution}$
+# for the latent state, which converts
+# the posterior from the previous time step to become the prior
 # for the current step.
 # 
 # 
-# The **update step**
+# The $\keyword{update step}$
 # is just Bayes rule:
-# ```{math}
+# \begin{align}
 # p(\hidden_t|\obs_{1:t}) = \frac{1}{Z_t}
 # p(\obs_t|\hidden_t) p(\hidden_t|\obs_{1:t-1})
-# ```
+# \end{align}
 # where the normalization constant is
-# ```{math}
+# \begin{align}
 # Z_t = \int p(\obs_t|\hidden_t) p(\hidden_t|\obs_{1:t-1}) d\hidden_{t}
 # = p(\obs_t|\obs_{1:t-1})
-# ```
+# \end{align}
 # 
+# Note that we can derive the log marginal likelihood from these normalization constants
+# as follows:
+# ```{math}
+# :label: eqn:logZ
 # 
+# \log p(\obs_{1:T})
+# = \sum_{t=1}^{T} \log p(\obs_t|\obs_{1:t-1})
+# = \sum_{t=1}^{T} \log Z_t
+# ```
 # 
 # 
 
@@ -305,10 +89,11 @@ def print_source(fname):
 # When the latent states $\hidden_t$ are discrete, as in HMM,
 # the above integrals become sums.
 # In particular, suppose we define
-# the belief state as $\alpha_t(j) \defeq p(\hidden_t=j|\obs_{1:t})$,
-# the local evidence as $\lambda_t(j) \defeq p(\obs_t|\hidden_t=j)$,
-# and the transition matrix
-# $A(i,j)  = p(\hidden_t=j|\hidden_{t-1}=i)$.
+# the $\keyword{belief state}$ as $\alpha_t(j) \defeq p(\hidden_t=j|\obs_{1:t})$,
+# the  $\keyword{local evidence}$ (or $\keyword{local likelihood}$)
+# as $\lambda_t(j) \defeq p(\obs_t|\hidden_t=j)$,
+# and the transition matrix as
+# $\hmmTrans(i,j)  = p(\hidden_t=j|\hidden_{t-1}=i)$.
 # Then the predict step becomes
 # ```{math}
 # :label: eqn:predictiveHMM
@@ -320,7 +105,7 @@ def print_source(fname):
 # :label: eqn:fwdsEqn
 # \alpha_t(j)
 # = \frac{1}{Z_t} \lambda_t(j) \alpha_{t|t-1}(j)
-# = \frac{1}{Z_t} \lambda_t(j) \left[\sum_i \alpha_{t-1}(i) A(i,j)  \right]
+# = \frac{1}{Z_t} \lambda_t(j) \left[\sum_i \alpha_{t-1}(i) \hmmTrans(i,j)  \right]
 # ```
 # where
 # the  normalization constant for each time step is given by
@@ -333,19 +118,24 @@ def print_source(fname):
 # \end{align}
 # ```
 # 
+# 
+
+# 
 # Since all the quantities are finite length vectors and matrices,
-# we can write the update equation
-# in matrix-vector notation as follows:
+# we can implement the whole procedure using matrix vector multoplication:
 # ```{math}
+# :label: eqn:fwdsAlgoMatrixForm
 # \valpha_t =\text{normalize}\left(
-# \vlambda_t \dotstar  (\vA^{\trans} \valpha_{t-1}) \right)
-# \label{eqn:fwdsAlgoMatrixForm}
+# \vlambda_t \dotstar  (\hmmTrans^{\trans} \valpha_{t-1}) \right)
 # ```
 # where $\dotstar$ represents
 # elementwise vector multiplication,
 # and the $\text{normalize}$ function just ensures its argument sums to one.
 # 
-# In {ref}(sec:casino-inference)
+
+# ## Example
+# 
+# In {ref}`sec:casino-inference`
 # we illustrate
 # filtering for the casino HMM,
 # applied to a random sequence $\obs_{1:T}$ of length $T=300$.
@@ -354,14 +144,174 @@ def print_source(fname):
 # The gray bars indicate time intervals during which the generative
 # process actually switched to the loaded dice.
 # We see that the probability generally increases in the right places.
+
+# ## Normalization constants
+# 
+# In most publications on HMMs,
+# such as {cite}`Rabiner89`,
+# the forwards message is defined
+# as the following unnormalized joint probability:
+# ```{math}
+# \alpha'_t(j) = p(\hidden_t=j,\obs_{1:t}) 
+# = \lambda_t(j) \left[\sum_i \alpha'_{t-1}(i) A(i,j)  \right]
+# ```
+# In this book we define the forwards message   as the normalized
+# conditional probability
+# ```{math}
+# \alpha_t(j) = p(\hidden_t=j|\obs_{1:t}) 
+# = \frac{1}{Z_t} \lambda_t(j) \left[\sum_i \alpha_{t-1}(i) A(i,j)  \right]
+# ```
+# where $Z_t = p(\obs_t|\obs_{1:t-1})$.
+# 
+# The "traditional" unnormalized form has several problems.
+# First, it rapidly suffers from numerical underflow,
+# since the probability of
+# the joint event that $(\hidden_t=j,\obs_{1:t})$
+# is vanishingly small. 
+# To see why, suppose the observations are independent of the states.
+# In this case, the unnormalized joint has the form
+# \begin{align}
+# p(\hidden_t=j,\obs_{1:t}) = p(\hidden_t=j)\prod_{i=1}^t p(\obs_i)
+# \end{align}
+# which becomes exponentially small with $t$, because we multiply
+# many probabilities which are less than one.
+# Second, the unnormalized probability is less interpretable,
+# since it is a joint distribution over states and observations,
+# rather than a conditional probability of states given observations.
+# Third, the unnormalized joint form is harder to approximate
+# than the normalized form.
+# Of course,
+# the two definitions only differ by a
+# multiplicative constant
+# {cite}`Devijver85`,
+# so the algorithmic difference is just
+# one line of code (namely the presence or absence of a call to the `normalize` function).
+# 
+# 
 # 
+# 
+# 
+
+# ## Naive implementation
+# 
+# Below we give a simple numpy implementation of the forwards algorithm.
+# We assume the HMM uses categorical observations, for simplicity.
+# 
+# 
+
+# In[2]:
+
+
+
+
+def normalize_np(u, axis=0, eps=1e-15):
+    u = np.where(u == 0, 0, np.where(u < eps, eps, u))
+    c = u.sum(axis=axis)
+    c = np.where(c == 0, 1, c)
+    return u / c, c
+
+def hmm_forwards_np(trans_mat, obs_mat, init_dist, obs_seq):
+    n_states, n_obs = obs_mat.shape
+    seq_len = len(obs_seq)
+
+    alpha_hist = np.zeros((seq_len, n_states))
+    ll_hist = np.zeros(seq_len)  # loglikelihood history
+
+    alpha_n = init_dist * obs_mat[:, obs_seq[0]]
+    alpha_n, cn = normalize_np(alpha_n)
+
+    alpha_hist[0] = alpha_n
+    log_normalizer = np.log(cn)
 
-# Here is a JAX implementation of the forwards algorithm.
+    for t in range(1, seq_len):
+        alpha_n = obs_mat[:, obs_seq[t]] * (alpha_n[:, None] * trans_mat).sum(axis=0)
+        alpha_n, zn = normalize_np(alpha_n)
+
+        alpha_hist[t] = alpha_n
+        log_normalizer = np.log(zn) + log_normalizer
+
+    return  log_normalizer, alpha_hist
+
+
+# ## Numerically stable implementation 
+# 
+# 
+# 
+# In practice it is more numerically stable to compute
+# the log likelihoods $\ell_t(j) = \log p(\obs_t|\hidden_t=j)$,
+# rather than the likelioods $\lambda_t(j) = p(\obs_t|\hidden_t=j)$.
+# In this case, we can perform the posterior updating in a numerically stable way as follows.
+# Define $L_t = \max_j \ell_t(j)$ and
+# \begin{align}
+# \tilde{p}(\hidden_t=j,\obs_t|\obs_{1:t-1})
+# &\defeq p(\hidden_t=j|\obs_{1:t-1}) p(\obs_t|\hidden_t=j) e^{-L_t} \\
+#  &= p(\hidden_t=j|\obs_{1:t-1}) e^{\ell_t(j) - L_t}
+# \end{align}
+# Then we have
+# \begin{align}
+# p(\hidden_t=j|\obs_t,\obs_{1:t-1})
+#   &= \frac{1}{\tilde{Z}_t} \tilde{p}(\hidden_t=j,\obs_t|\obs_{1:t-1}) \\
+# \tilde{Z}_t &= \sum_j \tilde{p}(\hidden_t=j,\obs_t|\obs_{1:t-1})
+# = p(\obs_t|\obs_{1:t-1}) e^{-L_t} \\
+# \log Z_t &= \log p(\obs_t|\obs_{1:t-1}) = \log \tilde{Z}_t + L_t
+# \end{align}
+# 
+# Below we show some JAX code that implements this core operation.
+# 
 
 # In[3]:
 
 
-import jsl.hmm.hmm_lib as hmm_lib
-print_source(hmm_lib.hmm_forwards_jax)
-#https://github.com/probml/JSL/blob/main/jsl/hmm/hmm_lib.py#L189
 
+def _condition_on(probs, ll):
+    ll_max = ll.max()
+    new_probs = probs * jnp.exp(ll - ll_max)
+    norm = new_probs.sum()
+    new_probs /= norm
+    log_norm = jnp.log(norm) + ll_max
+    return new_probs, log_norm
+
+
+# With the above function, we can implement a more numerically stable version of the forwards filter,
+# that works for any likelihood function, as shown below. It takes in the prior predictive distribution,
+# $\alpha_{t|t-1}$,
+# stored in `predicted_probs`, and conditions them on the log-likelihood for each time step $\ell_t$ to get the
+# posterior, $\alpha_t$, stored in `filtered_probs`,
+# which is then converted to the prediction for the next state, $\alpha_{t+1|t}$.
+
+# In[4]:
+
+
+def _predict(probs, A):
+    return A.T @ probs
+
+
+def hmm_filter(initial_distribution,
+               transition_matrix,
+               log_likelihoods):
+    def _step(carry, t):
+        log_normalizer, predicted_probs = carry
+
+        # Get parameters for time t
+        get = lambda x: x[t] if x.ndim == 3 else x
+        A = get(transition_matrix)
+        ll = log_likelihoods[t]
+
+        # Condition on emissions at time t, being careful not to overflow
+        filtered_probs, log_norm = _condition_on(predicted_probs, ll)
+        # Update the log normalizer
+        log_normalizer += log_norm
+        # Predict the next state
+        predicted_probs = _predict(filtered_probs, A)
+
+        return (log_normalizer, predicted_probs), (filtered_probs, predicted_probs)
+
+    num_timesteps = len(log_likelihoods)
+    carry = (0.0, initial_distribution)
+    (log_normalizer, _), (filtered_probs, predicted_probs) = lax.scan(
+        _step, carry, jnp.arange(num_timesteps))
+    return log_normalizer, filtered_probs, predicted_probs
+
+
+# 
+# TODO: check equivalence of these two implementations!

+ 6 - 6
_build/jupyter_execute/chapters/scratch.ipynb

@@ -2,7 +2,7 @@
  "cells": [
   {
    "cell_type": "markdown",
-   "id": "d79c3788",
+   "id": "9d5991e9",
    "metadata": {},
    "source": [
     "(ch:intro)=\n",
@@ -28,7 +28,7 @@
   {
    "cell_type": "code",
    "execution_count": 1,
-   "id": "28c55864",
+   "id": "08375669",
    "metadata": {},
    "outputs": [],
    "source": [
@@ -41,7 +41,7 @@
   {
    "cell_type": "code",
    "execution_count": 2,
-   "id": "b0bc0f28",
+   "id": "b5b2f4ec",
    "metadata": {},
    "outputs": [
     {
@@ -84,7 +84,7 @@
   {
    "cell_type": "code",
    "execution_count": 3,
-   "id": "902c33c0",
+   "id": "06fcf1f1",
    "metadata": {},
    "outputs": [
     {
@@ -113,7 +113,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "945c9164",
+   "id": "8688e6c2",
    "metadata": {},
    "source": [
     "## Images\n",
@@ -151,7 +151,7 @@
     "\n",
     "## Math\n",
     "\n",
-    "Here is $\\N=10$ and blah. $\\floor{42.3}= 42$.\n",
+    "Here is $\\N=10$ and blah. $\\floor{42.3}= 42$. Let's try again.\n",
     "\n",
     "We have $E= mc^2$, and also\n",
     "\n",

+ 1 - 1
_build/jupyter_execute/chapters/scratch.py

@@ -98,7 +98,7 @@ print(jax.devices())
 # 
 # ## Math
 # 
-# Here is $\N=10$ and blah. $\floor{42.3}= 42$.
+# Here is $\N=10$ and blah. $\floor{42.3}= 42$. Let's try again.
 # 
 # We have $E= mc^2$, and also
 # 

+ 434 - 0
_build/jupyter_execute/chapters/scratchpad.ipynb

@@ -0,0 +1,434 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "### Import standard libraries\n",
+    "\n",
+    "import abc\n",
+    "from dataclasses import dataclass\n",
+    "import functools\n",
+    "from functools import partial\n",
+    "import itertools\n",
+    "import matplotlib.pyplot as plt\n",
+    "import numpy as np\n",
+    "from typing import Any, Callable, NamedTuple, Optional, Union, Tuple\n",
+    "\n",
+    "import jax\n",
+    "import jax.numpy as jnp\n",
+    "from jax import lax, vmap, jit, grad\n",
+    "#from jax.scipy.special import logit\n",
+    "#from jax.nn import softmax\n",
+    "import jax.random as jr\n",
+    "\n",
+    "\n",
+    "\n",
+    "import distrax\n",
+    "import optax\n",
+    "\n",
+    "import jsl\n",
+    "import ssm_jax\n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import inspect\n",
+    "import inspect as py_inspect\n",
+    "import rich\n",
+    "from rich import inspect as r_inspect\n",
+    "from rich import print as r_print\n",
+    "\n",
+    "def print_source(fname):\n",
+    "    r_print(py_inspect.getsource(fname))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# meta-data does not work yet in VScode\n",
+    "# https://github.com/microsoft/vscode-jupyter/issues/1121\n",
+    "\n",
+    "{\n",
+    "    \"tags\": [\n",
+    "        \"hide-cell\"\n",
+    "    ]\n",
+    "}\n",
+    "\n",
+    "\n",
+    "### Install necessary libraries\n",
+    "\n",
+    "try:\n",
+    "    import jax\n",
+    "except:\n",
+    "    # For cuda version, see https://github.com/google/jax#installation\n",
+    "    %pip install --upgrade \"jax[cpu]\" \n",
+    "    import jax\n",
+    "\n",
+    "try:\n",
+    "    import distrax\n",
+    "except:\n",
+    "    %pip install --upgrade  distrax\n",
+    "    import distrax\n",
+    "\n",
+    "try:\n",
+    "    import jsl\n",
+    "except:\n",
+    "    %pip install git+https://github.com/probml/jsl\n",
+    "    import jsl\n",
+    "\n",
+    "try:\n",
+    "    import rich\n",
+    "except:\n",
+    "    %pip install rich\n",
+    "    import rich\n",
+    "\n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "{\n",
+    "    \"tags\": [\n",
+    "        \"hide-cell\"\n",
+    "    ]\n",
+    "}\n",
+    "\n",
+    "\n",
+    "### Import standard libraries\n",
+    "\n",
+    "import abc\n",
+    "from dataclasses import dataclass\n",
+    "import functools\n",
+    "import itertools\n",
+    "\n",
+    "from typing import Any, Callable, NamedTuple, Optional, Union, Tuple\n",
+    "\n",
+    "import matplotlib.pyplot as plt\n",
+    "import numpy as np\n",
+    "\n",
+    "\n",
+    "import jax\n",
+    "import jax.numpy as jnp\n",
+    "from jax import lax, vmap, jit, grad\n",
+    "from jax.scipy.special import logit\n",
+    "from jax.nn import softmax\n",
+    "from functools import partial\n",
+    "from jax.random import PRNGKey, split\n",
+    "\n",
+    "import inspect\n",
+    "import inspect as py_inspect\n",
+    "import rich\n",
+    "from rich import inspect as r_inspect\n",
+    "from rich import print as r_print\n",
+    "\n",
+    "def print_source(fname):\n",
+    "    r_print(py_inspect.getsource(fname))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">class GaussianHMM<span style=\"font-weight: bold\">(</span>BaseHMM<span style=\"font-weight: bold\">)</span>:\n",
+       "    def __init__<span style=\"font-weight: bold\">(</span>self,\n",
+       "                 initial_probabilities,\n",
+       "                 transition_matrix,\n",
+       "                 emission_means,\n",
+       "                 emission_covariance_matrices<span style=\"font-weight: bold\">)</span>:\n",
+       "        <span style=\"color: #008000; text-decoration-color: #008000\">\"\"</span>\"_summary_\n",
+       "\n",
+       "        Args:\n",
+       "            initial_probabilities <span style=\"font-weight: bold\">(</span>_type_<span style=\"font-weight: bold\">)</span>: _description_\n",
+       "            transition_matrix <span style=\"font-weight: bold\">(</span>_type_<span style=\"font-weight: bold\">)</span>: _description_\n",
+       "            emission_means <span style=\"font-weight: bold\">(</span>_type_<span style=\"font-weight: bold\">)</span>: _description_\n",
+       "            emission_covariance_matrices <span style=\"font-weight: bold\">(</span>_type_<span style=\"font-weight: bold\">)</span>: _description_\n",
+       "        <span style=\"color: #008000; text-decoration-color: #008000\">\"\"</span>\"\n",
+       "        super<span style=\"font-weight: bold\">()</span>.__init__<span style=\"font-weight: bold\">(</span>initial_probabilities,\n",
+       "                         transition_matrix<span style=\"font-weight: bold\">)</span>\n",
+       "\n",
+       "        self._emission_distribution = tfd.MultivariateNormalFullCovariance<span style=\"font-weight: bold\">(</span>\n",
+       "            emission_means, emission_covariance_matrices<span style=\"font-weight: bold\">)</span>\n",
+       "\n",
+       "    @classmethod\n",
+       "    def random_initialization<span style=\"font-weight: bold\">(</span>cls, key, num_states, emission_dim<span style=\"font-weight: bold\">)</span>:\n",
+       "        key1, key2, key3 = jr.split<span style=\"font-weight: bold\">(</span>key, <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">3</span><span style=\"font-weight: bold\">)</span>\n",
+       "        initial_probs = jr.dirichlet<span style=\"font-weight: bold\">(</span>key1, jnp.ones<span style=\"font-weight: bold\">(</span>num_states<span style=\"font-weight: bold\">))</span>\n",
+       "        transition_matrix = jr.dirichlet<span style=\"font-weight: bold\">(</span>key2, jnp.ones<span style=\"font-weight: bold\">(</span>num_states<span style=\"font-weight: bold\">)</span>, <span style=\"font-weight: bold\">(</span>num_states,<span style=\"font-weight: bold\">))</span>\n",
+       "        emission_means = jr.normal<span style=\"font-weight: bold\">(</span>key3, <span style=\"font-weight: bold\">(</span>num_states, emission_dim<span style=\"font-weight: bold\">))</span>\n",
+       "        emission_covs = jnp.tile<span style=\"font-weight: bold\">(</span>jnp.eye<span style=\"font-weight: bold\">(</span>emission_dim<span style=\"font-weight: bold\">)</span>, <span style=\"font-weight: bold\">(</span>num_states, <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">1</span>, <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">1</span><span style=\"font-weight: bold\">))</span>\n",
+       "        return cls<span style=\"font-weight: bold\">(</span>initial_probs, transition_matrix, emission_means, emission_covs<span style=\"font-weight: bold\">)</span>\n",
+       "\n",
+       "    # Properties to get various parameters of the model\n",
+       "    @property\n",
+       "    def emission_distribution<span style=\"font-weight: bold\">(</span>self<span style=\"font-weight: bold\">)</span>:\n",
+       "        return self._emission_distribution\n",
+       "\n",
+       "    @property\n",
+       "    def emission_means<span style=\"font-weight: bold\">(</span>self<span style=\"font-weight: bold\">)</span>:\n",
+       "        return self.emission_distribution.mean<span style=\"font-weight: bold\">()</span>\n",
+       "\n",
+       "    @property\n",
+       "    def emission_covariance_matrices<span style=\"font-weight: bold\">(</span>self<span style=\"font-weight: bold\">)</span>:\n",
+       "        return self.emission_distribution.covariance<span style=\"font-weight: bold\">()</span>\n",
+       "\n",
+       "    @property\n",
+       "    def unconstrained_params<span style=\"font-weight: bold\">(</span>self<span style=\"font-weight: bold\">)</span>:\n",
+       "        <span style=\"color: #008000; text-decoration-color: #008000\">\"\"</span>\"Helper property to get a PyTree of unconstrained parameters.\n",
+       "        <span style=\"color: #008000; text-decoration-color: #008000\">\"\"</span>\"\n",
+       "        return tfb.SoftmaxCentered<span style=\"font-weight: bold\">()</span>.inverse<span style=\"font-weight: bold\">(</span>self.initial_probabilities<span style=\"font-weight: bold\">)</span>, \\\n",
+       "               tfb.SoftmaxCentered<span style=\"font-weight: bold\">()</span>.inverse<span style=\"font-weight: bold\">(</span>self.transition_matrix<span style=\"font-weight: bold\">)</span>, \\\n",
+       "               self.emission_means, \\\n",
+       "               PSDToRealBijector.forward<span style=\"font-weight: bold\">(</span>self.emission_covariance_matrices<span style=\"font-weight: bold\">)</span>\n",
+       "\n",
+       "    @classmethod\n",
+       "    def from_unconstrained_params<span style=\"font-weight: bold\">(</span>cls, unconstrained_params, hypers<span style=\"font-weight: bold\">)</span>:\n",
+       "        initial_probabilities = tfb.SoftmaxCentered<span style=\"font-weight: bold\">()</span>.forward<span style=\"font-weight: bold\">(</span>unconstrained_params<span style=\"font-weight: bold\">[</span><span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">0</span><span style=\"font-weight: bold\">])</span>\n",
+       "        transition_matrix = tfb.SoftmaxCentered<span style=\"font-weight: bold\">()</span>.forward<span style=\"font-weight: bold\">(</span>unconstrained_params<span style=\"font-weight: bold\">[</span><span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">1</span><span style=\"font-weight: bold\">])</span>\n",
+       "        emission_means = unconstrained_params<span style=\"font-weight: bold\">[</span><span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">2</span><span style=\"font-weight: bold\">]</span>\n",
+       "        emission_covs = PSDToRealBijector.inverse<span style=\"font-weight: bold\">(</span>unconstrained_params<span style=\"font-weight: bold\">[</span><span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">3</span><span style=\"font-weight: bold\">])</span>\n",
+       "        return cls<span style=\"font-weight: bold\">(</span>initial_probabilities, transition_matrix, emission_means, emission_covs, \n",
+       "*hypers<span style=\"font-weight: bold\">)</span>\n",
+       "\n",
+       "</pre>\n"
+      ],
+      "text/plain": [
+       "<rich.jupyter.JupyterRenderable at 0x7fdea0b54460>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "import ssm_jax\n",
+    "from ssm_jax.hmm.models import GaussianHMM\n",
+    "\n",
+    "print_source(GaussianHMM)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
+     ]
+    },
+    {
+     "data": {
+      "text/html": [
+       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">    def sample<span style=\"font-weight: bold\">(</span>self, key, num_timesteps<span style=\"font-weight: bold\">)</span>:\n",
+       "        <span style=\"color: #008000; text-decoration-color: #008000\">\"\"</span>\"Sample a sequence of latent states and emissions.\n",
+       "\n",
+       "        Args:\n",
+       "            key <span style=\"font-weight: bold\">(</span>_type_<span style=\"font-weight: bold\">)</span>: _description_\n",
+       "            num_timesteps <span style=\"font-weight: bold\">(</span>_type_<span style=\"font-weight: bold\">)</span>: _description_\n",
+       "        <span style=\"color: #008000; text-decoration-color: #008000\">\"\"</span>\"\n",
+       "        def _step<span style=\"font-weight: bold\">(</span>state, key<span style=\"font-weight: bold\">)</span>:\n",
+       "            key1, key2 = jr.split<span style=\"font-weight: bold\">(</span>key, <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">2</span><span style=\"font-weight: bold\">)</span>\n",
+       "            emission = self.emission_distribution.sample<span style=\"font-weight: bold\">(</span><span style=\"color: #808000; text-decoration-color: #808000\">seed</span>=<span style=\"color: #800080; text-decoration-color: #800080\">key1</span><span style=\"font-weight: bold\">)</span>\n",
+       "            next_state = self.transition_distribution.sample<span style=\"font-weight: bold\">(</span><span style=\"color: #808000; text-decoration-color: #808000\">seed</span>=<span style=\"color: #800080; text-decoration-color: #800080\">key2</span><span style=\"font-weight: bold\">)</span>\n",
+       "            return next_state, <span style=\"font-weight: bold\">(</span>state, emission<span style=\"font-weight: bold\">)</span>\n",
+       "\n",
+       "        # Sample the initial state\n",
+       "        key1, key = jr.split<span style=\"font-weight: bold\">(</span>key, <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">2</span><span style=\"font-weight: bold\">)</span>\n",
+       "        initial_state = self.initial_distribution.sample<span style=\"font-weight: bold\">(</span><span style=\"color: #808000; text-decoration-color: #808000\">seed</span>=<span style=\"color: #800080; text-decoration-color: #800080\">key1</span><span style=\"font-weight: bold\">)</span>\n",
+       "\n",
+       "        # Sample the remaining emissions and states\n",
+       "        keys = jr.split<span style=\"font-weight: bold\">(</span>key, num_timesteps<span style=\"font-weight: bold\">)</span>\n",
+       "        _, <span style=\"font-weight: bold\">(</span>states, emissions<span style=\"font-weight: bold\">)</span> = lax.scan<span style=\"font-weight: bold\">(</span>_step, initial_state, keys<span style=\"font-weight: bold\">)</span>\n",
+       "        return states, emissions\n",
+       "\n",
+       "</pre>\n"
+      ],
+      "text/plain": [
+       "<rich.jupyter.JupyterRenderable at 0x7fde91c661c0>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "# Set dimensions\n",
+    "num_states = 5\n",
+    "emission_dim = 2\n",
+    "\n",
+    "# Specify parameters of the HMM\n",
+    "initial_probs = jnp.ones(num_states) / num_states\n",
+    "transition_matrix = 0.95 * jnp.eye(num_states) + 0.05 * jnp.roll(jnp.eye(num_states), 1, axis=1)\n",
+    "emission_means = jnp.column_stack([\n",
+    "    jnp.cos(jnp.linspace(0, 2 * jnp.pi, num_states+1))[:-1],\n",
+    "    jnp.sin(jnp.linspace(0, 2 * jnp.pi, num_states+1))[:-1]\n",
+    "])\n",
+    "emission_covs = jnp.tile(0.1**2 * jnp.eye(emission_dim), (num_states, 1, 1))\n",
+    "\n",
+    "hmm = GaussianHMM(initial_probs,\n",
+    "                       transition_matrix,\n",
+    "                       emission_means, \n",
+    "                       emission_covs)\n",
+    "\n",
+    "print_source(hmm.sample)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "<distrax._src.utils.hmm.HMM object at 0x7fde82c856d0>\n"
+     ]
+    }
+   ],
+   "source": [
+    "import distrax\n",
+    "from distrax import HMM\n",
+    "\n",
+    "A = np.array([\n",
+    "    [0.95, 0.05],\n",
+    "    [0.10, 0.90]\n",
+    "])\n",
+    "\n",
+    "# observation matrix\n",
+    "B = np.array([\n",
+    "    [1/6, 1/6, 1/6, 1/6, 1/6, 1/6], # fair die\n",
+    "    [1/10, 1/10, 1/10, 1/10, 1/10, 5/10] # loaded die\n",
+    "])\n",
+    "\n",
+    "pi = np.array([0.5, 0.5])\n",
+    "\n",
+    "(nstates, nobs) = np.shape(B)\n",
+    "\n",
+    "hmm = HMM(trans_dist=distrax.Categorical(probs=A),\n",
+    "            init_dist=distrax.Categorical(probs=pi),\n",
+    "            obs_dist=distrax.Categorical(probs=B))\n",
+    "\n",
+    "print(hmm)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">  def sample<span style=\"font-weight: bold\">(</span>self,\n",
+       "             *,\n",
+       "             seed: chex.PRNGKey,\n",
+       "             seq_len: chex.Array<span style=\"font-weight: bold\">)</span> -&gt; Tuple:\n",
+       "    <span style=\"color: #008000; text-decoration-color: #008000\">\"\"</span>\"Sample from this HMM.\n",
+       "\n",
+       "    Samples an observation of given length according to this\n",
+       "    Hidden Markov Model and gives the sequence of the hidden states\n",
+       "    as well as the observation.\n",
+       "\n",
+       "    Args:\n",
+       "      seed: Random key of shape <span style=\"font-weight: bold\">(</span><span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">2</span>,<span style=\"font-weight: bold\">)</span> and dtype uint32.\n",
+       "      seq_len: The length of the observation sequence.\n",
+       "\n",
+       "    Returns:\n",
+       "      Tuple of hidden state sequence, and observation sequence.\n",
+       "    <span style=\"color: #008000; text-decoration-color: #008000\">\"\"</span>\"\n",
+       "    rng_key, rng_init = jax.random.split<span style=\"font-weight: bold\">(</span>seed<span style=\"font-weight: bold\">)</span>\n",
+       "    initial_state = self._init_dist.sample<span style=\"font-weight: bold\">(</span><span style=\"color: #808000; text-decoration-color: #808000\">seed</span>=<span style=\"color: #800080; text-decoration-color: #800080\">rng_init</span><span style=\"font-weight: bold\">)</span>\n",
+       "\n",
+       "    def draw_state<span style=\"font-weight: bold\">(</span>prev_state, key<span style=\"font-weight: bold\">)</span>:\n",
+       "      state = self._trans_dist.sample<span style=\"font-weight: bold\">(</span><span style=\"color: #808000; text-decoration-color: #808000\">seed</span>=<span style=\"color: #800080; text-decoration-color: #800080\">key</span><span style=\"font-weight: bold\">)</span>\n",
+       "      return state, state\n",
+       "\n",
+       "    rng_state, rng_obs = jax.random.split<span style=\"font-weight: bold\">(</span>rng_key<span style=\"font-weight: bold\">)</span>\n",
+       "    keys = jax.random.split<span style=\"font-weight: bold\">(</span>rng_state, seq_len - <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">1</span><span style=\"font-weight: bold\">)</span>\n",
+       "    _, states = jax.lax.scan<span style=\"font-weight: bold\">(</span>draw_state, initial_state, keys<span style=\"font-weight: bold\">)</span>\n",
+       "    states = jnp.append<span style=\"font-weight: bold\">(</span>initial_state, states<span style=\"font-weight: bold\">)</span>\n",
+       "\n",
+       "    def draw_obs<span style=\"font-weight: bold\">(</span>state, key<span style=\"font-weight: bold\">)</span>:\n",
+       "      return self._obs_dist.sample<span style=\"font-weight: bold\">(</span><span style=\"color: #808000; text-decoration-color: #808000\">seed</span>=<span style=\"color: #800080; text-decoration-color: #800080\">key</span><span style=\"font-weight: bold\">)</span>\n",
+       "\n",
+       "    keys = jax.random.split<span style=\"font-weight: bold\">(</span>rng_obs, seq_len<span style=\"font-weight: bold\">)</span>\n",
+       "    obs_seq = jax.vmap<span style=\"font-weight: bold\">(</span>draw_obs, <span style=\"color: #808000; text-decoration-color: #808000\">in_axes</span>=<span style=\"font-weight: bold\">(</span><span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">0</span>, <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">0</span><span style=\"font-weight: bold\">))(</span>states, keys<span style=\"font-weight: bold\">)</span>\n",
+       "\n",
+       "    return states, obs_seq\n",
+       "\n",
+       "</pre>\n"
+      ],
+      "text/plain": [
+       "<rich.jupyter.JupyterRenderable at 0x7fdea03643d0>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "print_source(hmm.sample)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "interpreter": {
+   "hash": "6407c60499271029b671b4ff687c4ed4626355c45fd34c44476827f4be42c4d7"
+  },
+  "kernelspec": {
+   "display_name": "Python 3.9.2 ('spyder-dev')",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.8.5"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}

+ 0 - 0
_build/jupyter_execute/chapters/scratchpad.py


部分文件因文件數量過多而無法顯示