nlds.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. #!/usr/bin/env python
  2. # coding: utf-8
  3. # In[1]:
  4. # meta-data does not work yet in VScode
  5. # https://github.com/microsoft/vscode-jupyter/issues/1121
  6. {
  7. "tags": [
  8. "hide-cell"
  9. ]
  10. }
  11. ### Install necessary libraries
  12. try:
  13. import jax
  14. except:
  15. # For cuda version, see https://github.com/google/jax#installation
  16. get_ipython().run_line_magic('pip', 'install --upgrade "jax[cpu]"')
  17. import jax
  18. try:
  19. import distrax
  20. except:
  21. get_ipython().run_line_magic('pip', 'install --upgrade distrax')
  22. import distrax
  23. try:
  24. import jsl
  25. except:
  26. get_ipython().run_line_magic('pip', 'install git+https://github.com/probml/jsl')
  27. import jsl
  28. try:
  29. import rich
  30. except:
  31. get_ipython().run_line_magic('pip', 'install rich')
  32. import rich
  33. # In[2]:
  34. {
  35. "tags": [
  36. "hide-cell"
  37. ]
  38. }
  39. ### Import standard libraries
  40. import abc
  41. from dataclasses import dataclass
  42. import functools
  43. import itertools
  44. from typing import Any, Callable, NamedTuple, Optional, Union, Tuple
  45. import matplotlib.pyplot as plt
  46. import numpy as np
  47. import jax
  48. import jax.numpy as jnp
  49. from jax import lax, vmap, jit, grad
  50. from jax.scipy.special import logit
  51. from jax.nn import softmax
  52. from functools import partial
  53. from jax.random import PRNGKey, split
  54. import inspect
  55. import inspect as py_inspect
  56. import rich
  57. from rich import inspect as r_inspect
  58. from rich import print as r_print
  59. def print_source(fname):
  60. r_print(py_inspect.getsource(fname))
  61. # (sec:nlds-intro)=
  62. # # Nonlinear Gaussian SSMs
  63. #
  64. # In this section, we consider SSMs in which the dynamics and/or observation models are nonlinear,
  65. # but the process noise and observation noise are Gaussian.
  66. # That is,
  67. # \begin{align}
  68. # \hidden_t &= \dynamicsFn(\hidden_{t-1}, \inputs_t) + \transNoise_t \\
  69. # \obs_t &= \obsFn(\hidden_{t}, \inputs_t) + \obsNoise_t
  70. # \end{align}
  71. # where $\transNoise_t \sim \gauss(\vzero,\transCov)$
  72. # and $\obsNoise_t \sim \gauss(\vzero,\obsCov)$.
  73. # This is a very widely used model class. We give some examples below.
  74. # (sec:pendulum)=
  75. # ## Example: tracking a 1d pendulum
  76. #
  77. # ```{figure} /figures/pendulum.png
  78. # :scale: 50%
  79. # :name: fig:pendulum
  80. #
  81. # Illustration of a pendulum swinging.
  82. # $g$ is the force of gravity,
  83. # $w(t)$ is a random external force,
  84. # and $\alpha$ is the angle wrt the vertical.
  85. # Based on {cite}`Sarkka13` fig 3.10.
  86. # ```
  87. #
  88. #
  89. # % Sarka p45, p74
  90. # Consider a simple pendulum of unit mass and length swinging from
  91. # a fixed attachment, as in
  92. # {numref}`fig:pendulum`.
  93. # Such an object is in principle entirely deterministic in its behavior.
  94. # However, in the real world, there are often unknown forces at work
  95. # (e.g., air turbulence, friction).
  96. # We will model these by a continuous time random Gaussian noise process $w(t)$.
  97. # This gives rise to the following differential equation:
  98. # \begin{align}
  99. # \frac{d^2 \alpha}{d t^2}
  100. # = -g \sin(\alpha) + w(t)
  101. # \end{align}
  102. # We can write this as a nonlinear SSM by defining the state to be
  103. # $\hidden_1(t) = \alpha(t)$ and $\hidden_2(t) = d\alpha(t)/dt$.
  104. # Thus
  105. # \begin{align}
  106. # \frac{d \hidden}{dt}
  107. # = \begin{pmatrix} \hiddenScalar_2 \\ -g \sin(\hiddenScalar_1) \end{pmatrix}
  108. # + \begin{pmatrix} 0 \\ 1 \end{pmatrix} w(t)
  109. # \end{align}
  110. # If we discretize this step size $\Delta$,
  111. # we get the following
  112. # formulation {cite}`Sarkka13` p74:
  113. # \begin{align}
  114. # \underbrace{
  115. # \begin{pmatrix} \hiddenScalar_{1,t} \\ \hiddenScalar_{2,t} \end{pmatrix}
  116. # }_{\hidden_t}
  117. # =
  118. # \underbrace{
  119. # \begin{pmatrix} \hiddenScalar_{1,t-1} + \hiddenScalar_{2,t-1} \Delta \\
  120. # \hiddenScalar_{2,t-1} -g \sin(\hiddenScalar_{1,t-1}) \Delta \end{pmatrix}
  121. # }_{\dynamicsFn(\hidden_{t-1})}
  122. # +\transNoise_{t-1}
  123. # \end{align}
  124. # where $\transNoise_{t-1} \sim \gauss(\vzero,\transCov)$ with
  125. # \begin{align}
  126. # \transCov = q^c \begin{pmatrix}
  127. # \frac{\Delta^3}{3} & \frac{\Delta^2}{2} \\
  128. # \frac{\Delta^2}{2} & \Delta
  129. # \end{pmatrix}
  130. # \end{align}
  131. # where $q^c$ is the spectral density (continuous time variance)
  132. # of the continuous-time noise process.
  133. #
  134. #
  135. # If we observe the angular position, we
  136. # get the linear observation model
  137. # $\obsFn(\hidden_t) = \alpha_t = \hiddenScalar_{1,t}$.
  138. # If we only observe the horizontal position,
  139. # we get the nonlinear observation model
  140. # $\obsFn(\hidden_t) = \sin(\alpha_t) = \sin(\hiddenScalar_{1,t})$.
  141. #
  142. #
  143. #
  144. #
  145. #
  146. #