inference.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526
  1. #!/usr/bin/env python
  2. # coding: utf-8
  3. # In[1]:
  4. # meta-data does not work yet in VScode
  5. # https://github.com/microsoft/vscode-jupyter/issues/1121
  6. {
  7. "tags": [
  8. "hide-cell"
  9. ]
  10. }
  11. ### Install necessary libraries
  12. try:
  13. import jax
  14. except:
  15. # For cuda version, see https://github.com/google/jax#installation
  16. get_ipython().run_line_magic('pip', 'install --upgrade "jax[cpu]"')
  17. import jax
  18. try:
  19. import distrax
  20. except:
  21. get_ipython().run_line_magic('pip', 'install --upgrade distrax')
  22. import distrax
  23. try:
  24. import jsl
  25. except:
  26. get_ipython().run_line_magic('pip', 'install git+https://github.com/probml/jsl')
  27. import jsl
  28. try:
  29. import rich
  30. except:
  31. get_ipython().run_line_magic('pip', 'install rich')
  32. import rich
  33. # In[2]:
  34. {
  35. "tags": [
  36. "hide-cell"
  37. ]
  38. }
  39. ### Import standard libraries
  40. import abc
  41. from dataclasses import dataclass
  42. import functools
  43. import itertools
  44. from typing import Any, Callable, NamedTuple, Optional, Union, Tuple
  45. import matplotlib.pyplot as plt
  46. import numpy as np
  47. import jax
  48. import jax.numpy as jnp
  49. from jax import lax, vmap, jit, grad
  50. from jax.scipy.special import logit
  51. from jax.nn import softmax
  52. from functools import partial
  53. from jax.random import PRNGKey, split
  54. import inspect
  55. import inspect as py_inspect
  56. import rich
  57. from rich import inspect as r_inspect
  58. from rich import print as r_print
  59. def print_source(fname):
  60. r_print(py_inspect.getsource(fname))
  61. # ```{math}
  62. #
  63. # \newcommand\floor[1]{\lfloor#1\rfloor}
  64. #
  65. # \newcommand{\real}{\mathbb{R}}
  66. #
  67. # % Numbers
  68. # \newcommand{\vzero}{\boldsymbol{0}}
  69. # \newcommand{\vone}{\boldsymbol{1}}
  70. #
  71. # % Greek https://www.latex-tutorial.com/symbols/greek-alphabet/
  72. # \newcommand{\valpha}{\boldsymbol{\alpha}}
  73. # \newcommand{\vbeta}{\boldsymbol{\beta}}
  74. # \newcommand{\vchi}{\boldsymbol{\chi}}
  75. # \newcommand{\vdelta}{\boldsymbol{\delta}}
  76. # \newcommand{\vDelta}{\boldsymbol{\Delta}}
  77. # \newcommand{\vepsilon}{\boldsymbol{\epsilon}}
  78. # \newcommand{\vzeta}{\boldsymbol{\zeta}}
  79. # \newcommand{\vXi}{\boldsymbol{\Xi}}
  80. # \newcommand{\vell}{\boldsymbol{\ell}}
  81. # \newcommand{\veta}{\boldsymbol{\eta}}
  82. # %\newcommand{\vEta}{\boldsymbol{\Eta}}
  83. # \newcommand{\vgamma}{\boldsymbol{\gamma}}
  84. # \newcommand{\vGamma}{\boldsymbol{\Gamma}}
  85. # \newcommand{\vmu}{\boldsymbol{\mu}}
  86. # \newcommand{\vmut}{\boldsymbol{\tilde{\mu}}}
  87. # \newcommand{\vnu}{\boldsymbol{\nu}}
  88. # \newcommand{\vkappa}{\boldsymbol{\kappa}}
  89. # \newcommand{\vlambda}{\boldsymbol{\lambda}}
  90. # \newcommand{\vLambda}{\boldsymbol{\Lambda}}
  91. # \newcommand{\vLambdaBar}{\overline{\vLambda}}
  92. # %\newcommand{\vnu}{\boldsymbol{\nu}}
  93. # \newcommand{\vomega}{\boldsymbol{\omega}}
  94. # \newcommand{\vOmega}{\boldsymbol{\Omega}}
  95. # \newcommand{\vphi}{\boldsymbol{\phi}}
  96. # \newcommand{\vvarphi}{\boldsymbol{\varphi}}
  97. # \newcommand{\vPhi}{\boldsymbol{\Phi}}
  98. # \newcommand{\vpi}{\boldsymbol{\pi}}
  99. # \newcommand{\vPi}{\boldsymbol{\Pi}}
  100. # \newcommand{\vpsi}{\boldsymbol{\psi}}
  101. # \newcommand{\vPsi}{\boldsymbol{\Psi}}
  102. # \newcommand{\vrho}{\boldsymbol{\rho}}
  103. # \newcommand{\vtheta}{\boldsymbol{\theta}}
  104. # \newcommand{\vthetat}{\boldsymbol{\tilde{\theta}}}
  105. # \newcommand{\vTheta}{\boldsymbol{\Theta}}
  106. # \newcommand{\vsigma}{\boldsymbol{\sigma}}
  107. # \newcommand{\vSigma}{\boldsymbol{\Sigma}}
  108. # \newcommand{\vSigmat}{\boldsymbol{\tilde{\Sigma}}}
  109. # \newcommand{\vsigmoid}{\vsigma}
  110. # \newcommand{\vtau}{\boldsymbol{\tau}}
  111. # \newcommand{\vxi}{\boldsymbol{\xi}}
  112. #
  113. #
  114. # % Lower Roman (Vectors)
  115. # \newcommand{\va}{\mathbf{a}}
  116. # \newcommand{\vb}{\mathbf{b}}
  117. # \newcommand{\vBt}{\mathbf{\tilde{B}}}
  118. # \newcommand{\vc}{\mathbf{c}}
  119. # \newcommand{\vct}{\mathbf{\tilde{c}}}
  120. # \newcommand{\vd}{\mathbf{d}}
  121. # \newcommand{\ve}{\mathbf{e}}
  122. # \newcommand{\vf}{\mathbf{f}}
  123. # \newcommand{\vg}{\mathbf{g}}
  124. # \newcommand{\vh}{\mathbf{h}}
  125. # %\newcommand{\myvh}{\mathbf{h}}
  126. # \newcommand{\vi}{\mathbf{i}}
  127. # \newcommand{\vj}{\mathbf{j}}
  128. # \newcommand{\vk}{\mathbf{k}}
  129. # \newcommand{\vl}{\mathbf{l}}
  130. # \newcommand{\vm}{\mathbf{m}}
  131. # \newcommand{\vn}{\mathbf{n}}
  132. # \newcommand{\vo}{\mathbf{o}}
  133. # \newcommand{\vp}{\mathbf{p}}
  134. # \newcommand{\vq}{\mathbf{q}}
  135. # \newcommand{\vr}{\mathbf{r}}
  136. # \newcommand{\vs}{\mathbf{s}}
  137. # \newcommand{\vt}{\mathbf{t}}
  138. # \newcommand{\vu}{\mathbf{u}}
  139. # \newcommand{\vv}{\mathbf{v}}
  140. # \newcommand{\vw}{\mathbf{w}}
  141. # \newcommand{\vws}{\vw_s}
  142. # \newcommand{\vwt}{\mathbf{\tilde{w}}}
  143. # \newcommand{\vWt}{\mathbf{\tilde{W}}}
  144. # \newcommand{\vwh}{\hat{\vw}}
  145. # \newcommand{\vx}{\mathbf{x}}
  146. # %\newcommand{\vx}{\mathbf{x}}
  147. # \newcommand{\vxt}{\mathbf{\tilde{x}}}
  148. # \newcommand{\vy}{\mathbf{y}}
  149. # \newcommand{\vyt}{\mathbf{\tilde{y}}}
  150. # \newcommand{\vz}{\mathbf{z}}
  151. # %\newcommand{\vzt}{\mathbf{\tilde{z}}}
  152. #
  153. #
  154. # % Upper Roman (Matrices)
  155. # \newcommand{\vA}{\mathbf{A}}
  156. # \newcommand{\vB}{\mathbf{B}}
  157. # \newcommand{\vC}{\mathbf{C}}
  158. # \newcommand{\vD}{\mathbf{D}}
  159. # \newcommand{\vE}{\mathbf{E}}
  160. # \newcommand{\vF}{\mathbf{F}}
  161. # \newcommand{\vG}{\mathbf{G}}
  162. # \newcommand{\vH}{\mathbf{H}}
  163. # \newcommand{\vI}{\mathbf{I}}
  164. # \newcommand{\vJ}{\mathbf{J}}
  165. # \newcommand{\vK}{\mathbf{K}}
  166. # \newcommand{\vL}{\mathbf{L}}
  167. # \newcommand{\vM}{\mathbf{M}}
  168. # \newcommand{\vMt}{\mathbf{\tilde{M}}}
  169. # \newcommand{\vN}{\mathbf{N}}
  170. # \newcommand{\vO}{\mathbf{O}}
  171. # \newcommand{\vP}{\mathbf{P}}
  172. # \newcommand{\vQ}{\mathbf{Q}}
  173. # \newcommand{\vR}{\mathbf{R}}
  174. # \newcommand{\vS}{\mathbf{S}}
  175. # \newcommand{\vT}{\mathbf{T}}
  176. # \newcommand{\vU}{\mathbf{U}}
  177. # \newcommand{\vV}{\mathbf{V}}
  178. # \newcommand{\vW}{\mathbf{W}}
  179. # \newcommand{\vX}{\mathbf{X}}
  180. # %\newcommand{\vXs}{\vX_{\vs}}
  181. # \newcommand{\vXs}{\vX_{s}}
  182. # \newcommand{\vXt}{\mathbf{\tilde{X}}}
  183. # \newcommand{\vY}{\mathbf{Y}}
  184. # \newcommand{\vZ}{\mathbf{Z}}
  185. # \newcommand{\vZt}{\mathbf{\tilde{Z}}}
  186. # \newcommand{\vzt}{\mathbf{\tilde{z}}}
  187. #
  188. #
  189. # %%%%
  190. # \newcommand{\hidden}{\vz}
  191. # \newcommand{\hid}{\hidden}
  192. # \newcommand{\observed}{\vy}
  193. # \newcommand{\obs}{\observed}
  194. # \newcommand{\inputs}{\vu}
  195. # \newcommand{\input}{\inputs}
  196. #
  197. # \newcommand{\hmmTrans}{\vA}
  198. # \newcommand{\hmmObs}{\vB}
  199. # \newcommand{\hmmInit}{\vpi}
  200. # \newcommand{\hmmhid}{\hidden}
  201. # \newcommand{\hmmobs}{\obs}
  202. #
  203. # \newcommand{\ldsDyn}{\vA}
  204. # \newcommand{\ldsObs}{\vC}
  205. # \newcommand{\ldsDynIn}{\vB}
  206. # \newcommand{\ldsObsIn}{\vD}
  207. # \newcommand{\ldsDynNoise}{\vQ}
  208. # \newcommand{\ldsObsNoise}{\vR}
  209. #
  210. # \newcommand{\ssmDynFn}{f}
  211. # \newcommand{\ssmObsFn}{h}
  212. #
  213. #
  214. # %%%
  215. # \newcommand{\gauss}{\mathcal{N}}
  216. #
  217. # \newcommand{\diag}{\mathrm{diag}}
  218. # ```
  219. #
  220. # (sec:inference)=
  221. # # Inferential goals
  222. #
  223. # ```{figure} /figures/inference-problems-tikz.png
  224. # :scale: 30%
  225. # :name: fig:dbn-inference
  226. #
  227. # Illustration of the different kinds of inference in an SSM.
  228. # The main kinds of inference for state-space models.
  229. # The shaded region is the interval for which we have data.
  230. # The arrow represents the time step at which we want to perform inference.
  231. # $t$ is the current time, $T$ is the sequence length,
  232. # $\ell$ is the lag and $h$ is the prediction horizon.
  233. # ```
  234. #
  235. #
  236. #
  237. # Given the sequence of observations, and a known model,
  238. # one of the main tasks with SSMs
  239. # to perform posterior inference,
  240. # about the hidden states; this is also called
  241. # state estimation.
  242. # At each time step $t$,
  243. # there are multiple forms of posterior we may be interested in computing,
  244. # including the following:
  245. # - the filtering distribution
  246. # $p(\hmmhid_t|\hmmobs_{1:t})$
  247. # - the smoothing distribution
  248. # $p(\hmmhid_t|\hmmobs_{1:T})$ (note that this conditions on future data $T>t$)
  249. # - the fixed-lag smoothing distribution
  250. # $p(\hmmhid_{t-\ell}|\hmmobs_{1:t})$ (note that this
  251. # infers $\ell$ steps in the past given data up to the present).
  252. #
  253. # We may also want to compute the
  254. # predictive distribution $h$ steps into the future:
  255. # ```{math}
  256. # p(\hmmobs_{t+h}|\hmmobs_{1:t})
  257. # = \sum_{\hmmhid_{t+h}} p(\hmmobs_{t+h}|\hmmhid_{t+h}) p(\hmmhid_{t+h}|\hmmobs_{1:t})
  258. # ```
  259. # where the hidden state predictive distribution is
  260. # \begin{align}
  261. # p(\hmmhid_{t+h}|\hmmobs_{1:t})
  262. # &= \sum_{\hmmhid_{t:t+h-1}}
  263. # p(\hmmhid_t|\hmmobs_{1:t})
  264. # p(\hmmhid_{t+1}|\hmmhid_{t})
  265. # p(\hmmhid_{t+2}|\hmmhid_{t+1})
  266. # \cdots
  267. # p(\hmmhid_{t+h}|\hmmhid_{t+h-1})
  268. # \end{align}
  269. # See
  270. # {numref}`fig:dbn-inference` for a summary of these distributions.
  271. #
  272. # In addition to comuting posterior marginals,
  273. # we may want to compute the most probable hidden sequence,
  274. # i.e., the joint MAP estimate
  275. # ```{math}
  276. # \arg \max_{\hmmhid_{1:T}} p(\hmmhid_{1:T}|\hmmobs_{1:T})
  277. # ```
  278. # or sample sequences from the posterior
  279. # ```{math}
  280. # \hmmhid_{1:T} \sim p(\hmmhid_{1:T}|\hmmobs_{1:T})
  281. # ```
  282. #
  283. # Algorithms for all these task are discussed in the following chapters,
  284. # since the details depend on the form of the SSM.
  285. #
  286. #
  287. #
  288. #
  289. #
  290. # (sec:casino-inference)=
  291. # ## Example: inference in the casino HMM
  292. #
  293. #
  294. # We now illustrate filtering, smoothing and MAP decoding applied
  295. # to the casino HMM from {ref}`sec:casino` and [](sec:casino).
  296. #
  297. # In[3]:
  298. # state transition matrix
  299. A = np.array([
  300. [0.95, 0.05],
  301. [0.10, 0.90]
  302. ])
  303. # observation matrix
  304. B = np.array([
  305. [1/6, 1/6, 1/6, 1/6, 1/6, 1/6], # fair die
  306. [1/10, 1/10, 1/10, 1/10, 1/10, 5/10] # loaded die
  307. ])
  308. pi = np.array([0.5, 0.5])
  309. (nstates, nobs) = np.shape(B)
  310. import distrax
  311. from distrax import HMM
  312. hmm = HMM(trans_dist=distrax.Categorical(probs=A),
  313. init_dist=distrax.Categorical(probs=pi),
  314. obs_dist=distrax.Categorical(probs=B))
  315. seed = 314
  316. n_samples = 300
  317. z_hist, x_hist = hmm.sample(seed=PRNGKey(seed), seq_len=n_samples)
  318. # In[4]:
  319. # Call inference engine
  320. filtered_dist, _, smoothed_dist, loglik = hmm.forward_backward(x_hist)
  321. map_path = hmm.viterbi(x_hist)
  322. # In[5]:
  323. # Find the span of timesteps that the simulated systems turns to be in state 1
  324. def find_dishonest_intervals(z_hist):
  325. spans = []
  326. x_init = 0
  327. for t, _ in enumerate(z_hist[:-1]):
  328. if z_hist[t + 1] == 0 and z_hist[t] == 1:
  329. x_end = t
  330. spans.append((x_init, x_end))
  331. elif z_hist[t + 1] == 1 and z_hist[t] == 0:
  332. x_init = t + 1
  333. return spans
  334. # In[6]:
  335. # Plot posterior
  336. def plot_inference(inference_values, z_hist, ax, state=1, map_estimate=False):
  337. n_samples = len(inference_values)
  338. xspan = np.arange(1, n_samples + 1)
  339. spans = find_dishonest_intervals(z_hist)
  340. if map_estimate:
  341. ax.step(xspan, inference_values, where="post")
  342. else:
  343. ax.plot(xspan, inference_values[:, state])
  344. for span in spans:
  345. ax.axvspan(*span, alpha=0.5, facecolor="tab:gray", edgecolor="none")
  346. ax.set_xlim(1, n_samples)
  347. # ax.set_ylim(0, 1)
  348. ax.set_ylim(-0.1, 1.1)
  349. ax.set_xlabel("Observation number")
  350. # In[7]:
  351. # Filtering
  352. fig, ax = plt.subplots()
  353. plot_inference(filtered_dist, z_hist, ax)
  354. ax.set_ylabel("p(loaded)")
  355. ax.set_title("Filtered")
  356. # In[8]:
  357. # Smoothing
  358. fig, ax = plt.subplots()
  359. plot_inference(smoothed_dist, z_hist, ax)
  360. ax.set_ylabel("p(loaded)")
  361. ax.set_title("Smoothed")
  362. # In[9]:
  363. # MAP estimation
  364. fig, ax = plt.subplots()
  365. plot_inference(map_path, z_hist, ax, map_estimate=True)
  366. ax.set_ylabel("MAP state")
  367. ax.set_title("Viterbi")
  368. # In[10]:
  369. # TODO: posterior samples
  370. # ## Example: inference in the tracking LG-SSM
  371. #
  372. # We now illustrate filtering, smoothing and MAP decoding applied
  373. # to the 2d tracking HMM from {ref}`sec:tracking-lds`.
  374. # In[11]:
  375. key = jax.random.PRNGKey(314)
  376. timesteps = 15
  377. delta = 1.0
  378. A = jnp.array([
  379. [1, 0, delta, 0],
  380. [0, 1, 0, delta],
  381. [0, 0, 1, 0],
  382. [0, 0, 0, 1]
  383. ])
  384. C = jnp.array([
  385. [1, 0, 0, 0],
  386. [0, 1, 0, 0]
  387. ])
  388. state_size, _ = A.shape
  389. observation_size, _ = C.shape
  390. Q = jnp.eye(state_size) * 0.001
  391. R = jnp.eye(observation_size) * 1.0
  392. mu0 = jnp.array([8, 10, 1, 0]).astype(float)
  393. Sigma0 = jnp.eye(state_size) * 1.0
  394. from jsl.lds.kalman_filter import LDS, smooth, filter
  395. lds = LDS(A, C, Q, R, mu0, Sigma0)
  396. z_hist, x_hist = lds.sample(key, timesteps)
  397. # In[12]:
  398. from jsl.demos.plot_utils import plot_ellipse
  399. def plot_tracking_values(observed, filtered, cov_hist, signal_label, ax):
  400. timesteps, _ = observed.shape
  401. ax.plot(observed[:, 0], observed[:, 1], marker="o", linewidth=0,
  402. markerfacecolor="none", markeredgewidth=2, markersize=8, label="observed", c="tab:green")
  403. ax.plot(*filtered[:, :2].T, label=signal_label, c="tab:red", marker="x", linewidth=2)
  404. for t in range(0, timesteps, 1):
  405. covn = cov_hist[t][:2, :2]
  406. plot_ellipse(covn, filtered[t, :2], ax, n_std=2.0, plot_center=False)
  407. ax.axis("equal")
  408. ax.legend()
  409. # In[13]:
  410. # Filtering
  411. mu_hist, Sigma_hist, mu_cond_hist, Sigma_cond_hist = filter(lds, x_hist)
  412. l2_filter = jnp.linalg.norm(z_hist[:, :2] - mu_hist[:, :2], 2)
  413. print(f"L2-filter: {l2_filter:0.4f}")
  414. fig_filtered, axs = plt.subplots()
  415. plot_tracking_values(x_hist, mu_hist, Sigma_hist, "filtered", axs)
  416. # In[14]:
  417. # Smoothing
  418. mu_hist_smooth, Sigma_hist_smooth = smooth(lds, mu_hist, Sigma_hist, mu_cond_hist, Sigma_cond_hist)
  419. l2_smooth = jnp.linalg.norm(z_hist[:, :2] - mu_hist_smooth[:, :2], 2)
  420. print(f"L2-smooth: {l2_smooth:0.4f}")
  421. fig_smoothed, axs = plt.subplots()
  422. plot_tracking_values(x_hist, mu_hist_smooth, Sigma_hist_smooth, "smoothed", axs)