lds.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449
  1. #!/usr/bin/env python
  2. # coding: utf-8
  3. # In[1]:
  4. # meta-data does not work yet in VScode
  5. # https://github.com/microsoft/vscode-jupyter/issues/1121
  6. {
  7. "tags": [
  8. "hide-cell"
  9. ]
  10. }
  11. ### Install necessary libraries
  12. try:
  13. import jax
  14. except:
  15. # For cuda version, see https://github.com/google/jax#installation
  16. get_ipython().run_line_magic('pip', 'install --upgrade "jax[cpu]"')
  17. import jax
  18. try:
  19. import distrax
  20. except:
  21. get_ipython().run_line_magic('pip', 'install --upgrade distrax')
  22. import distrax
  23. try:
  24. import jsl
  25. except:
  26. get_ipython().run_line_magic('pip', 'install git+https://github.com/probml/jsl')
  27. import jsl
  28. try:
  29. import rich
  30. except:
  31. get_ipython().run_line_magic('pip', 'install rich')
  32. import rich
  33. # In[2]:
  34. {
  35. "tags": [
  36. "hide-cell"
  37. ]
  38. }
  39. ### Import standard libraries
  40. import abc
  41. from dataclasses import dataclass
  42. import functools
  43. import itertools
  44. from typing import Any, Callable, NamedTuple, Optional, Union, Tuple
  45. import matplotlib.pyplot as plt
  46. import numpy as np
  47. import jax
  48. import jax.numpy as jnp
  49. from jax import lax, vmap, jit, grad
  50. from jax.scipy.special import logit
  51. from jax.nn import softmax
  52. from functools import partial
  53. from jax.random import PRNGKey, split
  54. import inspect
  55. import inspect as py_inspect
  56. import rich
  57. from rich import inspect as r_inspect
  58. from rich import print as r_print
  59. def print_source(fname):
  60. r_print(py_inspect.getsource(fname))
  61. # ```{math}
  62. #
  63. # \newcommand\floor[1]{\lfloor#1\rfloor}
  64. #
  65. # \newcommand{\real}{\mathbb{R}}
  66. #
  67. # % Numbers
  68. # \newcommand{\vzero}{\boldsymbol{0}}
  69. # \newcommand{\vone}{\boldsymbol{1}}
  70. #
  71. # % Greek https://www.latex-tutorial.com/symbols/greek-alphabet/
  72. # \newcommand{\valpha}{\boldsymbol{\alpha}}
  73. # \newcommand{\vbeta}{\boldsymbol{\beta}}
  74. # \newcommand{\vchi}{\boldsymbol{\chi}}
  75. # \newcommand{\vdelta}{\boldsymbol{\delta}}
  76. # \newcommand{\vDelta}{\boldsymbol{\Delta}}
  77. # \newcommand{\vepsilon}{\boldsymbol{\epsilon}}
  78. # \newcommand{\vzeta}{\boldsymbol{\zeta}}
  79. # \newcommand{\vXi}{\boldsymbol{\Xi}}
  80. # \newcommand{\vell}{\boldsymbol{\ell}}
  81. # \newcommand{\veta}{\boldsymbol{\eta}}
  82. # %\newcommand{\vEta}{\boldsymbol{\Eta}}
  83. # \newcommand{\vgamma}{\boldsymbol{\gamma}}
  84. # \newcommand{\vGamma}{\boldsymbol{\Gamma}}
  85. # \newcommand{\vmu}{\boldsymbol{\mu}}
  86. # \newcommand{\vmut}{\boldsymbol{\tilde{\mu}}}
  87. # \newcommand{\vnu}{\boldsymbol{\nu}}
  88. # \newcommand{\vkappa}{\boldsymbol{\kappa}}
  89. # \newcommand{\vlambda}{\boldsymbol{\lambda}}
  90. # \newcommand{\vLambda}{\boldsymbol{\Lambda}}
  91. # \newcommand{\vLambdaBar}{\overline{\vLambda}}
  92. # %\newcommand{\vnu}{\boldsymbol{\nu}}
  93. # \newcommand{\vomega}{\boldsymbol{\omega}}
  94. # \newcommand{\vOmega}{\boldsymbol{\Omega}}
  95. # \newcommand{\vphi}{\boldsymbol{\phi}}
  96. # \newcommand{\vvarphi}{\boldsymbol{\varphi}}
  97. # \newcommand{\vPhi}{\boldsymbol{\Phi}}
  98. # \newcommand{\vpi}{\boldsymbol{\pi}}
  99. # \newcommand{\vPi}{\boldsymbol{\Pi}}
  100. # \newcommand{\vpsi}{\boldsymbol{\psi}}
  101. # \newcommand{\vPsi}{\boldsymbol{\Psi}}
  102. # \newcommand{\vrho}{\boldsymbol{\rho}}
  103. # \newcommand{\vtheta}{\boldsymbol{\theta}}
  104. # \newcommand{\vthetat}{\boldsymbol{\tilde{\theta}}}
  105. # \newcommand{\vTheta}{\boldsymbol{\Theta}}
  106. # \newcommand{\vsigma}{\boldsymbol{\sigma}}
  107. # \newcommand{\vSigma}{\boldsymbol{\Sigma}}
  108. # \newcommand{\vSigmat}{\boldsymbol{\tilde{\Sigma}}}
  109. # \newcommand{\vsigmoid}{\vsigma}
  110. # \newcommand{\vtau}{\boldsymbol{\tau}}
  111. # \newcommand{\vxi}{\boldsymbol{\xi}}
  112. #
  113. #
  114. # % Lower Roman (Vectors)
  115. # \newcommand{\va}{\mathbf{a}}
  116. # \newcommand{\vb}{\mathbf{b}}
  117. # \newcommand{\vBt}{\mathbf{\tilde{B}}}
  118. # \newcommand{\vc}{\mathbf{c}}
  119. # \newcommand{\vct}{\mathbf{\tilde{c}}}
  120. # \newcommand{\vd}{\mathbf{d}}
  121. # \newcommand{\ve}{\mathbf{e}}
  122. # \newcommand{\vf}{\mathbf{f}}
  123. # \newcommand{\vg}{\mathbf{g}}
  124. # \newcommand{\vh}{\mathbf{h}}
  125. # %\newcommand{\myvh}{\mathbf{h}}
  126. # \newcommand{\vi}{\mathbf{i}}
  127. # \newcommand{\vj}{\mathbf{j}}
  128. # \newcommand{\vk}{\mathbf{k}}
  129. # \newcommand{\vl}{\mathbf{l}}
  130. # \newcommand{\vm}{\mathbf{m}}
  131. # \newcommand{\vn}{\mathbf{n}}
  132. # \newcommand{\vo}{\mathbf{o}}
  133. # \newcommand{\vp}{\mathbf{p}}
  134. # \newcommand{\vq}{\mathbf{q}}
  135. # \newcommand{\vr}{\mathbf{r}}
  136. # \newcommand{\vs}{\mathbf{s}}
  137. # \newcommand{\vt}{\mathbf{t}}
  138. # \newcommand{\vu}{\mathbf{u}}
  139. # \newcommand{\vv}{\mathbf{v}}
  140. # \newcommand{\vw}{\mathbf{w}}
  141. # \newcommand{\vws}{\vw_s}
  142. # \newcommand{\vwt}{\mathbf{\tilde{w}}}
  143. # \newcommand{\vWt}{\mathbf{\tilde{W}}}
  144. # \newcommand{\vwh}{\hat{\vw}}
  145. # \newcommand{\vx}{\mathbf{x}}
  146. # %\newcommand{\vx}{\mathbf{x}}
  147. # \newcommand{\vxt}{\mathbf{\tilde{x}}}
  148. # \newcommand{\vy}{\mathbf{y}}
  149. # \newcommand{\vyt}{\mathbf{\tilde{y}}}
  150. # \newcommand{\vz}{\mathbf{z}}
  151. # %\newcommand{\vzt}{\mathbf{\tilde{z}}}
  152. #
  153. #
  154. # % Upper Roman (Matrices)
  155. # \newcommand{\vA}{\mathbf{A}}
  156. # \newcommand{\vB}{\mathbf{B}}
  157. # \newcommand{\vC}{\mathbf{C}}
  158. # \newcommand{\vD}{\mathbf{D}}
  159. # \newcommand{\vE}{\mathbf{E}}
  160. # \newcommand{\vF}{\mathbf{F}}
  161. # \newcommand{\vG}{\mathbf{G}}
  162. # \newcommand{\vH}{\mathbf{H}}
  163. # \newcommand{\vI}{\mathbf{I}}
  164. # \newcommand{\vJ}{\mathbf{J}}
  165. # \newcommand{\vK}{\mathbf{K}}
  166. # \newcommand{\vL}{\mathbf{L}}
  167. # \newcommand{\vM}{\mathbf{M}}
  168. # \newcommand{\vMt}{\mathbf{\tilde{M}}}
  169. # \newcommand{\vN}{\mathbf{N}}
  170. # \newcommand{\vO}{\mathbf{O}}
  171. # \newcommand{\vP}{\mathbf{P}}
  172. # \newcommand{\vQ}{\mathbf{Q}}
  173. # \newcommand{\vR}{\mathbf{R}}
  174. # \newcommand{\vS}{\mathbf{S}}
  175. # \newcommand{\vT}{\mathbf{T}}
  176. # \newcommand{\vU}{\mathbf{U}}
  177. # \newcommand{\vV}{\mathbf{V}}
  178. # \newcommand{\vW}{\mathbf{W}}
  179. # \newcommand{\vX}{\mathbf{X}}
  180. # %\newcommand{\vXs}{\vX_{\vs}}
  181. # \newcommand{\vXs}{\vX_{s}}
  182. # \newcommand{\vXt}{\mathbf{\tilde{X}}}
  183. # \newcommand{\vY}{\mathbf{Y}}
  184. # \newcommand{\vZ}{\mathbf{Z}}
  185. # \newcommand{\vZt}{\mathbf{\tilde{Z}}}
  186. # \newcommand{\vzt}{\mathbf{\tilde{z}}}
  187. #
  188. #
  189. # %%%%
  190. # \newcommand{\hidden}{\vz}
  191. # \newcommand{\hid}{\hidden}
  192. # \newcommand{\observed}{\vy}
  193. # \newcommand{\obs}{\observed}
  194. # \newcommand{\inputs}{\vu}
  195. # \newcommand{\input}{\inputs}
  196. #
  197. # \newcommand{\hmmTrans}{\vA}
  198. # \newcommand{\hmmObs}{\vB}
  199. # \newcommand{\hmmInit}{\vpi}
  200. # \newcommand{\hmmhid}{\hidden}
  201. # \newcommand{\hmmobs}{\obs}
  202. #
  203. # \newcommand{\ldsDyn}{\vA}
  204. # \newcommand{\ldsObs}{\vC}
  205. # \newcommand{\ldsDynIn}{\vB}
  206. # \newcommand{\ldsObsIn}{\vD}
  207. # \newcommand{\ldsDynNoise}{\vQ}
  208. # \newcommand{\ldsObsNoise}{\vR}
  209. #
  210. # \newcommand{\ssmDynFn}{f}
  211. # \newcommand{\ssmObsFn}{h}
  212. #
  213. #
  214. # %%%
  215. # \newcommand{\gauss}{\mathcal{N}}
  216. #
  217. # \newcommand{\diag}{\mathrm{diag}}
  218. # ```
  219. #
  220. # (sec:lds-intro)=
  221. # # Linear Gaussian SSMs
  222. #
  223. #
  224. # Consider the state space model in
  225. # {eq}`eq:SSM-ar`
  226. # where we assume the observations are conditionally iid given the
  227. # hidden states and inputs (i.e. there are no auto-regressive dependencies
  228. # between the observables).
  229. # We can rewrite this model as
  230. # a stochastic nonlinear dynamical system (NLDS)
  231. # by defining the distribution of the next hidden state
  232. # as a deterministic function of the past state
  233. # plus random process noise $\vepsilon_t$
  234. # \begin{align}
  235. # \hmmhid_t &= \ssmDynFn(\hmmhid_{t-1}, \inputs_t, \vepsilon_t)
  236. # \end{align}
  237. # where $\vepsilon_t$ is drawn from the distribution such
  238. # that the induced distribution
  239. # on $\hmmhid_t$ matches $p(\hmmhid_t|\hmmhid_{t-1}, \inputs_t)$.
  240. # Similarly we can rewrite the observation distributions
  241. # as a deterministic function of the hidden state
  242. # plus observation noise $\veta_t$:
  243. # \begin{align}
  244. # \hmmobs_t &= \ssmObsFn(\hmmhid_{t}, \inputs_t, \veta_t)
  245. # \end{align}
  246. #
  247. #
  248. # If we assume additive Gaussian noise,
  249. # the model becomes
  250. # \begin{align}
  251. # \hmmhid_t &= \ssmDynFn(\hmmhid_{t-1}, \inputs_t) + \vepsilon_t \\
  252. # \hmmobs_t &= \ssmObsFn(\hmmhid_{t}, \inputs_t) + \veta_t
  253. # \end{align}
  254. # where $\vepsilon_t \sim \gauss(\vzero,\vQ_t)$
  255. # and $\veta_t \sim \gauss(\vzero,\vR_t)$.
  256. # We will call these Gaussian SSMs.
  257. #
  258. # If we additionally assume
  259. # the transition function $\ssmDynFn$
  260. # and the observation function $\ssmObsFn$ are both linear,
  261. # then we can rewrite the model as follows:
  262. # \begin{align}
  263. # p(\hmmhid_t|\hmmhid_{t-1},\inputs_t) &= \gauss(\hmmhid_t|\ldsDyn_t \hmmhid_{t-1}
  264. # + \ldsDynIn_t \inputs_t, \vQ_t)
  265. # \\
  266. # p(\hmmobs_t|\hmmhid_t,\inputs_t) &= \gauss(\hmmobs_t|\ldsObs_t \hmmhid_{t}
  267. # + \ldsObsIn_t \inputs_t, \vR_t)
  268. # \end{align}
  269. # This is called a
  270. # linear-Gaussian state space model
  271. # (LG-SSM),
  272. # or a
  273. # linear dynamical system (LDS).
  274. # We usually assume the parameters are independent of time, in which case
  275. # the model is said to be time-invariant or homogeneous.
  276. #
  277. # (sec:tracking-lds)=
  278. # (sec:kalman-tracking)=
  279. # ## Example: tracking a 2d point
  280. #
  281. #
  282. #
  283. # % Sarkkar p43
  284. # Consider an object moving in $\real^2$.
  285. # Let the state be
  286. # the position and velocity of the object,
  287. # $\vz_t =\begin{pmatrix} u_t & \dot{u}_t & v_t & \dot{v}_t \end{pmatrix}$.
  288. # (We use $u$ and $v$ for the two coordinates,
  289. # to avoid confusion with the state and observation variables.)
  290. # If we use Euler discretization,
  291. # the dynamics become
  292. # \begin{align}
  293. # \underbrace{\begin{pmatrix} u_t\\ \dot{u}_t \\ v_t \\ \dot{v}_t \end{pmatrix}}_{\vz_t}
  294. # =
  295. # \underbrace{
  296. # \begin{pmatrix}
  297. # 1 & 0 & \Delta & 0 \\
  298. # 0 & 1 & 0 & \Delta\\
  299. # 0 & 0 & 1 & 0 \\
  300. # 0 & 0 & 0 & 1
  301. # \end{pmatrix}
  302. # }_{\ldsDyn}
  303. # \
  304. # \underbrace{\begin{pmatrix} u_{t-1} \\ \dot{u}_{t-1} \\ v_{t-1} \\ \dot{v}_{t-1} \end{pmatrix}}_{\vz_{t-1}}
  305. # + \vepsilon_t
  306. # \end{align}
  307. # where $\vepsilon_t \sim \gauss(\vzero,\vQ)$ is
  308. # the process noise.
  309. #
  310. # Let us assume
  311. # that the process noise is
  312. # a white noise process added to the velocity components
  313. # of the state, but not to the location.
  314. # (This is known as a random accelerations model.)
  315. # We can approximate the resulting process in discrete time by assuming
  316. # $\vQ = \diag(0, q, 0, q)$.
  317. # (See {cite}`Sarkka13` p60 for a more accurate way
  318. # to convert the continuous time process to discrete time.)
  319. #
  320. #
  321. # Now suppose that at each discrete time point we
  322. # observe the location,
  323. # corrupted by Gaussian noise.
  324. # Thus the observation model becomes
  325. # \begin{align}
  326. # \underbrace{\begin{pmatrix} y_{1,t} \\ y_{2,t} \end{pmatrix}}_{\vy_t}
  327. # &=
  328. # \underbrace{
  329. # \begin{pmatrix}
  330. # 1 & 0 & 0 & 0 \\
  331. # 0 & 0 & 1 & 0
  332. # \end{pmatrix}
  333. # }_{\ldsObs}
  334. # \
  335. # \underbrace{\begin{pmatrix} u_t\\ \dot{u}_t \\ v_t \\ \dot{v}_t \end{pmatrix}}_{\vz_t}
  336. # + \veta_t
  337. # \end{align}
  338. # where $\veta_t \sim \gauss(\vzero,\vR)$ is the \keywordDef{observation noise}.
  339. # We see that the observation matrix $\ldsObs$ simply ``extracts'' the
  340. # relevant parts of the state vector.
  341. #
  342. # Suppose we sample a trajectory and corresponding set
  343. # of noisy observations from this model,
  344. # $(\vz_{1:T}, \vy_{1:T}) \sim p(\vz,\vy|\vtheta)$.
  345. # (We use diagonal observation noise,
  346. # $\vR = \diag(\sigma_1^2, \sigma_2^2)$.)
  347. # The results are shown below.
  348. #
  349. # In[3]:
  350. key = jax.random.PRNGKey(314)
  351. timesteps = 15
  352. delta = 1.0
  353. A = jnp.array([
  354. [1, 0, delta, 0],
  355. [0, 1, 0, delta],
  356. [0, 0, 1, 0],
  357. [0, 0, 0, 1]
  358. ])
  359. C = jnp.array([
  360. [1, 0, 0, 0],
  361. [0, 1, 0, 0]
  362. ])
  363. state_size, _ = A.shape
  364. observation_size, _ = C.shape
  365. Q = jnp.eye(state_size) * 0.001
  366. R = jnp.eye(observation_size) * 1.0
  367. # Prior parameter distribution
  368. mu0 = jnp.array([8, 10, 1, 0]).astype(float)
  369. Sigma0 = jnp.eye(state_size) * 1.0
  370. from jsl.lds.kalman_filter import LDS, smooth, filter
  371. lds = LDS(A, C, Q, R, mu0, Sigma0)
  372. print(lds)
  373. # In[4]:
  374. from jsl.demos.plot_utils import plot_ellipse
  375. def plot_tracking_values(observed, filtered, cov_hist, signal_label, ax):
  376. timesteps, _ = observed.shape
  377. ax.plot(observed[:, 0], observed[:, 1], marker="o", linewidth=0,
  378. markerfacecolor="none", markeredgewidth=2, markersize=8, label="observed", c="tab:green")
  379. ax.plot(*filtered[:, :2].T, label=signal_label, c="tab:red", marker="x", linewidth=2)
  380. for t in range(0, timesteps, 1):
  381. covn = cov_hist[t][:2, :2]
  382. plot_ellipse(covn, filtered[t, :2], ax, n_std=2.0, plot_center=False)
  383. ax.axis("equal")
  384. ax.legend()
  385. # In[5]:
  386. z_hist, x_hist = lds.sample(key, timesteps)
  387. fig_truth, axs = plt.subplots()
  388. axs.plot(x_hist[:, 0], x_hist[:, 1],
  389. marker="o", linewidth=0, markerfacecolor="none",
  390. markeredgewidth=2, markersize=8,
  391. label="observed", c="tab:green")
  392. axs.plot(z_hist[:, 0], z_hist[:, 1],
  393. linewidth=2, label="truth",
  394. marker="s", markersize=8)
  395. axs.legend()
  396. axs.axis("equal")
  397. # The main task is to infer the hidden states given the noisy
  398. # observations, i.e., $p(\vz|\vy,\vtheta)$. We discuss the topic of inference in {ref}`sec:inference`.