ssm_old.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059
  1. #!/usr/bin/env python
  2. # coding: utf-8
  3. # In[1]:
  4. # meta-data does not work yet in VScode
  5. # https://github.com/microsoft/vscode-jupyter/issues/1121
  6. {
  7. "tags": [
  8. "hide-cell"
  9. ]
  10. }
  11. ### Install necessary libraries
  12. try:
  13. import jax
  14. except:
  15. # For cuda version, see https://github.com/google/jax#installation
  16. get_ipython().run_line_magic('pip', 'install --upgrade "jax[cpu]"')
  17. import jax
  18. try:
  19. import distrax
  20. except:
  21. get_ipython().run_line_magic('pip', 'install --upgrade distrax')
  22. import distrax
  23. try:
  24. import jsl
  25. except:
  26. get_ipython().run_line_magic('pip', 'install git+https://github.com/probml/jsl')
  27. import jsl
  28. try:
  29. import rich
  30. except:
  31. get_ipython().run_line_magic('pip', 'install rich')
  32. import rich
  33. # In[2]:
  34. {
  35. "tags": [
  36. "hide-cell"
  37. ]
  38. }
  39. ### Import standard libraries
  40. import abc
  41. from dataclasses import dataclass
  42. import functools
  43. import itertools
  44. from typing import Any, Callable, NamedTuple, Optional, Union, Tuple
  45. import matplotlib.pyplot as plt
  46. import numpy as np
  47. import jax
  48. import jax.numpy as jnp
  49. from jax import lax, vmap, jit, grad
  50. from jax.scipy.special import logit
  51. from jax.nn import softmax
  52. from functools import partial
  53. from jax.random import PRNGKey, split
  54. import inspect
  55. import inspect as py_inspect
  56. import rich
  57. from rich import inspect as r_inspect
  58. from rich import print as r_print
  59. def print_source(fname):
  60. r_print(py_inspect.getsource(fname))
  61. # ```{math}
  62. #
  63. # \newcommand\floor[1]{\lfloor#1\rfloor}
  64. #
  65. # \newcommand{\real}{\mathbb{R}}
  66. #
  67. # % Numbers
  68. # \newcommand{\vzero}{\boldsymbol{0}}
  69. # \newcommand{\vone}{\boldsymbol{1}}
  70. #
  71. # % Greek https://www.latex-tutorial.com/symbols/greek-alphabet/
  72. # \newcommand{\valpha}{\boldsymbol{\alpha}}
  73. # \newcommand{\vbeta}{\boldsymbol{\beta}}
  74. # \newcommand{\vchi}{\boldsymbol{\chi}}
  75. # \newcommand{\vdelta}{\boldsymbol{\delta}}
  76. # \newcommand{\vDelta}{\boldsymbol{\Delta}}
  77. # \newcommand{\vepsilon}{\boldsymbol{\epsilon}}
  78. # \newcommand{\vzeta}{\boldsymbol{\zeta}}
  79. # \newcommand{\vXi}{\boldsymbol{\Xi}}
  80. # \newcommand{\vell}{\boldsymbol{\ell}}
  81. # \newcommand{\veta}{\boldsymbol{\eta}}
  82. # %\newcommand{\vEta}{\boldsymbol{\Eta}}
  83. # \newcommand{\vgamma}{\boldsymbol{\gamma}}
  84. # \newcommand{\vGamma}{\boldsymbol{\Gamma}}
  85. # \newcommand{\vmu}{\boldsymbol{\mu}}
  86. # \newcommand{\vmut}{\boldsymbol{\tilde{\mu}}}
  87. # \newcommand{\vnu}{\boldsymbol{\nu}}
  88. # \newcommand{\vkappa}{\boldsymbol{\kappa}}
  89. # \newcommand{\vlambda}{\boldsymbol{\lambda}}
  90. # \newcommand{\vLambda}{\boldsymbol{\Lambda}}
  91. # \newcommand{\vLambdaBar}{\overline{\vLambda}}
  92. # %\newcommand{\vnu}{\boldsymbol{\nu}}
  93. # \newcommand{\vomega}{\boldsymbol{\omega}}
  94. # \newcommand{\vOmega}{\boldsymbol{\Omega}}
  95. # \newcommand{\vphi}{\boldsymbol{\phi}}
  96. # \newcommand{\vvarphi}{\boldsymbol{\varphi}}
  97. # \newcommand{\vPhi}{\boldsymbol{\Phi}}
  98. # \newcommand{\vpi}{\boldsymbol{\pi}}
  99. # \newcommand{\vPi}{\boldsymbol{\Pi}}
  100. # \newcommand{\vpsi}{\boldsymbol{\psi}}
  101. # \newcommand{\vPsi}{\boldsymbol{\Psi}}
  102. # \newcommand{\vrho}{\boldsymbol{\rho}}
  103. # \newcommand{\vtheta}{\boldsymbol{\theta}}
  104. # \newcommand{\vthetat}{\boldsymbol{\tilde{\theta}}}
  105. # \newcommand{\vTheta}{\boldsymbol{\Theta}}
  106. # \newcommand{\vsigma}{\boldsymbol{\sigma}}
  107. # \newcommand{\vSigma}{\boldsymbol{\Sigma}}
  108. # \newcommand{\vSigmat}{\boldsymbol{\tilde{\Sigma}}}
  109. # \newcommand{\vsigmoid}{\vsigma}
  110. # \newcommand{\vtau}{\boldsymbol{\tau}}
  111. # \newcommand{\vxi}{\boldsymbol{\xi}}
  112. #
  113. #
  114. # % Lower Roman (Vectors)
  115. # \newcommand{\va}{\mathbf{a}}
  116. # \newcommand{\vb}{\mathbf{b}}
  117. # \newcommand{\vBt}{\mathbf{\tilde{B}}}
  118. # \newcommand{\vc}{\mathbf{c}}
  119. # \newcommand{\vct}{\mathbf{\tilde{c}}}
  120. # \newcommand{\vd}{\mathbf{d}}
  121. # \newcommand{\ve}{\mathbf{e}}
  122. # \newcommand{\vf}{\mathbf{f}}
  123. # \newcommand{\vg}{\mathbf{g}}
  124. # \newcommand{\vh}{\mathbf{h}}
  125. # %\newcommand{\myvh}{\mathbf{h}}
  126. # \newcommand{\vi}{\mathbf{i}}
  127. # \newcommand{\vj}{\mathbf{j}}
  128. # \newcommand{\vk}{\mathbf{k}}
  129. # \newcommand{\vl}{\mathbf{l}}
  130. # \newcommand{\vm}{\mathbf{m}}
  131. # \newcommand{\vn}{\mathbf{n}}
  132. # \newcommand{\vo}{\mathbf{o}}
  133. # \newcommand{\vp}{\mathbf{p}}
  134. # \newcommand{\vq}{\mathbf{q}}
  135. # \newcommand{\vr}{\mathbf{r}}
  136. # \newcommand{\vs}{\mathbf{s}}
  137. # \newcommand{\vt}{\mathbf{t}}
  138. # \newcommand{\vu}{\mathbf{u}}
  139. # \newcommand{\vv}{\mathbf{v}}
  140. # \newcommand{\vw}{\mathbf{w}}
  141. # \newcommand{\vws}{\vw_s}
  142. # \newcommand{\vwt}{\mathbf{\tilde{w}}}
  143. # \newcommand{\vWt}{\mathbf{\tilde{W}}}
  144. # \newcommand{\vwh}{\hat{\vw}}
  145. # \newcommand{\vx}{\mathbf{x}}
  146. # %\newcommand{\vx}{\mathbf{x}}
  147. # \newcommand{\vxt}{\mathbf{\tilde{x}}}
  148. # \newcommand{\vy}{\mathbf{y}}
  149. # \newcommand{\vyt}{\mathbf{\tilde{y}}}
  150. # \newcommand{\vz}{\mathbf{z}}
  151. # %\newcommand{\vzt}{\mathbf{\tilde{z}}}
  152. #
  153. #
  154. # % Upper Roman (Matrices)
  155. # \newcommand{\vA}{\mathbf{A}}
  156. # \newcommand{\vB}{\mathbf{B}}
  157. # \newcommand{\vC}{\mathbf{C}}
  158. # \newcommand{\vD}{\mathbf{D}}
  159. # \newcommand{\vE}{\mathbf{E}}
  160. # \newcommand{\vF}{\mathbf{F}}
  161. # \newcommand{\vG}{\mathbf{G}}
  162. # \newcommand{\vH}{\mathbf{H}}
  163. # \newcommand{\vI}{\mathbf{I}}
  164. # \newcommand{\vJ}{\mathbf{J}}
  165. # \newcommand{\vK}{\mathbf{K}}
  166. # \newcommand{\vL}{\mathbf{L}}
  167. # \newcommand{\vM}{\mathbf{M}}
  168. # \newcommand{\vMt}{\mathbf{\tilde{M}}}
  169. # \newcommand{\vN}{\mathbf{N}}
  170. # \newcommand{\vO}{\mathbf{O}}
  171. # \newcommand{\vP}{\mathbf{P}}
  172. # \newcommand{\vQ}{\mathbf{Q}}
  173. # \newcommand{\vR}{\mathbf{R}}
  174. # \newcommand{\vS}{\mathbf{S}}
  175. # \newcommand{\vT}{\mathbf{T}}
  176. # \newcommand{\vU}{\mathbf{U}}
  177. # \newcommand{\vV}{\mathbf{V}}
  178. # \newcommand{\vW}{\mathbf{W}}
  179. # \newcommand{\vX}{\mathbf{X}}
  180. # %\newcommand{\vXs}{\vX_{\vs}}
  181. # \newcommand{\vXs}{\vX_{s}}
  182. # \newcommand{\vXt}{\mathbf{\tilde{X}}}
  183. # \newcommand{\vY}{\mathbf{Y}}
  184. # \newcommand{\vZ}{\mathbf{Z}}
  185. # \newcommand{\vZt}{\mathbf{\tilde{Z}}}
  186. # \newcommand{\vzt}{\mathbf{\tilde{z}}}
  187. #
  188. #
  189. # %%%%
  190. # \newcommand{\hidden}{\vz}
  191. # \newcommand{\hid}{\hidden}
  192. # \newcommand{\observed}{\vy}
  193. # \newcommand{\obs}{\observed}
  194. # \newcommand{\inputs}{\vu}
  195. # \newcommand{\input}{\inputs}
  196. #
  197. # \newcommand{\hmmTrans}{\vA}
  198. # \newcommand{\hmmObs}{\vB}
  199. # \newcommand{\hmmInit}{\vpi}
  200. # \newcommand{\hmmhid}{\hidden}
  201. # \newcommand{\hmmobs}{\obs}
  202. #
  203. # \newcommand{\ldsDyn}{\vA}
  204. # \newcommand{\ldsObs}{\vC}
  205. # \newcommand{\ldsDynIn}{\vB}
  206. # \newcommand{\ldsObsIn}{\vD}
  207. # \newcommand{\ldsDynNoise}{\vQ}
  208. # \newcommand{\ldsObsNoise}{\vR}
  209. #
  210. # \newcommand{\ssmDynFn}{f}
  211. # \newcommand{\ssmObsFn}{h}
  212. #
  213. #
  214. # %%%
  215. # \newcommand{\gauss}{\mathcal{N}}
  216. #
  217. # \newcommand{\diag}{\mathrm{diag}}
  218. # ```
  219. #
  220. # (sec:ssm-intro)=
  221. # # What are State Space Models?
  222. #
  223. #
  224. # A state space model or SSM
  225. # is a partially observed Markov model,
  226. # in which the hidden state, $\hidden_t$,
  227. # evolves over time according to a Markov process,
  228. # possibly conditional on external inputs or controls $\input_t$,
  229. # and each hidden state generates some
  230. # observations $\obs_t$ at each time step.
  231. # (In this book, we mostly focus on discrete time systems,
  232. # although we consider the continuous-time case in XXX.)
  233. # We get to see the observations, but not the hidden state.
  234. # Our main goal is to infer the hidden state given the observations.
  235. # However, we can also use the model to predict future observations,
  236. # by first predicting future hidden states, and then predicting
  237. # what observations they might generate.
  238. # By using a hidden state $\hidden_t$
  239. # to represent the past observations, $\obs_{1:t-1}$,
  240. # the model can have ``infinite'' memory,
  241. # unlike a standard Markov model.
  242. #
  243. # Formally we can define an SSM
  244. # as the following joint distribution:
  245. # ```{math}
  246. # :label: eq:SSM-ar
  247. # p(\hmmobs_{1:T},\hmmhid_{1:T}|\inputs_{1:T})
  248. # = \left[ p(\hmmhid_1|\inputs_1) \prod_{t=2}^{T}
  249. # p(\hmmhid_t|\hmmhid_{t-1},\inputs_t) \right]
  250. # \left[ \prod_{t=1}^T p(\hmmobs_t|\hmmhid_t, \inputs_t, \hmmobs_{t-1}) \right]
  251. # ```
  252. # where $p(\hmmhid_t|\hmmhid_{t-1},\inputs_t)$ is the
  253. # transition model,
  254. # $p(\hmmobs_t|\hmmhid_t, \inputs_t, \hmmobs_{t-1})$ is the
  255. # observation model,
  256. # and $\inputs_{t}$ is an optional input or action.
  257. # See {numref}`Figure %s <ssm-ar>`
  258. # for an illustration of the corresponding graphical model.
  259. #
  260. #
  261. # ```{figure} /figures/SSM-AR-inputs.png
  262. # :scale: 100%
  263. # :name: ssm-ar
  264. #
  265. # Illustration of an SSM as a graphical model.
  266. # ```
  267. #
  268. #
  269. # We often consider a simpler setting in which the
  270. # observations are conditionally independent of each other
  271. # (rather than having Markovian dependencies) given the hidden state.
  272. # In this case the joint simplifies to
  273. # ```{math}
  274. # :label: eq:SSM-input
  275. # p(\hmmobs_{1:T},\hmmhid_{1:T}|\inputs_{1:T})
  276. # = \left[ p(\hmmhid_1|\inputs_1) \prod_{t=2}^{T}
  277. # p(\hmmhid_t|\hmmhid_{t-1},\inputs_t) \right]
  278. # \left[ \prod_{t=1}^T p(\hmmobs_t|\hmmhid_t, \inputs_t) \right]
  279. # ```
  280. # Sometimes there are no external inputs, so the model further
  281. # simplifies to the following unconditional generative model:
  282. # ```{math}
  283. # :label: eq:SSM-no-input
  284. # p(\hmmobs_{1:T},\hmmhid_{1:T})
  285. # = \left[ p(\hmmhid_1) \prod_{t=2}^{T}
  286. # p(\hmmhid_t|\hmmhid_{t-1}) \right]
  287. # \left[ \prod_{t=1}^T p(\hmmobs_t|\hmmhid_t) \right]
  288. # ```
  289. # See {numref}`Figure %s <ssm-simplified>`
  290. # for an illustration of the corresponding graphical model.
  291. #
  292. #
  293. # ```{figure} /figures/SSM-simplified.png
  294. # :scale: 100%
  295. # :name: ssm-simplified
  296. #
  297. # Illustration of a simplified SSM.
  298. # ```
  299. #
  300. #
  301. # (sec:hmm-intro)=
  302. # # Hidden Markov Models
  303. #
  304. # In this section, we discuss the
  305. # hidden Markov model or HMM,
  306. # which is a state space model in which the hidden states
  307. # are discrete, so $\hmmhid_t \in \{1,\ldots, K\}$.
  308. # The observations may be discrete,
  309. # $\hmmobs_t \in \{1,\ldots, C\}$,
  310. # or continuous,
  311. # $\hmmobs_t \in \real^D$,
  312. # or some combination,
  313. # as we illustrate below.
  314. # More details can be found in e.g.,
  315. # {cite}`Rabiner89,Fraser08,Cappe05`.
  316. # For an interactive introduction,
  317. # see https://nipunbatra.github.io/hmm/.
  318. # (sec:casino)=
  319. # ### Example: Casino HMM
  320. #
  321. # To illustrate HMMs with categorical observation model,
  322. # we consider the "Ocassionally dishonest casino" model from {cite}`Durbin98`.
  323. # There are 2 hidden states, representing whether the dice being used in the casino is fair or loaded.
  324. # Each state defines a distribution over the 6 possible observations.
  325. #
  326. # The transition model is denoted by
  327. # ```{math}
  328. # p(z_t=j|z_{t-1}=i) = \hmmTrans_{ij}
  329. # ```
  330. # Here the $i$'th row of $\vA$ corresponds to the outgoing distribution from state $i$.
  331. # This is a row stochastic matrix,
  332. # meaning each row sums to one.
  333. # We can visualize
  334. # the non-zero entries in the transition matrix by creating a state transition diagram,
  335. # as shown in
  336. # {numref}`Figure %s <casino-fig>`
  337. # %{ref}`casino-fig`.
  338. #
  339. # ```{figure} /figures/casino.png
  340. # :scale: 50%
  341. # :name: casino-fig
  342. #
  343. # Illustration of the casino HMM.
  344. # ```
  345. #
  346. # The observation model
  347. # $p(\obs_t|\hidden_t=j)$ has the form
  348. # ```{math}
  349. # p(\obs_t=k|\hidden_t=j) = \hmmObs_{jk}
  350. # ```
  351. # This is represented by the histograms associated with each
  352. # state in {ref}`casino-fig`.
  353. #
  354. # Finally,
  355. # the initial state distribution is denoted by
  356. # ```{math}
  357. # p(z_1=j) = \hmmInit_j
  358. # ```
  359. #
  360. # Collectively we denote all the parameters by $\vtheta=(\hmmTrans, \hmmObs, \hmmInit)$.
  361. #
  362. # Now let us implement this model in code.
  363. # In[3]:
  364. # state transition matrix
  365. A = np.array([
  366. [0.95, 0.05],
  367. [0.10, 0.90]
  368. ])
  369. # observation matrix
  370. B = np.array([
  371. [1/6, 1/6, 1/6, 1/6, 1/6, 1/6], # fair die
  372. [1/10, 1/10, 1/10, 1/10, 1/10, 5/10] # loaded die
  373. ])
  374. pi = np.array([0.5, 0.5])
  375. (nstates, nobs) = np.shape(B)
  376. # In[4]:
  377. import distrax
  378. from distrax import HMM
  379. hmm = HMM(trans_dist=distrax.Categorical(probs=A),
  380. init_dist=distrax.Categorical(probs=pi),
  381. obs_dist=distrax.Categorical(probs=B))
  382. print(hmm)
  383. #
  384. # Let's sample from the model. We will generate a sequence of latent states, $\hid_{1:T}$,
  385. # which we then convert to a sequence of observations, $\obs_{1:T}$.
  386. # In[5]:
  387. seed = 314
  388. n_samples = 300
  389. z_hist, x_hist = hmm.sample(seed=PRNGKey(seed), seq_len=n_samples)
  390. z_hist_str = "".join((np.array(z_hist) + 1).astype(str))[:60]
  391. x_hist_str = "".join((np.array(x_hist) + 1).astype(str))[:60]
  392. print("Printing sample observed/latent...")
  393. print(f"x: {x_hist_str}")
  394. print(f"z: {z_hist_str}")
  395. # In[6]:
  396. # Here is the source code for the sampling algorithm.
  397. print_source(hmm.sample)
  398. # Our primary goal will be to infer the latent state from the observations,
  399. # so we can detect if the casino is being dishonest or not. This will
  400. # affect how we choose to gamble our money.
  401. # We discuss various ways to perform this inference below.
  402. # (sec:lillypad)=
  403. # ## Example: Lillypad HMM
  404. #
  405. #
  406. # If $\obs_t$ is continuous, it is common to use a Gaussian
  407. # observation model:
  408. # ```{math}
  409. # p(\obs_t|\hidden_t=j) = \gauss(\obs_t|\vmu_j,\vSigma_j)
  410. # ```
  411. # This is sometimes called a Gaussian HMM.
  412. #
  413. # As a simple example, suppose we have an HMM with 3 hidden states,
  414. # each of which generates a 2d Gaussian.
  415. # We can represent these Gaussian distributions are 2d ellipses,
  416. # as we show below.
  417. # We call these ``lilly pads'', because of their shape.
  418. # We can imagine a frog hopping from one lilly pad to another.
  419. # (This analogy is due to the late Sam Roweis.)
  420. # The frog will stay on a pad for a while (corresponding to remaining in the same
  421. # discrete state $\hidden_t$), and then jump to a new pad
  422. # (corresponding to a transition to a new state).
  423. # The data we see are just the 2d points (e.g., water droplets)
  424. # coming from near the pad that the frog is currently on.
  425. # Thus this model is like a Gaussian mixture model,
  426. # in that it generates clusters of observations,
  427. # except now there is temporal correlation between the data points.
  428. #
  429. # Let us now illustrate this model in code.
  430. #
  431. #
  432. # In[7]:
  433. # Let us create the model
  434. initial_probs = jnp.array([0.3, 0.2, 0.5])
  435. # transition matrix
  436. A = jnp.array([
  437. [0.3, 0.4, 0.3],
  438. [0.1, 0.6, 0.3],
  439. [0.2, 0.3, 0.5]
  440. ])
  441. # Observation model
  442. mu_collection = jnp.array([
  443. [0.3, 0.3],
  444. [0.8, 0.5],
  445. [0.3, 0.8]
  446. ])
  447. S1 = jnp.array([[1.1, 0], [0, 0.3]])
  448. S2 = jnp.array([[0.3, -0.5], [-0.5, 1.3]])
  449. S3 = jnp.array([[0.8, 0.4], [0.4, 0.5]])
  450. cov_collection = jnp.array([S1, S2, S3]) / 60
  451. import tensorflow_probability as tfp
  452. if False:
  453. hmm = HMM(trans_dist=distrax.Categorical(probs=A),
  454. init_dist=distrax.Categorical(probs=initial_probs),
  455. obs_dist=distrax.MultivariateNormalFullCovariance(
  456. loc=mu_collection, covariance_matrix=cov_collection))
  457. else:
  458. hmm = HMM(trans_dist=distrax.Categorical(probs=A),
  459. init_dist=distrax.Categorical(probs=initial_probs),
  460. obs_dist=distrax.as_distribution(
  461. tfp.substrates.jax.distributions.MultivariateNormalFullCovariance(loc=mu_collection,
  462. covariance_matrix=cov_collection)))
  463. print(hmm)
  464. # In[8]:
  465. n_samples, seed = 50, 10
  466. samples_state, samples_obs = hmm.sample(seed=PRNGKey(seed), seq_len=n_samples)
  467. print(samples_state.shape)
  468. print(samples_obs.shape)
  469. # In[9]:
  470. # Let's plot the observed data in 2d
  471. xmin, xmax = 0, 1
  472. ymin, ymax = 0, 1.2
  473. colors = ["tab:green", "tab:blue", "tab:red"]
  474. def plot_2dhmm(hmm, samples_obs, samples_state, colors, ax, xmin, xmax, ymin, ymax, step=1e-2):
  475. obs_dist = hmm.obs_dist
  476. color_sample = [colors[i] for i in samples_state]
  477. xs = jnp.arange(xmin, xmax, step)
  478. ys = jnp.arange(ymin, ymax, step)
  479. v_prob = vmap(lambda x, y: obs_dist.prob(jnp.array([x, y])), in_axes=(None, 0))
  480. z = vmap(v_prob, in_axes=(0, None))(xs, ys)
  481. grid = np.mgrid[xmin:xmax:step, ymin:ymax:step]
  482. for k, color in enumerate(colors):
  483. ax.contour(*grid, z[:, :, k], levels=[1], colors=color, linewidths=3)
  484. ax.text(*(obs_dist.mean()[k] + 0.13), f"$k$={k + 1}", fontsize=13, horizontalalignment="right")
  485. ax.plot(*samples_obs.T, c="black", alpha=0.3, zorder=1)
  486. ax.scatter(*samples_obs.T, c=color_sample, s=30, zorder=2, alpha=0.8)
  487. return ax, color_sample
  488. fig, ax = plt.subplots()
  489. _, color_sample = plot_2dhmm(hmm, samples_obs, samples_state, colors, ax, xmin, xmax, ymin, ymax)
  490. # In[10]:
  491. # Let's plot the hidden state sequence
  492. fig, ax = plt.subplots()
  493. ax.step(range(n_samples), samples_state, where="post", c="black", linewidth=1, alpha=0.3)
  494. ax.scatter(range(n_samples), samples_state, c=color_sample, zorder=3)
  495. # (sec:lds-intro)=
  496. # # Linear Gaussian SSMs
  497. #
  498. #
  499. # Consider the state space model in
  500. # {eq}`eq:SSM-ar`
  501. # where we assume the observations are conditionally iid given the
  502. # hidden states and inputs (i.e. there are no auto-regressive dependencies
  503. # between the observables).
  504. # We can rewrite this model as
  505. # a stochastic nonlinear dynamical system (NLDS)
  506. # by defining the distribution of the next hidden state
  507. # as a deterministic function of the past state
  508. # plus random process noise $\vepsilon_t$
  509. # \begin{align}
  510. # \hmmhid_t &= \ssmDynFn(\hmmhid_{t-1}, \inputs_t, \vepsilon_t)
  511. # \end{align}
  512. # where $\vepsilon_t$ is drawn from the distribution such
  513. # that the induced distribution
  514. # on $\hmmhid_t$ matches $p(\hmmhid_t|\hmmhid_{t-1}, \inputs_t)$.
  515. # Similarly we can rewrite the observation distributions
  516. # as a deterministic function of the hidden state
  517. # plus observation noise $\veta_t$:
  518. # \begin{align}
  519. # \hmmobs_t &= \ssmObsFn(\hmmhid_{t}, \inputs_t, \veta_t)
  520. # \end{align}
  521. #
  522. #
  523. # If we assume additive Gaussian noise,
  524. # the model becomes
  525. # \begin{align}
  526. # \hmmhid_t &= \ssmDynFn(\hmmhid_{t-1}, \inputs_t) + \vepsilon_t \\
  527. # \hmmobs_t &= \ssmObsFn(\hmmhid_{t}, \inputs_t) + \veta_t
  528. # \end{align}
  529. # where $\vepsilon_t \sim \gauss(\vzero,\vQ_t)$
  530. # and $\veta_t \sim \gauss(\vzero,\vR_t)$.
  531. # We will call these Gaussian SSMs.
  532. #
  533. # If we additionally assume
  534. # the transition function $\ssmDynFn$
  535. # and the observation function $\ssmObsFn$ are both linear,
  536. # then we can rewrite the model as follows:
  537. # \begin{align}
  538. # p(\hmmhid_t|\hmmhid_{t-1},\inputs_t) &= \gauss(\hmmhid_t|\ldsDyn_t \hmmhid_{t-1}
  539. # + \ldsDynIn_t \inputs_t, \vQ_t)
  540. # \\
  541. # p(\hmmobs_t|\hmmhid_t,\inputs_t) &= \gauss(\hmmobs_t|\ldsObs_t \hmmhid_{t}
  542. # + \ldsObsIn_t \inputs_t, \vR_t)
  543. # \end{align}
  544. # This is called a
  545. # linear-Gaussian state space model
  546. # (LG-SSM),
  547. # or a
  548. # linear dynamical system (LDS).
  549. # We usually assume the parameters are independent of time, in which case
  550. # the model is said to be time-invariant or homogeneous.
  551. #
  552. # (sec:tracking-lds)=
  553. # (sec:kalman-tracking)=
  554. # ## Example: tracking a 2d point
  555. #
  556. #
  557. #
  558. # % Sarkkar p43
  559. # Consider an object moving in $\real^2$.
  560. # Let the state be
  561. # the position and velocity of the object,
  562. # $$\vz_t =\begin{pmatrix} u_t & \dot{u}_t & v_t & \dot{v}_t \end{pmatrix}$$.
  563. # (We use $u$ and $v$ for the two coordinates,
  564. # to avoid confusion with the state and observation variables.)
  565. # If we use Euler discretization,
  566. # the dynamics become
  567. # \begin{align}
  568. # \underbrace{\begin{pmatrix} u_t\\ \dot{u}_t \\ v_t \\ \dot{v}_t \end{pmatrix}}_{\vz_t}
  569. # =
  570. # \underbrace{
  571. # \begin{pmatrix}
  572. # 1 & 0 & \Delta & 0 \\
  573. # 0 & 1 & 0 & \Delta\\
  574. # 0 & 0 & 1 & 0 \\
  575. # 0 & 0 & 0 & 1
  576. # \end{pmatrix}
  577. # }_{\ldsDyn}
  578. # \
  579. # \underbrace{\begin{pmatrix} u_{t-1} \\ \dot{u}_{t-1} \\ v_{t-1} \\ \dot{v}_{t-1} \end{pmatrix}}_{\vz_{t-1}}
  580. # + \vepsilon_t
  581. # \end{align}
  582. # where $\vepsilon_t \sim \gauss(\vzero,\vQ)$ is
  583. # the process noise.
  584. #
  585. # Let us assume
  586. # that the process noise is
  587. # a white noise process added to the velocity components
  588. # of the state, but not to the location.
  589. # (This is known as a random accelerations model.)
  590. # We can approximate the resulting process in discrete time by assuming
  591. # $\vQ = \diag(0, q, 0, q)$.
  592. # (See {cite}`Sarkka13` p60 for a more accurate way
  593. # to convert the continuous time process to discrete time.)
  594. #
  595. #
  596. # Now suppose that at each discrete time point we
  597. # observe the location,
  598. # corrupted by Gaussian noise.
  599. # Thus the observation model becomes
  600. # \begin{align}
  601. # \underbrace{\begin{pmatrix} y_{1,t} \\ y_{2,t} \end{pmatrix}}_{\vy_t}
  602. # &=
  603. # \underbrace{
  604. # \begin{pmatrix}
  605. # 1 & 0 & 0 & 0 \\
  606. # 0 & 0 & 1 & 0
  607. # \end{pmatrix}
  608. # }_{\ldsObs}
  609. # \
  610. # \underbrace{\begin{pmatrix} u_t\\ \dot{u}_t \\ v_t \\ \dot{v}_t \end{pmatrix}}_{\vz_t}
  611. # + \veta_t
  612. # \end{align}
  613. # where $\veta_t \sim \gauss(\vzero,\vR)$ is the \keywordDef{observation noise}.
  614. # We see that the observation matrix $\ldsObs$ simply ``extracts'' the
  615. # relevant parts of the state vector.
  616. #
  617. # Suppose we sample a trajectory and corresponding set
  618. # of noisy observations from this model,
  619. # $(\vz_{1:T}, \vy_{1:T}) \sim p(\vz,\vy|\vtheta)$.
  620. # (We use diagonal observation noise,
  621. # $\vR = \diag(\sigma_1^2, \sigma_2^2)$.)
  622. # The results are shown below.
  623. #
  624. # In[11]:
  625. key = jax.random.PRNGKey(314)
  626. timesteps = 15
  627. delta = 1.0
  628. A = jnp.array([
  629. [1, 0, delta, 0],
  630. [0, 1, 0, delta],
  631. [0, 0, 1, 0],
  632. [0, 0, 0, 1]
  633. ])
  634. C = jnp.array([
  635. [1, 0, 0, 0],
  636. [0, 1, 0, 0]
  637. ])
  638. state_size, _ = A.shape
  639. observation_size, _ = C.shape
  640. Q = jnp.eye(state_size) * 0.001
  641. R = jnp.eye(observation_size) * 1.0
  642. # Prior parameter distribution
  643. mu0 = jnp.array([8, 10, 1, 0]).astype(float)
  644. Sigma0 = jnp.eye(state_size) * 1.0
  645. from jsl.lds.kalman_filter import LDS, smooth, filter
  646. lds = LDS(A, C, Q, R, mu0, Sigma0)
  647. print(lds)
  648. # In[12]:
  649. from jsl.demos.plot_utils import plot_ellipse
  650. def plot_tracking_values(observed, filtered, cov_hist, signal_label, ax):
  651. timesteps, _ = observed.shape
  652. ax.plot(observed[:, 0], observed[:, 1], marker="o", linewidth=0,
  653. markerfacecolor="none", markeredgewidth=2, markersize=8, label="observed", c="tab:green")
  654. ax.plot(*filtered[:, :2].T, label=signal_label, c="tab:red", marker="x", linewidth=2)
  655. for t in range(0, timesteps, 1):
  656. covn = cov_hist[t][:2, :2]
  657. plot_ellipse(covn, filtered[t, :2], ax, n_std=2.0, plot_center=False)
  658. ax.axis("equal")
  659. ax.legend()
  660. # In[13]:
  661. z_hist, x_hist = lds.sample(key, timesteps)
  662. fig_truth, axs = plt.subplots()
  663. axs.plot(x_hist[:, 0], x_hist[:, 1],
  664. marker="o", linewidth=0, markerfacecolor="none",
  665. markeredgewidth=2, markersize=8,
  666. label="observed", c="tab:green")
  667. axs.plot(z_hist[:, 0], z_hist[:, 1],
  668. linewidth=2, label="truth",
  669. marker="s", markersize=8)
  670. axs.legend()
  671. axs.axis("equal")
  672. # The main task is to infer the hidden states given the noisy
  673. # observations, i.e., $p(\vz|\vy,\vtheta)$. We discuss the topic of inference in {ref}`sec:inference`.
  674. # (sec:nlds-intro)=
  675. # # Nonlinear Gaussian SSMs
  676. #
  677. # In this section, we consider SSMs in which the dynamics and/or observation models are nonlinear,
  678. # but the process noise and observation noise are Gaussian.
  679. # That is,
  680. # \begin{align}
  681. # \hmmhid_t &= \ssmDynFn(\hmmhid_{t-1}, \inputs_t) + \vepsilon_t \\
  682. # \hmmobs_t &= \ssmObsFn(\hmmhid_{t}, \inputs_t) + \veta_t
  683. # \end{align}
  684. # where $\vepsilon_t \sim \gauss(\vzero,\vQ_t)$
  685. # and $\veta_t \sim \gauss(\vzero,\vR_t)$.
  686. # This is a very widely used model class. We give some examples below.
  687. # (sec:pendulum)=
  688. # ## Example: tracking a 1d pendulum
  689. #
  690. # ```{figure} /figures/pendulum.png
  691. # :scale: 100%
  692. # :name: fig:pendulum
  693. #
  694. # Illustration of a pendulum swinging.
  695. # $g$ is the force of gravity,
  696. # $w(t)$ is a random external force,
  697. # and $\alpha$ is the angle wrt the vertical.
  698. # Based on {cite}`Sarkka13` fig 3.10.
  699. #
  700. # ```
  701. #
  702. #
  703. # % Sarka p45, p74
  704. # Consider a simple pendulum of unit mass and length swinging from
  705. # a fixed attachment, as in {ref}`fig:pendulum`.
  706. # Such an object is in principle entirely deterministic in its behavior.
  707. # However, in the real world, there are often unknown forces at work
  708. # (e.g., air turbulence, friction).
  709. # We will model these by a continuous time random Gaussian noise process $w(t)$.
  710. # This gives rise to the following differential equation:
  711. # \begin{align}
  712. # \frac{d^2 \alpha}{d t^2}
  713. # = -g \sin(\alpha) + w(t)
  714. # \end{align}
  715. # We can write this as a nonlinear SSM by defining the state to be
  716. # $z_1(t) = \alpha(t)$ and $z_2(t) = d\alpha(t)/dt$.
  717. # Thus
  718. # \begin{align}
  719. # \frac{d \vz}{dt}
  720. # = \begin{pmatrix} z_2 \\ -g \sin(z_1) \end{pmatrix}
  721. # + \begin{pmatrix} 0 \\ 1 \end{pmatrix} w(t)
  722. # \end{align}
  723. # If we discretize this step size $\Delta$,
  724. # we get the following
  725. # formulation {cite}`Sarkka13` p74:
  726. # \begin{align}
  727. # \underbrace{
  728. # \begin{pmatrix} z_{1,t} \\ z_{2,t} \end{pmatrix}
  729. # }_{\hmmhid_t}
  730. # =
  731. # \underbrace{
  732. # \begin{pmatrix} z_{1,t-1} + z_{2,t-1} \Delta \\
  733. # z_{2,t-1} -g \sin(z_{1,t-1}) \Delta \end{pmatrix}
  734. # }_{\vf(\hmmhid_{t-1})}
  735. # +\vq_{t-1}
  736. # \end{align}
  737. # where $\vq_{t-1} \sim \gauss(\vzero,\vQ)$ with
  738. # \begin{align}
  739. # \vQ = q^c \begin{pmatrix}
  740. # \frac{\Delta^3}{3} & \frac{\Delta^2}{2} \\
  741. # \frac{\Delta^2}{2} & \Delta
  742. # \end{pmatrix}
  743. # \end{align}
  744. # where $q^c$ is the spectral density (continuous time variance)
  745. # of the continuous-time noise process.
  746. #
  747. #
  748. # If we observe the angular position, we
  749. # get the linear observation model
  750. # \begin{align}
  751. # y_t = \alpha_t + r_t = h(\hmmhid_t) + r_t
  752. # \end{align}
  753. # where $h(\hmmhid_t) = z_{1,t}$
  754. # and $r_t$ is the observation noise.
  755. # If we only observe the horizontal position,
  756. # we get the nonlinear observation model
  757. # \begin{align}
  758. # y_t = \sin(\alpha_t) + r_t = h(\hmmhid_t) + r_t
  759. # \end{align}
  760. # where $h(\hmmhid_t) = \sin(z_{1,t})$.
  761. #
  762. #
  763. #
  764. #
  765. #
  766. #
  767. # (sec:inference)=
  768. # # Inferential goals
  769. #
  770. # ```{figure} /figures/inference-problems-tikz.png
  771. # :scale: 100%
  772. # :name: fig:dbn-inference
  773. #
  774. # Illustration of the different kinds of inference in an SSM.
  775. # The main kinds of inference for state-space models.
  776. # The shaded region is the interval for which we have data.
  777. # The arrow represents the time step at which we want to perform inference.
  778. # $t$ is the current time, $T$ is the sequence length,
  779. # $\ell$ is the lag and $h$ is the prediction horizon.
  780. # ```
  781. #
  782. #
  783. #
  784. # Given the sequence of observations, and a known model,
  785. # one of the main tasks with SSMs
  786. # to perform posterior inference,
  787. # about the hidden states; this is also called
  788. # state estimation.
  789. # At each time step $t$,
  790. # there are multiple forms of posterior we may be interested in computing,
  791. # including the following:
  792. # - the filtering distribution
  793. # $p(\hmmhid_t|\hmmobs_{1:t})$
  794. # - the smoothing distribution
  795. # $p(\hmmhid_t|\hmmobs_{1:T})$ (note that this conditions on future data $T>t$)
  796. # - the fixed-lag smoothing distribution
  797. # $p(\hmmhid_{t-\ell}|\hmmobs_{1:t})$ (note that this
  798. # infers $\ell$ steps in the past given data up to the present).
  799. #
  800. # We may also want to compute the
  801. # predictive distribution $h$ steps into the future:
  802. # \begin{align}
  803. # p(\hmmobs_{t+h}|\hmmobs_{1:t})
  804. # &= \sum_{\hmmhid_{t+h}} p(\hmmobs_{t+h}|\hmmhid_{t+h}) p(\hmmhid_{t+h}|\hmmobs_{1:t})
  805. # \end{align}
  806. # where the hidden state predictive distribution is
  807. # \begin{align}
  808. # p(\hmmhid_{t+h}|\hmmobs_{1:t})
  809. # &= \sum_{\hmmhid_{t:t+h-1}}
  810. # p(\hmmhid_t|\hmmobs_{1:t})
  811. # p(\hmmhid_{t+1}|\hmmhid_{t})
  812. # p(\hmmhid_{t+2}|\hmmhid_{t+1})
  813. # \cdots
  814. # p(\hmmhid_{t+h}|\hmmhid_{t+h-1})
  815. # \end{align}
  816. # See {ref}`fig:dbn-inference` for a summary of these distributions.
  817. #
  818. # In addition to comuting posterior marginals,
  819. # we may want to compute the most probable hidden sequence,
  820. # i.e., the joint MAP estimate
  821. # ```{math}
  822. # \arg \max_{\hmmhid_{1:T}} p(\hmmhid_{1:T}|\hmmobs_{1:T})
  823. # ```
  824. # or sample sequences from the posterior
  825. # ```{math}
  826. # \hmmhid_{1:T} \sim p(\hmmhid_{1:T}|\hmmobs_{1:T})
  827. # ```
  828. #
  829. # Algorithms for all these task are discussed in the following chapters,
  830. # since the details depend on the form of the SSM.
  831. #
  832. #
  833. #
  834. #
  835. #
  836. # ## Example: inference in the casino HMM
  837. #
  838. # We now illustrate filtering, smoothing and MAP decoding applied
  839. # to the casino HMM from {ref}`sec:casino`.
  840. #
  841. # In[14]:
  842. # Call inference engine
  843. filtered_dist, _, smoothed_dist, loglik = hmm.forward_backward(x_hist)
  844. map_path = hmm.viterbi(x_hist)
  845. # In[15]:
  846. # Find the span of timesteps that the simulated systems turns to be in state 1
  847. def find_dishonest_intervals(z_hist):
  848. spans = []
  849. x_init = 0
  850. for t, _ in enumerate(z_hist[:-1]):
  851. if z_hist[t + 1] == 0 and z_hist[t] == 1:
  852. x_end = t
  853. spans.append((x_init, x_end))
  854. elif z_hist[t + 1] == 1 and z_hist[t] == 0:
  855. x_init = t + 1
  856. return spans
  857. # In[16]:
  858. # Plot posterior
  859. def plot_inference(inference_values, z_hist, ax, state=1, map_estimate=False):
  860. n_samples = len(inference_values)
  861. xspan = np.arange(1, n_samples + 1)
  862. spans = find_dishonest_intervals(z_hist)
  863. if map_estimate:
  864. ax.step(xspan, inference_values, where="post")
  865. else:
  866. ax.plot(xspan, inference_values[:, state])
  867. for span in spans:
  868. ax.axvspan(*span, alpha=0.5, facecolor="tab:gray", edgecolor="none")
  869. ax.set_xlim(1, n_samples)
  870. # ax.set_ylim(0, 1)
  871. ax.set_ylim(-0.1, 1.1)
  872. ax.set_xlabel("Observation number")
  873. # In[17]:
  874. # Filtering
  875. fig, ax = plt.subplots()
  876. plot_inference(filtered_dist, z_hist, ax)
  877. ax.set_ylabel("p(loaded)")
  878. ax.set_title("Filtered")
  879. # In[12]:
  880. # Smoothing
  881. fig, ax = plt.subplots()
  882. plot_inference(smoothed_dist, z_hist, ax)
  883. ax.set_ylabel("p(loaded)")
  884. ax.set_title("Smoothed")
  885. # In[ ]:
  886. # MAP estimation
  887. fig, ax = plt.subplots()
  888. plot_inference(map_path, z_hist, ax, map_estimate=True)
  889. ax.set_ylabel("MAP state")
  890. ax.set_title("Viterbi")
  891. # In[ ]:
  892. # TODO: posterior samples
  893. # ## Example: inference in the tracking SSM
  894. #
  895. # We now illustrate filtering, smoothing and MAP decoding applied
  896. # to the 2d tracking HMM from {ref}`sec:tracking-lds`.