nlds.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341
  1. #!/usr/bin/env python
  2. # coding: utf-8
  3. # In[1]:
  4. # meta-data does not work yet in VScode
  5. # https://github.com/microsoft/vscode-jupyter/issues/1121
  6. {
  7. "tags": [
  8. "hide-cell"
  9. ]
  10. }
  11. ### Install necessary libraries
  12. try:
  13. import jax
  14. except:
  15. # For cuda version, see https://github.com/google/jax#installation
  16. get_ipython().run_line_magic('pip', 'install --upgrade "jax[cpu]"')
  17. import jax
  18. try:
  19. import distrax
  20. except:
  21. get_ipython().run_line_magic('pip', 'install --upgrade distrax')
  22. import distrax
  23. try:
  24. import jsl
  25. except:
  26. get_ipython().run_line_magic('pip', 'install git+https://github.com/probml/jsl')
  27. import jsl
  28. try:
  29. import rich
  30. except:
  31. get_ipython().run_line_magic('pip', 'install rich')
  32. import rich
  33. # In[2]:
  34. {
  35. "tags": [
  36. "hide-cell"
  37. ]
  38. }
  39. ### Import standard libraries
  40. import abc
  41. from dataclasses import dataclass
  42. import functools
  43. import itertools
  44. from typing import Any, Callable, NamedTuple, Optional, Union, Tuple
  45. import matplotlib.pyplot as plt
  46. import numpy as np
  47. import jax
  48. import jax.numpy as jnp
  49. from jax import lax, vmap, jit, grad
  50. from jax.scipy.special import logit
  51. from jax.nn import softmax
  52. from functools import partial
  53. from jax.random import PRNGKey, split
  54. import inspect
  55. import inspect as py_inspect
  56. import rich
  57. from rich import inspect as r_inspect
  58. from rich import print as r_print
  59. def print_source(fname):
  60. r_print(py_inspect.getsource(fname))
  61. # ```{math}
  62. #
  63. # \newcommand\floor[1]{\lfloor#1\rfloor}
  64. #
  65. # \newcommand{\real}{\mathbb{R}}
  66. #
  67. # % Numbers
  68. # \newcommand{\vzero}{\boldsymbol{0}}
  69. # \newcommand{\vone}{\boldsymbol{1}}
  70. #
  71. # % Greek https://www.latex-tutorial.com/symbols/greek-alphabet/
  72. # \newcommand{\valpha}{\boldsymbol{\alpha}}
  73. # \newcommand{\vbeta}{\boldsymbol{\beta}}
  74. # \newcommand{\vchi}{\boldsymbol{\chi}}
  75. # \newcommand{\vdelta}{\boldsymbol{\delta}}
  76. # \newcommand{\vDelta}{\boldsymbol{\Delta}}
  77. # \newcommand{\vepsilon}{\boldsymbol{\epsilon}}
  78. # \newcommand{\vzeta}{\boldsymbol{\zeta}}
  79. # \newcommand{\vXi}{\boldsymbol{\Xi}}
  80. # \newcommand{\vell}{\boldsymbol{\ell}}
  81. # \newcommand{\veta}{\boldsymbol{\eta}}
  82. # %\newcommand{\vEta}{\boldsymbol{\Eta}}
  83. # \newcommand{\vgamma}{\boldsymbol{\gamma}}
  84. # \newcommand{\vGamma}{\boldsymbol{\Gamma}}
  85. # \newcommand{\vmu}{\boldsymbol{\mu}}
  86. # \newcommand{\vmut}{\boldsymbol{\tilde{\mu}}}
  87. # \newcommand{\vnu}{\boldsymbol{\nu}}
  88. # \newcommand{\vkappa}{\boldsymbol{\kappa}}
  89. # \newcommand{\vlambda}{\boldsymbol{\lambda}}
  90. # \newcommand{\vLambda}{\boldsymbol{\Lambda}}
  91. # \newcommand{\vLambdaBar}{\overline{\vLambda}}
  92. # %\newcommand{\vnu}{\boldsymbol{\nu}}
  93. # \newcommand{\vomega}{\boldsymbol{\omega}}
  94. # \newcommand{\vOmega}{\boldsymbol{\Omega}}
  95. # \newcommand{\vphi}{\boldsymbol{\phi}}
  96. # \newcommand{\vvarphi}{\boldsymbol{\varphi}}
  97. # \newcommand{\vPhi}{\boldsymbol{\Phi}}
  98. # \newcommand{\vpi}{\boldsymbol{\pi}}
  99. # \newcommand{\vPi}{\boldsymbol{\Pi}}
  100. # \newcommand{\vpsi}{\boldsymbol{\psi}}
  101. # \newcommand{\vPsi}{\boldsymbol{\Psi}}
  102. # \newcommand{\vrho}{\boldsymbol{\rho}}
  103. # \newcommand{\vtheta}{\boldsymbol{\theta}}
  104. # \newcommand{\vthetat}{\boldsymbol{\tilde{\theta}}}
  105. # \newcommand{\vTheta}{\boldsymbol{\Theta}}
  106. # \newcommand{\vsigma}{\boldsymbol{\sigma}}
  107. # \newcommand{\vSigma}{\boldsymbol{\Sigma}}
  108. # \newcommand{\vSigmat}{\boldsymbol{\tilde{\Sigma}}}
  109. # \newcommand{\vsigmoid}{\vsigma}
  110. # \newcommand{\vtau}{\boldsymbol{\tau}}
  111. # \newcommand{\vxi}{\boldsymbol{\xi}}
  112. #
  113. #
  114. # % Lower Roman (Vectors)
  115. # \newcommand{\va}{\mathbf{a}}
  116. # \newcommand{\vb}{\mathbf{b}}
  117. # \newcommand{\vBt}{\mathbf{\tilde{B}}}
  118. # \newcommand{\vc}{\mathbf{c}}
  119. # \newcommand{\vct}{\mathbf{\tilde{c}}}
  120. # \newcommand{\vd}{\mathbf{d}}
  121. # \newcommand{\ve}{\mathbf{e}}
  122. # \newcommand{\vf}{\mathbf{f}}
  123. # \newcommand{\vg}{\mathbf{g}}
  124. # \newcommand{\vh}{\mathbf{h}}
  125. # %\newcommand{\myvh}{\mathbf{h}}
  126. # \newcommand{\vi}{\mathbf{i}}
  127. # \newcommand{\vj}{\mathbf{j}}
  128. # \newcommand{\vk}{\mathbf{k}}
  129. # \newcommand{\vl}{\mathbf{l}}
  130. # \newcommand{\vm}{\mathbf{m}}
  131. # \newcommand{\vn}{\mathbf{n}}
  132. # \newcommand{\vo}{\mathbf{o}}
  133. # \newcommand{\vp}{\mathbf{p}}
  134. # \newcommand{\vq}{\mathbf{q}}
  135. # \newcommand{\vr}{\mathbf{r}}
  136. # \newcommand{\vs}{\mathbf{s}}
  137. # \newcommand{\vt}{\mathbf{t}}
  138. # \newcommand{\vu}{\mathbf{u}}
  139. # \newcommand{\vv}{\mathbf{v}}
  140. # \newcommand{\vw}{\mathbf{w}}
  141. # \newcommand{\vws}{\vw_s}
  142. # \newcommand{\vwt}{\mathbf{\tilde{w}}}
  143. # \newcommand{\vWt}{\mathbf{\tilde{W}}}
  144. # \newcommand{\vwh}{\hat{\vw}}
  145. # \newcommand{\vx}{\mathbf{x}}
  146. # %\newcommand{\vx}{\mathbf{x}}
  147. # \newcommand{\vxt}{\mathbf{\tilde{x}}}
  148. # \newcommand{\vy}{\mathbf{y}}
  149. # \newcommand{\vyt}{\mathbf{\tilde{y}}}
  150. # \newcommand{\vz}{\mathbf{z}}
  151. # %\newcommand{\vzt}{\mathbf{\tilde{z}}}
  152. #
  153. #
  154. # % Upper Roman (Matrices)
  155. # \newcommand{\vA}{\mathbf{A}}
  156. # \newcommand{\vB}{\mathbf{B}}
  157. # \newcommand{\vC}{\mathbf{C}}
  158. # \newcommand{\vD}{\mathbf{D}}
  159. # \newcommand{\vE}{\mathbf{E}}
  160. # \newcommand{\vF}{\mathbf{F}}
  161. # \newcommand{\vG}{\mathbf{G}}
  162. # \newcommand{\vH}{\mathbf{H}}
  163. # \newcommand{\vI}{\mathbf{I}}
  164. # \newcommand{\vJ}{\mathbf{J}}
  165. # \newcommand{\vK}{\mathbf{K}}
  166. # \newcommand{\vL}{\mathbf{L}}
  167. # \newcommand{\vM}{\mathbf{M}}
  168. # \newcommand{\vMt}{\mathbf{\tilde{M}}}
  169. # \newcommand{\vN}{\mathbf{N}}
  170. # \newcommand{\vO}{\mathbf{O}}
  171. # \newcommand{\vP}{\mathbf{P}}
  172. # \newcommand{\vQ}{\mathbf{Q}}
  173. # \newcommand{\vR}{\mathbf{R}}
  174. # \newcommand{\vS}{\mathbf{S}}
  175. # \newcommand{\vT}{\mathbf{T}}
  176. # \newcommand{\vU}{\mathbf{U}}
  177. # \newcommand{\vV}{\mathbf{V}}
  178. # \newcommand{\vW}{\mathbf{W}}
  179. # \newcommand{\vX}{\mathbf{X}}
  180. # %\newcommand{\vXs}{\vX_{\vs}}
  181. # \newcommand{\vXs}{\vX_{s}}
  182. # \newcommand{\vXt}{\mathbf{\tilde{X}}}
  183. # \newcommand{\vY}{\mathbf{Y}}
  184. # \newcommand{\vZ}{\mathbf{Z}}
  185. # \newcommand{\vZt}{\mathbf{\tilde{Z}}}
  186. # \newcommand{\vzt}{\mathbf{\tilde{z}}}
  187. #
  188. #
  189. # %%%%
  190. # \newcommand{\hidden}{\vz}
  191. # \newcommand{\hid}{\hidden}
  192. # \newcommand{\observed}{\vy}
  193. # \newcommand{\obs}{\observed}
  194. # \newcommand{\inputs}{\vu}
  195. # \newcommand{\input}{\inputs}
  196. #
  197. # \newcommand{\hmmTrans}{\vA}
  198. # \newcommand{\hmmObs}{\vB}
  199. # \newcommand{\hmmInit}{\vpi}
  200. # \newcommand{\hmmhid}{\hidden}
  201. # \newcommand{\hmmobs}{\obs}
  202. #
  203. # \newcommand{\ldsDyn}{\vA}
  204. # \newcommand{\ldsObs}{\vC}
  205. # \newcommand{\ldsDynIn}{\vB}
  206. # \newcommand{\ldsObsIn}{\vD}
  207. # \newcommand{\ldsDynNoise}{\vQ}
  208. # \newcommand{\ldsObsNoise}{\vR}
  209. #
  210. # \newcommand{\ssmDynFn}{f}
  211. # \newcommand{\ssmObsFn}{h}
  212. #
  213. #
  214. # %%%
  215. # \newcommand{\gauss}{\mathcal{N}}
  216. #
  217. # \newcommand{\diag}{\mathrm{diag}}
  218. # ```
  219. #
  220. # (sec:nlds-intro)=
  221. # # Nonlinear Gaussian SSMs
  222. #
  223. # In this section, we consider SSMs in which the dynamics and/or observation models are nonlinear,
  224. # but the process noise and observation noise are Gaussian.
  225. # That is,
  226. # \begin{align}
  227. # \hmmhid_t &= \ssmDynFn(\hmmhid_{t-1}, \inputs_t) + \vepsilon_t \\
  228. # \hmmobs_t &= \ssmObsFn(\hmmhid_{t}, \inputs_t) + \veta_t
  229. # \end{align}
  230. # where $\vepsilon_t \sim \gauss(\vzero,\vQ_t)$
  231. # and $\veta_t \sim \gauss(\vzero,\vR_t)$.
  232. # This is a very widely used model class. We give some examples below.
  233. # (sec:pendulum)=
  234. # ## Example: tracking a 1d pendulum
  235. #
  236. # ```{figure} /figures/pendulum.png
  237. # :scale: 50%
  238. # :name: fig:pendulum
  239. #
  240. # Illustration of a pendulum swinging.
  241. # $g$ is the force of gravity,
  242. # $w(t)$ is a random external force,
  243. # and $\alpha$ is the angle wrt the vertical.
  244. # Based on {cite}`Sarkka13` fig 3.10.
  245. # ```
  246. #
  247. #
  248. # % Sarka p45, p74
  249. # Consider a simple pendulum of unit mass and length swinging from
  250. # a fixed attachment, as in
  251. # {numref}`Figure %s <fig:pendulum>`.
  252. # Such an object is in principle entirely deterministic in its behavior.
  253. # However, in the real world, there are often unknown forces at work
  254. # (e.g., air turbulence, friction).
  255. # We will model these by a continuous time random Gaussian noise process $w(t)$.
  256. # This gives rise to the following differential equation:
  257. # \begin{align}
  258. # \frac{d^2 \alpha}{d t^2}
  259. # = -g \sin(\alpha) + w(t)
  260. # \end{align}
  261. # We can write this as a nonlinear SSM by defining the state to be
  262. # $z_1(t) = \alpha(t)$ and $z_2(t) = d\alpha(t)/dt$.
  263. # Thus
  264. # \begin{align}
  265. # \frac{d \vz}{dt}
  266. # = \begin{pmatrix} z_2 \\ -g \sin(z_1) \end{pmatrix}
  267. # + \begin{pmatrix} 0 \\ 1 \end{pmatrix} w(t)
  268. # \end{align}
  269. # If we discretize this step size $\Delta$,
  270. # we get the following
  271. # formulation {cite}`Sarkka13` p74:
  272. # \begin{align}
  273. # \underbrace{
  274. # \begin{pmatrix} z_{1,t} \\ z_{2,t} \end{pmatrix}
  275. # }_{\hmmhid_t}
  276. # =
  277. # \underbrace{
  278. # \begin{pmatrix} z_{1,t-1} + z_{2,t-1} \Delta \\
  279. # z_{2,t-1} -g \sin(z_{1,t-1}) \Delta \end{pmatrix}
  280. # }_{\vf(\hmmhid_{t-1})}
  281. # +\vq_{t-1}
  282. # \end{align}
  283. # where $\vq_{t-1} \sim \gauss(\vzero,\vQ)$ with
  284. # \begin{align}
  285. # \vQ = q^c \begin{pmatrix}
  286. # \frac{\Delta^3}{3} & \frac{\Delta^2}{2} \\
  287. # \frac{\Delta^2}{2} & \Delta
  288. # \end{pmatrix}
  289. # \end{align}
  290. # where $q^c$ is the spectral density (continuous time variance)
  291. # of the continuous-time noise process.
  292. #
  293. #
  294. # If we observe the angular position, we
  295. # get the linear observation model
  296. # \begin{align}
  297. # y_t = \alpha_t + r_t = h(\hmmhid_t) + r_t
  298. # \end{align}
  299. # where $h(\hmmhid_t) = z_{1,t}$
  300. # and $r_t$ is the observation noise.
  301. # If we only observe the horizontal position,
  302. # we get the nonlinear observation model
  303. # \begin{align}
  304. # y_t = \sin(\alpha_t) + r_t = h(\hmmhid_t) + r_t
  305. # \end{align}
  306. # where $h(\hmmhid_t) = \sin(z_{1,t})$.
  307. #
  308. #
  309. #
  310. #
  311. #
  312. #