inference.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. #!/usr/bin/env python
  2. # coding: utf-8
  3. # In[1]:
  4. {
  5. "tags": [
  6. "hide-cell"
  7. ]
  8. }
  9. ### Import standard libraries
  10. import abc
  11. from dataclasses import dataclass
  12. import functools
  13. import itertools
  14. from typing import Any, Callable, NamedTuple, Optional, Union, Tuple
  15. import matplotlib.pyplot as plt
  16. import numpy as np
  17. import jax
  18. import jax.numpy as jnp
  19. from jax import lax, vmap, jit, grad
  20. from jax.scipy.special import logit
  21. from jax.nn import softmax
  22. from functools import partial
  23. from jax.random import PRNGKey, split
  24. import jsl
  25. import ssm_jax
  26. # (sec:inference)=
  27. # # States estimation (inference)
  28. #
  29. #
  30. #
  31. #
  32. #
  33. # Given the sequence of observations, and a known model,
  34. # one of the main tasks with SSMs
  35. # to perform posterior inference,
  36. # about the hidden states; this is also called
  37. # state estimation.
  38. # At each time step $t$,
  39. # there are multiple forms of posterior we may be interested in computing,
  40. # including the following:
  41. # - the filtering distribution
  42. # $p(\hidden_t|\obs_{1:t})$
  43. # - the smoothing distribution
  44. # $p(\hidden_t|\obs_{1:T})$ (note that this conditions on future data $T>t$)
  45. # - the fixed-lag smoothing distribution
  46. # $p(\hidden_{t-\ell}|\obs_{1:t})$ (note that this
  47. # infers $\ell$ steps in the past given data up to the present).
  48. #
  49. # We may also want to compute the
  50. # predictive distribution $h$ steps into the future:
  51. # \begin{align}
  52. # p(\obs_{t+h}|\obs_{1:t})
  53. # = \sum_{\hidden_{t+h}} p(\obs_{t+h}|\hidden_{t+h}) p(\hidden_{t+h}|\obs_{1:t})
  54. # \end{align}
  55. # where the hidden state predictive distribution is
  56. # \begin{align}
  57. # p(\hidden_{t+h}|\obs_{1:t})
  58. # &= \sum_{\hidden_{t:t+h-1}}
  59. # p(\hidden_t|\obs_{1:t})
  60. # p(\hidden_{t+1}|\hidden_{t})
  61. # p(\hidden_{t+2}|\hidden_{t+1})
  62. # \cdots
  63. # p(\hidden_{t+h}|\hidden_{t+h-1})
  64. # \end{align}
  65. # See
  66. # {numref}`fig:dbn-inference` for a summary of these distributions.
  67. #
  68. # ```{figure} /figures/inference-problems-tikz.png
  69. # :scale: 30%
  70. # :name: fig:dbn-inference
  71. #
  72. # Illustration of the different kinds of inference in an SSM.
  73. # The main kinds of inference for state-space models.
  74. # The shaded region is the interval for which we have data.
  75. # The arrow represents the time step at which we want to perform inference.
  76. # $t$ is the current time, $T$ is the sequence length,
  77. # $\ell$ is the lag and $h$ is the prediction horizon.
  78. # ```
  79. #
  80. # In addition to comuting posterior marginals,
  81. # we may want to compute the most probable hidden sequence,
  82. # i.e., the joint MAP estimate
  83. # ```{math}
  84. # \arg \max_{\hidden_{1:T}} p(\hidden_{1:T}|\obs_{1:T})
  85. # ```
  86. # or sample sequences from the posterior
  87. # ```{math}
  88. # \hidden_{1:T} \sim p(\hidden_{1:T}|\obs_{1:T})
  89. # ```
  90. #
  91. # Algorithms for all these task are discussed in the following chapters,
  92. # since the details depend on the form of the SSM.
  93. #
  94. #
  95. #
  96. #
  97. #
  98. # (sec:casino-inference)=
  99. # ## Example: inference in the casino HMM
  100. #
  101. #
  102. # We now illustrate filtering, smoothing and MAP decoding applied
  103. # to the casino HMM from {ref}`sec:casino` and [](sec:casino).
  104. #
  105. # In[2]:
  106. # state transition matrix
  107. A = np.array([
  108. [0.95, 0.05],
  109. [0.10, 0.90]
  110. ])
  111. # observation matrix
  112. B = np.array([
  113. [1/6, 1/6, 1/6, 1/6, 1/6, 1/6], # fair die
  114. [1/10, 1/10, 1/10, 1/10, 1/10, 5/10] # loaded die
  115. ])
  116. pi = np.array([0.5, 0.5])
  117. (nstates, nobs) = np.shape(B)
  118. import distrax
  119. from distrax import HMM
  120. hmm = HMM(trans_dist=distrax.Categorical(probs=A),
  121. init_dist=distrax.Categorical(probs=pi),
  122. obs_dist=distrax.Categorical(probs=B))
  123. seed = 314
  124. n_samples = 300
  125. z_hist, x_hist = hmm.sample(seed=PRNGKey(seed), seq_len=n_samples)
  126. # In[3]:
  127. # Call inference engine
  128. filtered_dist, _, smoothed_dist, loglik = hmm.forward_backward(x_hist)
  129. map_path = hmm.viterbi(x_hist)
  130. # In[4]:
  131. # Find the span of timesteps that the simulated systems turns to be in state 1
  132. def find_dishonest_intervals(z_hist):
  133. spans = []
  134. x_init = 0
  135. for t, _ in enumerate(z_hist[:-1]):
  136. if z_hist[t + 1] == 0 and z_hist[t] == 1:
  137. x_end = t
  138. spans.append((x_init, x_end))
  139. elif z_hist[t + 1] == 1 and z_hist[t] == 0:
  140. x_init = t + 1
  141. return spans
  142. # In[5]:
  143. # Plot posterior
  144. def plot_inference(inference_values, z_hist, ax, state=1, map_estimate=False):
  145. n_samples = len(inference_values)
  146. xspan = np.arange(1, n_samples + 1)
  147. spans = find_dishonest_intervals(z_hist)
  148. if map_estimate:
  149. ax.step(xspan, inference_values, where="post")
  150. else:
  151. ax.plot(xspan, inference_values[:, state])
  152. for span in spans:
  153. ax.axvspan(*span, alpha=0.5, facecolor="tab:gray", edgecolor="none")
  154. ax.set_xlim(1, n_samples)
  155. # ax.set_ylim(0, 1)
  156. ax.set_ylim(-0.1, 1.1)
  157. ax.set_xlabel("Observation number")
  158. # In[6]:
  159. # Filtering
  160. fig, ax = plt.subplots()
  161. plot_inference(filtered_dist, z_hist, ax)
  162. ax.set_ylabel("p(loaded)")
  163. ax.set_title("Filtered")
  164. # In[7]:
  165. # Smoothing
  166. fig, ax = plt.subplots()
  167. plot_inference(smoothed_dist, z_hist, ax)
  168. ax.set_ylabel("p(loaded)")
  169. ax.set_title("Smoothed")
  170. # In[8]:
  171. # MAP estimation
  172. fig, ax = plt.subplots()
  173. plot_inference(map_path, z_hist, ax, map_estimate=True)
  174. ax.set_ylabel("MAP state")
  175. ax.set_title("Viterbi")
  176. # In[9]:
  177. # TODO: posterior samples
  178. # ## Example: inference in the tracking LG-SSM
  179. #
  180. # We now illustrate filtering, smoothing and MAP decoding applied
  181. # to the 2d tracking HMM from {ref}`sec:tracking-lds`.
  182. # In[10]:
  183. key = jax.random.PRNGKey(314)
  184. timesteps = 15
  185. delta = 1.0
  186. A = jnp.array([
  187. [1, 0, delta, 0],
  188. [0, 1, 0, delta],
  189. [0, 0, 1, 0],
  190. [0, 0, 0, 1]
  191. ])
  192. C = jnp.array([
  193. [1, 0, 0, 0],
  194. [0, 1, 0, 0]
  195. ])
  196. state_size, _ = A.shape
  197. observation_size, _ = C.shape
  198. Q = jnp.eye(state_size) * 0.001
  199. R = jnp.eye(observation_size) * 1.0
  200. mu0 = jnp.array([8, 10, 1, 0]).astype(float)
  201. Sigma0 = jnp.eye(state_size) * 1.0
  202. from jsl.lds.kalman_filter import LDS, smooth, filter
  203. lds = LDS(A, C, Q, R, mu0, Sigma0)
  204. z_hist, x_hist = lds.sample(key, timesteps)
  205. # In[11]:
  206. from jsl.demos.plot_utils import plot_ellipse
  207. def plot_tracking_values(observed, filtered, cov_hist, signal_label, ax):
  208. timesteps, _ = observed.shape
  209. ax.plot(observed[:, 0], observed[:, 1], marker="o", linewidth=0,
  210. markerfacecolor="none", markeredgewidth=2, markersize=8, label="observed", c="tab:green")
  211. ax.plot(*filtered[:, :2].T, label=signal_label, c="tab:red", marker="x", linewidth=2)
  212. for t in range(0, timesteps, 1):
  213. covn = cov_hist[t][:2, :2]
  214. plot_ellipse(covn, filtered[t, :2], ax, n_std=2.0, plot_center=False)
  215. ax.axis("equal")
  216. ax.legend()
  217. # In[12]:
  218. # Filtering
  219. mu_hist, Sigma_hist, mu_cond_hist, Sigma_cond_hist = filter(lds, x_hist)
  220. l2_filter = jnp.linalg.norm(z_hist[:, :2] - mu_hist[:, :2], 2)
  221. print(f"L2-filter: {l2_filter:0.4f}")
  222. fig_filtered, axs = plt.subplots()
  223. plot_tracking_values(x_hist, mu_hist, Sigma_hist, "filtered", axs)
  224. # In[13]:
  225. # Smoothing
  226. mu_hist_smooth, Sigma_hist_smooth = smooth(lds, mu_hist, Sigma_hist, mu_cond_hist, Sigma_cond_hist)
  227. l2_smooth = jnp.linalg.norm(z_hist[:, :2] - mu_hist_smooth[:, :2], 2)
  228. print(f"L2-smooth: {l2_smooth:0.4f}")
  229. fig_smoothed, axs = plt.subplots()
  230. plot_tracking_values(x_hist, mu_hist_smooth, Sigma_hist_smooth, "smoothed", axs)