lds.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. #!/usr/bin/env python
  2. # coding: utf-8
  3. # In[1]:
  4. ### Import standard libraries
  5. import abc
  6. from dataclasses import dataclass
  7. import functools
  8. from functools import partial
  9. import itertools
  10. import matplotlib.pyplot as plt
  11. import numpy as np
  12. from typing import Any, Callable, NamedTuple, Optional, Union, Tuple
  13. import jax
  14. import jax.numpy as jnp
  15. from jax import lax, vmap, jit, grad
  16. #from jax.scipy.special import logit
  17. #from jax.nn import softmax
  18. import jax.random as jr
  19. import distrax
  20. import optax
  21. import jsl
  22. import ssm_jax
  23. import inspect
  24. import inspect as py_inspect
  25. import rich
  26. from rich import inspect as r_inspect
  27. from rich import print as r_print
  28. def print_source(fname):
  29. r_print(py_inspect.getsource(fname))
  30. # (sec:lds-intro)=
  31. # # Linear Gaussian SSMs
  32. #
  33. #
  34. # Consider the state space model in
  35. # {eq}`eq:SSM-ar`
  36. # where we assume the observations are conditionally iid given the
  37. # hidden states and inputs (i.e. there are no auto-regressive dependencies
  38. # between the observables).
  39. # We can rewrite this model as
  40. # a stochastic $\keyword{nonlinear dynamical system}$ or $\keyword{NLDS}$
  41. # by defining the distribution of the next hidden state
  42. # $\hidden_t \in \real^{\nhidden}$
  43. # as a deterministic function of the past state
  44. # $\hidden_{t-1}$,
  45. # the input $\inputs_t \in \real^{\ninputs}$,
  46. # and some random $\keyword{process noise}$ $\transNoise_t \in \real^{\nhidden}$
  47. # \begin{align}
  48. # \hidden_t &= \dynamicsFn(\hidden_{t-1}, \inputs_t, \transNoise_t)
  49. # \end{align}
  50. # where $\transNoise_t$ is drawn from the distribution such
  51. # that the induced distribution
  52. # on $\hidden_t$ matches $p(\hidden_t|\hidden_{t-1}, \inputs_t)$.
  53. # Similarly we can rewrite the observation distribution
  54. # as a deterministic function of the hidden state
  55. # plus $\keyword{observation noise}$ $\obsNoise_t \in \real^{\nobs}$:
  56. # \begin{align}
  57. # \obs_t &= \measurementFn(\hidden_{t}, \inputs_t, \obsNoise_t)
  58. # \end{align}
  59. #
  60. #
  61. # If we assume additive Gaussian noise,
  62. # the model becomes
  63. # \begin{align}
  64. # \hidden_t &= \dynamicsFn(\hidden_{t-1}, \inputs_t) + \transNoise_t \\
  65. # \obs_t &= \measurementFn(\hidden_{t}, \inputs_t) + \obsNoise_t
  66. # \end{align}
  67. # where $\transNoise_t \sim \gauss(\vzero,\transCov_t)$
  68. # and $\obsNoise_t \sim \gauss(\vzero,\obsCov_t)$.
  69. # We will call these $\keyword{Gaussian SSMs}$.
  70. #
  71. # If we additionally assume
  72. # the transition function $\dynamicsFn$
  73. # and the observation function $\measurementFn$ are both linear,
  74. # then we can rewrite the model as follows:
  75. # \begin{align}
  76. # p(\hidden_t|\hidden_{t-1},\inputs_t) &= \gauss(\hidden_t|\ldsDyn \hidden_{t-1}
  77. # + \ldsDynIn \inputs_t, \transCov)
  78. # \\
  79. # p(\obs_t|\hidden_t,\inputs_t) &= \gauss(\obs_t|\ldsObs \hidden_{t}
  80. # + \ldsObsIn \inputs_t, \obsCov)
  81. # \end{align}
  82. # This is called a
  83. # $\keyword{linear-Gaussian state space model}$
  84. # or $\keyword{LG-SSM}$;
  85. # it is also called
  86. # a $\keyword{linear dynamical system}$ or $\keyword{LDS}$.
  87. # We usually assume the parameters are independent of time, in which case
  88. # the model is said to be time-invariant or homogeneous.
  89. #
  90. # (sec:tracking-lds)=
  91. # (sec:kalman-tracking)=
  92. # ## Example: tracking a 2d point
  93. #
  94. #
  95. #
  96. # % Sarkkar p43
  97. # Consider an object moving in $\real^2$.
  98. # Let the state be
  99. # the position and velocity of the object,
  100. # $\hidden_t =\begin{pmatrix} u_t & \dot{u}_t & v_t & \dot{v}_t \end{pmatrix}$.
  101. # (We use $u$ and $v$ for the two coordinates,
  102. # to avoid confusion with the state and observation variables.)
  103. # If we use Euler discretization,
  104. # the dynamics become
  105. # \begin{align}
  106. # \underbrace{\begin{pmatrix} u_t\\ \dot{u}_t \\ v_t \\ \dot{v}_t \end{pmatrix}}_{\hidden_t}
  107. # =
  108. # \underbrace{
  109. # \begin{pmatrix}
  110. # 1 & 0 & \Delta & 0 \\
  111. # 0 & 1 & 0 & \Delta\\
  112. # 0 & 0 & 1 & 0 \\
  113. # 0 & 0 & 0 & 1
  114. # \end{pmatrix}
  115. # }_{\ldsDyn}
  116. #
  117. # \underbrace{\begin{pmatrix} u_{t-1} \\ \dot{u}_{t-1} \\ v_{t-1} \\ \dot{v}_{t-1} \end{pmatrix}}_{\hidden_{t-1}}
  118. # + \transNoise_t
  119. # \end{align}
  120. # where $\transNoise_t \sim \gauss(\vzero,\transCov)$ is
  121. # the process noise.
  122. # We assume
  123. # that the process noise is
  124. # a white noise process added to the velocity components
  125. # of the state, but not to the location,
  126. # so $\transCov = \diag(0, q, 0, q)$.
  127. # This is known as a random accelerations model.
  128. # (See {cite}`Sarkka13` p60 for a more accurate way
  129. # to convert the continuous time process to discrete time.)
  130. #
  131. #
  132. # Now suppose that at each discrete time point we
  133. # observe the location,
  134. # corrupted by Gaussian noise.
  135. # Thus the observation model becomes
  136. # \begin{align}
  137. # \underbrace{\begin{pmatrix} \obs_{1,t} \\ \obs_{2,t} \end{pmatrix}}_{\obs_t}
  138. # &=
  139. # \underbrace{
  140. # \begin{pmatrix}
  141. # 1 & 0 & 0 & 0 \\
  142. # 0 & 0 & 1 & 0
  143. # \end{pmatrix}
  144. # }_{\ldsObs}
  145. #
  146. # \underbrace{\begin{pmatrix} u_t\\ \dot{u}_t \\ v_t \\ \dot{v}_t \end{pmatrix}}_{\hidden_t}
  147. # + \obsNoise_t
  148. # \end{align}
  149. # where $\obsNoise_t \sim \gauss(\vzero,\obsCov)$ is the observation noise.
  150. # We see that the observation matrix $\ldsObs$ simply ``extracts'' the
  151. # relevant parts of the state vector.
  152. #
  153. # Suppose we sample a trajectory and corresponding set
  154. # of noisy observations from this model,
  155. # $(\hidden_{1:T}, \obs_{1:T}) \sim p(\hidden,\obs|\params)$.
  156. # (We use diagonal observation noise,
  157. # $\obsCov = \diag(\sigma_1^2, \sigma_2^2)$.)
  158. # The results are shown below.
  159. #
  160. # In[2]:
  161. key = jax.random.PRNGKey(314)
  162. timesteps = 15
  163. delta = 1.0
  164. A = jnp.array([
  165. [1, 0, delta, 0],
  166. [0, 1, 0, delta],
  167. [0, 0, 1, 0],
  168. [0, 0, 0, 1]
  169. ])
  170. C = jnp.array([
  171. [1, 0, 0, 0],
  172. [0, 1, 0, 0]
  173. ])
  174. state_size, _ = A.shape
  175. observation_size, _ = C.shape
  176. Q = jnp.eye(state_size) * 0.001
  177. R = jnp.eye(observation_size) * 1.0
  178. # Prior parameter distribution
  179. mu0 = jnp.array([8, 10, 1, 0]).astype(float)
  180. Sigma0 = jnp.eye(state_size) * 1.0
  181. from jsl.lds.kalman_filter import LDS, smooth, filter
  182. lds = LDS(A, C, Q, R, mu0, Sigma0)
  183. print(lds)
  184. # In[3]:
  185. from jsl.demos.plot_utils import plot_ellipse
  186. def plot_tracking_values(observed, filtered, cov_hist, signal_label, ax):
  187. timesteps, _ = observed.shape
  188. ax.plot(observed[:, 0], observed[:, 1], marker="o", linewidth=0,
  189. markerfacecolor="none", markeredgewidth=2, markersize=8, label="observed", c="tab:green")
  190. ax.plot(*filtered[:, :2].T, label=signal_label, c="tab:red", marker="x", linewidth=2)
  191. for t in range(0, timesteps, 1):
  192. covn = cov_hist[t][:2, :2]
  193. plot_ellipse(covn, filtered[t, :2], ax, n_std=2.0, plot_center=False)
  194. ax.axis("equal")
  195. ax.legend()
  196. # In[4]:
  197. z_hist, x_hist = lds.sample(key, timesteps)
  198. fig_truth, axs = plt.subplots()
  199. axs.plot(x_hist[:, 0], x_hist[:, 1],
  200. marker="o", linewidth=0, markerfacecolor="none",
  201. markeredgewidth=2, markersize=8,
  202. label="observed", c="tab:green")
  203. axs.plot(z_hist[:, 0], z_hist[:, 1],
  204. linewidth=2, label="truth",
  205. marker="s", markersize=8)
  206. axs.legend()
  207. axs.axis("equal")
  208. # The main task is to infer the hidden states given the noisy
  209. # observations, i.e., $p(\hidden_t|\obs_{1:t},\params)$
  210. # or $p(\hidden_t|\obs_{1:T}, \params)$ in the offline case.
  211. # We discuss the topic of inference in {ref}`sec:inference`.
  212. # We will usually represent this belief state by a Gaussian distribution,
  213. # $p(\hidden_t|\obs_{1:s},\params) = \gauss(\hidden_t| \mean_{t|s}, \covMat_{t|s})$,
  214. # where usually $s=t$ or $s=T$.
  215. # Sometimes we use information form,
  216. # $p(\hidden_t|\obs_{1:s},\params) = \gaussInfo(\hidden_t|\precMean_{t|s}, \precMat_{t|s})$.