ssm.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625
  1. #!/usr/bin/env python
  2. # coding: utf-8
  3. # In[1]:
  4. # meta-data does not work yet in VScode
  5. # https://github.com/microsoft/vscode-jupyter/issues/1121
  6. {
  7. "tags": [
  8. "hide-cell"
  9. ]
  10. }
  11. ### Install necessary libraries
  12. try:
  13. import jax
  14. except:
  15. # For cuda version, see https://github.com/google/jax#installation
  16. get_ipython().run_line_magic('pip', 'install --upgrade "jax[cpu]"')
  17. import jax
  18. try:
  19. import distrax
  20. except:
  21. get_ipython().run_line_magic('pip', 'install --upgrade distrax')
  22. import distrax
  23. try:
  24. import jsl
  25. except:
  26. get_ipython().run_line_magic('pip', 'install git+https://github.com/probml/jsl')
  27. import jsl
  28. try:
  29. import rich
  30. except:
  31. get_ipython().run_line_magic('pip', 'install rich')
  32. import rich
  33. # In[2]:
  34. {
  35. "tags": [
  36. "hide-cell"
  37. ]
  38. }
  39. ### Import standard libraries
  40. import abc
  41. from dataclasses import dataclass
  42. import functools
  43. import itertools
  44. from typing import Any, Callable, NamedTuple, Optional, Union, Tuple
  45. import matplotlib.pyplot as plt
  46. import numpy as np
  47. import jax
  48. import jax.numpy as jnp
  49. from jax import lax, vmap, jit, grad
  50. from jax.scipy.special import logit
  51. from jax.nn import softmax
  52. from functools import partial
  53. from jax.random import PRNGKey, split
  54. import inspect
  55. import inspect as py_inspect
  56. import rich
  57. from rich import inspect as r_inspect
  58. from rich import print as r_print
  59. def print_source(fname):
  60. r_print(py_inspect.getsource(fname))
  61. # ```{math}
  62. #
  63. # \newcommand\floor[1]{\lfloor#1\rfloor}
  64. #
  65. # \newcommand{\real}{\mathbb{R}}
  66. #
  67. # % Numbers
  68. # \newcommand{\vzero}{\boldsymbol{0}}
  69. # \newcommand{\vone}{\boldsymbol{1}}
  70. #
  71. # % Greek https://www.latex-tutorial.com/symbols/greek-alphabet/
  72. # \newcommand{\valpha}{\boldsymbol{\alpha}}
  73. # \newcommand{\vbeta}{\boldsymbol{\beta}}
  74. # \newcommand{\vchi}{\boldsymbol{\chi}}
  75. # \newcommand{\vdelta}{\boldsymbol{\delta}}
  76. # \newcommand{\vDelta}{\boldsymbol{\Delta}}
  77. # \newcommand{\vepsilon}{\boldsymbol{\epsilon}}
  78. # \newcommand{\vzeta}{\boldsymbol{\zeta}}
  79. # \newcommand{\vXi}{\boldsymbol{\Xi}}
  80. # \newcommand{\vell}{\boldsymbol{\ell}}
  81. # \newcommand{\veta}{\boldsymbol{\eta}}
  82. # %\newcommand{\vEta}{\boldsymbol{\Eta}}
  83. # \newcommand{\vgamma}{\boldsymbol{\gamma}}
  84. # \newcommand{\vGamma}{\boldsymbol{\Gamma}}
  85. # \newcommand{\vmu}{\boldsymbol{\mu}}
  86. # \newcommand{\vmut}{\boldsymbol{\tilde{\mu}}}
  87. # \newcommand{\vnu}{\boldsymbol{\nu}}
  88. # \newcommand{\vkappa}{\boldsymbol{\kappa}}
  89. # \newcommand{\vlambda}{\boldsymbol{\lambda}}
  90. # \newcommand{\vLambda}{\boldsymbol{\Lambda}}
  91. # \newcommand{\vLambdaBar}{\overline{\vLambda}}
  92. # %\newcommand{\vnu}{\boldsymbol{\nu}}
  93. # \newcommand{\vomega}{\boldsymbol{\omega}}
  94. # \newcommand{\vOmega}{\boldsymbol{\Omega}}
  95. # \newcommand{\vphi}{\boldsymbol{\phi}}
  96. # \newcommand{\vvarphi}{\boldsymbol{\varphi}}
  97. # \newcommand{\vPhi}{\boldsymbol{\Phi}}
  98. # \newcommand{\vpi}{\boldsymbol{\pi}}
  99. # \newcommand{\vPi}{\boldsymbol{\Pi}}
  100. # \newcommand{\vpsi}{\boldsymbol{\psi}}
  101. # \newcommand{\vPsi}{\boldsymbol{\Psi}}
  102. # \newcommand{\vrho}{\boldsymbol{\rho}}
  103. # \newcommand{\vtheta}{\boldsymbol{\theta}}
  104. # \newcommand{\vthetat}{\boldsymbol{\tilde{\theta}}}
  105. # \newcommand{\vTheta}{\boldsymbol{\Theta}}
  106. # \newcommand{\vsigma}{\boldsymbol{\sigma}}
  107. # \newcommand{\vSigma}{\boldsymbol{\Sigma}}
  108. # \newcommand{\vSigmat}{\boldsymbol{\tilde{\Sigma}}}
  109. # \newcommand{\vsigmoid}{\vsigma}
  110. # \newcommand{\vtau}{\boldsymbol{\tau}}
  111. # \newcommand{\vxi}{\boldsymbol{\xi}}
  112. #
  113. #
  114. # % Lower Roman (Vectors)
  115. # \newcommand{\va}{\mathbf{a}}
  116. # \newcommand{\vb}{\mathbf{b}}
  117. # \newcommand{\vBt}{\mathbf{\tilde{B}}}
  118. # \newcommand{\vc}{\mathbf{c}}
  119. # \newcommand{\vct}{\mathbf{\tilde{c}}}
  120. # \newcommand{\vd}{\mathbf{d}}
  121. # \newcommand{\ve}{\mathbf{e}}
  122. # \newcommand{\vf}{\mathbf{f}}
  123. # \newcommand{\vg}{\mathbf{g}}
  124. # \newcommand{\vh}{\mathbf{h}}
  125. # %\newcommand{\myvh}{\mathbf{h}}
  126. # \newcommand{\vi}{\mathbf{i}}
  127. # \newcommand{\vj}{\mathbf{j}}
  128. # \newcommand{\vk}{\mathbf{k}}
  129. # \newcommand{\vl}{\mathbf{l}}
  130. # \newcommand{\vm}{\mathbf{m}}
  131. # \newcommand{\vn}{\mathbf{n}}
  132. # \newcommand{\vo}{\mathbf{o}}
  133. # \newcommand{\vp}{\mathbf{p}}
  134. # \newcommand{\vq}{\mathbf{q}}
  135. # \newcommand{\vr}{\mathbf{r}}
  136. # \newcommand{\vs}{\mathbf{s}}
  137. # \newcommand{\vt}{\mathbf{t}}
  138. # \newcommand{\vu}{\mathbf{u}}
  139. # \newcommand{\vv}{\mathbf{v}}
  140. # \newcommand{\vw}{\mathbf{w}}
  141. # \newcommand{\vws}{\vw_s}
  142. # \newcommand{\vwt}{\mathbf{\tilde{w}}}
  143. # \newcommand{\vWt}{\mathbf{\tilde{W}}}
  144. # \newcommand{\vwh}{\hat{\vw}}
  145. # \newcommand{\vx}{\mathbf{x}}
  146. # %\newcommand{\vx}{\mathbf{x}}
  147. # \newcommand{\vxt}{\mathbf{\tilde{x}}}
  148. # \newcommand{\vy}{\mathbf{y}}
  149. # \newcommand{\vyt}{\mathbf{\tilde{y}}}
  150. # \newcommand{\vz}{\mathbf{z}}
  151. # %\newcommand{\vzt}{\mathbf{\tilde{z}}}
  152. #
  153. #
  154. # % Upper Roman (Matrices)
  155. # \newcommand{\vA}{\mathbf{A}}
  156. # \newcommand{\vB}{\mathbf{B}}
  157. # \newcommand{\vC}{\mathbf{C}}
  158. # \newcommand{\vD}{\mathbf{D}}
  159. # \newcommand{\vE}{\mathbf{E}}
  160. # \newcommand{\vF}{\mathbf{F}}
  161. # \newcommand{\vG}{\mathbf{G}}
  162. # \newcommand{\vH}{\mathbf{H}}
  163. # \newcommand{\vI}{\mathbf{I}}
  164. # \newcommand{\vJ}{\mathbf{J}}
  165. # \newcommand{\vK}{\mathbf{K}}
  166. # \newcommand{\vL}{\mathbf{L}}
  167. # \newcommand{\vM}{\mathbf{M}}
  168. # \newcommand{\vMt}{\mathbf{\tilde{M}}}
  169. # \newcommand{\vN}{\mathbf{N}}
  170. # \newcommand{\vO}{\mathbf{O}}
  171. # \newcommand{\vP}{\mathbf{P}}
  172. # \newcommand{\vQ}{\mathbf{Q}}
  173. # \newcommand{\vR}{\mathbf{R}}
  174. # \newcommand{\vS}{\mathbf{S}}
  175. # \newcommand{\vT}{\mathbf{T}}
  176. # \newcommand{\vU}{\mathbf{U}}
  177. # \newcommand{\vV}{\mathbf{V}}
  178. # \newcommand{\vW}{\mathbf{W}}
  179. # \newcommand{\vX}{\mathbf{X}}
  180. # %\newcommand{\vXs}{\vX_{\vs}}
  181. # \newcommand{\vXs}{\vX_{s}}
  182. # \newcommand{\vXt}{\mathbf{\tilde{X}}}
  183. # \newcommand{\vY}{\mathbf{Y}}
  184. # \newcommand{\vZ}{\mathbf{Z}}
  185. # \newcommand{\vZt}{\mathbf{\tilde{Z}}}
  186. # \newcommand{\vzt}{\mathbf{\tilde{z}}}
  187. #
  188. #
  189. # %%%%
  190. # \newcommand{\hidden}{\vz}
  191. # \newcommand{\hid}{\hidden}
  192. # \newcommand{\observed}{\vy}
  193. # \newcommand{\obs}{\observed}
  194. # \newcommand{\inputs}{\vu}
  195. # \newcommand{\input}{\inputs}
  196. #
  197. # \newcommand{\hmmTrans}{\vA}
  198. # \newcommand{\hmmObs}{\vB}
  199. # \newcommand{\hmmInit}{\vpi}
  200. # \newcommand{\hmmhid}{\hidden}
  201. # \newcommand{\hmmobs}{\obs}
  202. #
  203. # \newcommand{\ldsDyn}{\vA}
  204. # \newcommand{\ldsObs}{\vC}
  205. # \newcommand{\ldsDynIn}{\vB}
  206. # \newcommand{\ldsObsIn}{\vD}
  207. # \newcommand{\ldsDynNoise}{\vQ}
  208. # \newcommand{\ldsObsNoise}{\vR}
  209. #
  210. # \newcommand{\ssmDyn}{f}
  211. # \newcommand{\ssmObs}{h}
  212. # ```
  213. #
  214. # (sec:ssm-intro)=
  215. # # What are State Space Models?
  216. #
  217. #
  218. # A state space model or SSM
  219. # is a partially observed Markov model,
  220. # in which the hidden state, $\hidden_t$,
  221. # evolves over time according to a Markov process,
  222. # possibly conditional on external inputs or controls $\input_t$,
  223. # and each hidden state generates some
  224. # observations $\obs_t$ at each time step.
  225. # (In this book, we mostly focus on discrete time systems,
  226. # although we consider the continuous-time case in XXX.)
  227. # We get to see the observations, but not the hidden state.
  228. # Our main goal is to infer the hidden state given the observations.
  229. # However, we can also use the model to predict future observations,
  230. # by first predicting future hidden states, and then predicting
  231. # what observations they might generate.
  232. # By using a hidden state $\hidden_t$
  233. # to represent the past observations, $\obs_{1:t-1}$,
  234. # the model can have ``infinite'' memory,
  235. # unlike a standard Markov model.
  236. #
  237. # Formally we can define an SSM
  238. # as the following joint distribution:
  239. # ```{math}
  240. # :label: SSMfull
  241. # p(\hmmobs_{1:T},\hmmhid_{1:T}|\inputs_{1:T})
  242. # = \left[ p(\hmmhid_1|\inputs_1) \prod_{t=2}^{T}
  243. # p(\hmmhid_t|\hmmhid_{t-1},\inputs_t) \right]
  244. # \left[ \prod_{t=1}^T p(\hmmobs_t|\hmmhid_t, \inputs_t, \hmmobs_{t-1}) \right]
  245. # ```
  246. # where $p(\hmmhid_t|\hmmhid_{t-1},\inputs_t)$ is the
  247. # transition model,
  248. # $p(\hmmobs_t|\hmmhid_t, \inputs_t, \hmmobs_{t-1})$ is the
  249. # observation model,
  250. # and $\inputs_{t}$ is an optional input or action.
  251. # See {numref}`Figure %s <ssm-ar>`
  252. # for an illustration of the corresponding graphical model.
  253. #
  254. #
  255. # ```{figure} /figures/SSM-AR-inputs.png
  256. # :scale: 100%
  257. # :name: ssm-ar
  258. #
  259. # Illustration of an SSM as a graphical model.
  260. # ```
  261. #
  262. # We often consider a simpler setting in which there
  263. # are no external inputs,
  264. # and the observations are conditionally independent of each other
  265. # (rather than having Markovian dependencies) given the hidden state.
  266. # In this case the joint simplifies to
  267. # ```{math}
  268. # :label: SSMsimplified
  269. # p(\hmmobs_{1:T},\hmmhid_{1:T})
  270. # = \left[ p(\hmmhid_1) \prod_{t=2}^{T}
  271. # p(\hmmhid_t|\hmmhid_{t-1}) \right]
  272. # \left[ \prod_{t=1}^T p(\hmmobs_t|\hmmhid_t \right]
  273. # ```
  274. # See {numref}`Figure %s <ssm-simplified>`
  275. # for an illustration of the corresponding graphical model.
  276. # Compare {eq}`SSMfull` and {eq}`SSMsimplified`.
  277. #
  278. #
  279. # ```{figure} /figures/SSM-simplified.png
  280. # :scale: 100%
  281. # :name: ssm-simplified
  282. #
  283. # Illustration of a simplified SSM.
  284. # ```
  285. #
  286. #
  287. # (sec:hmm-intro)=
  288. # # Hidden Markov Models
  289. #
  290. # In this section, we discuss the
  291. # hidden Markov model or HMM,
  292. # which is a state space model in which the hidden states
  293. # are discrete, so $\hmmhid_t \in \{1,\ldots, K\}$.
  294. # The observations may be discrete,
  295. # $\hmmobs_t \in \{1,\ldots, C\}$,
  296. # or continuous,
  297. # $\hmmobs_t \in \real^D$,
  298. # or some combination,
  299. # as we illustrate below.
  300. # More details can be found in e.g.,
  301. # {cite}`Rabiner89,Fraser08,Cappe05`.
  302. # For an interactive introduction,
  303. # see https://nipunbatra.github.io/hmm/.
  304. # (sec:casino)=
  305. # ### Example: Casino HMM
  306. #
  307. # To illustrate HMMs with categorical observation model,
  308. # we consider the "Ocassionally dishonest casino" model from {cite}`Durbin98`.
  309. # There are 2 hidden states, representing whether the dice being used in the casino is fair or loaded.
  310. # Each state defines a distribution over the 6 possible observations.
  311. #
  312. # The transition model is denoted by
  313. # ```{math}
  314. # p(z_t=j|z_{t-1}=i) = \hmmTrans_{ij}
  315. # ```
  316. # Here the $i$'th row of $\vA$ corresponds to the outgoing distribution from state $i$.
  317. # This is a row stochastic matrix,
  318. # meaning each row sums to one.
  319. # We can visualize
  320. # the non-zero entries in the transition matrix by creating a state transition diagram,
  321. # as shown in
  322. # {numref}`Figure %s <casino-fig>`
  323. # %{ref}`casino-fig`.
  324. #
  325. # ```{figure} /figures/casino.png
  326. # :scale: 50%
  327. # :name: casino-fig
  328. #
  329. # Illustration of the casino HMM.
  330. # ```
  331. #
  332. # The observation model
  333. # $p(\obs_t|\hidden_t=j)$ has the form
  334. # ```{math}
  335. # p(\obs_t=k|\hidden_t=j) = \hmmObs_{jk}
  336. # ```
  337. # This is represented by the histograms associated with each
  338. # state in {ref}`casino-fig`.
  339. #
  340. # Finally,
  341. # the initial state distribution is denoted by
  342. # ```{math}
  343. # p(z_1=j) = \hmmInit_j
  344. # ```
  345. #
  346. # Collectively we denote all the parameters by $\vtheta=(\hmmTrans, \hmmObs, \hmmInit)$.
  347. #
  348. # Now let us implement this model in code.
  349. # In[3]:
  350. # state transition matrix
  351. A = np.array([
  352. [0.95, 0.05],
  353. [0.10, 0.90]
  354. ])
  355. # observation matrix
  356. B = np.array([
  357. [1/6, 1/6, 1/6, 1/6, 1/6, 1/6], # fair die
  358. [1/10, 1/10, 1/10, 1/10, 1/10, 5/10] # loaded die
  359. ])
  360. pi = np.array([0.5, 0.5])
  361. (nstates, nobs) = np.shape(B)
  362. # In[4]:
  363. import distrax
  364. from distrax import HMM
  365. hmm = HMM(trans_dist=distrax.Categorical(probs=A),
  366. init_dist=distrax.Categorical(probs=pi),
  367. obs_dist=distrax.Categorical(probs=B))
  368. print(hmm)
  369. #
  370. # Let's sample from the model. We will generate a sequence of latent states, $\hid_{1:T}$,
  371. # which we then convert to a sequence of observations, $\obs_{1:T}$.
  372. # In[5]:
  373. seed = 314
  374. n_samples = 300
  375. z_hist, x_hist = hmm.sample(seed=PRNGKey(seed), seq_len=n_samples)
  376. z_hist_str = "".join((np.array(z_hist) + 1).astype(str))[:60]
  377. x_hist_str = "".join((np.array(x_hist) + 1).astype(str))[:60]
  378. print("Printing sample observed/latent...")
  379. print(f"x: {x_hist_str}")
  380. print(f"z: {z_hist_str}")
  381. # In[6]:
  382. # Here is the source code for the sampling algorithm.
  383. print_source(hmm.sample)
  384. # Our primary goal will be to infer the latent state from the observations,
  385. # so we can detect if the casino is being dishonest or not. This will
  386. # affect how we choose to gamble our money.
  387. # We discuss various ways to perform this inference below.
  388. # # Linear Gaussian SSMs
  389. #
  390. # Blah blah
  391. # (sec:tracking-lds)=
  392. # ## Example: model for 2d tracking
  393. #
  394. # Blah blah
  395. # (sec:inference)=
  396. # # Inferential goals
  397. #
  398. # ```{figure} /figures/dbn-inference-problems.png
  399. # :scale: 100%
  400. # :name: dbn-inference
  401. #
  402. # Illustration of the different kinds of inference in an SSM.
  403. # The main kinds of inference for state-space models.
  404. # The shaded region is the interval for which we have data.
  405. # The arrow represents the time step at which we want to perform inference.
  406. # $t$ is the current time, $T$ is the sequence length,
  407. # $\ell$ is the lag and $h$ is the prediction horizon.
  408. # ```
  409. #
  410. #
  411. #
  412. # Given the sequence of observations, and a known model,
  413. # one of the main tasks with SSMs
  414. # to perform posterior inference,
  415. # about the hidden states; this is also called
  416. # state estimation.
  417. # At each time step $t$,
  418. # there are multiple forms of posterior we may be interested in computing,
  419. # including the following:
  420. # - the filtering distribution
  421. # $p(\hmmhid_t|\hmmobs_{1:t})$
  422. # - the smoothing distribution
  423. # $p(\hmmhid_t|\hmmobs_{1:T})$ (note that this conditions on future data $T>t$)
  424. # - the fixed-lag smoothing distribution
  425. # $p(\hmmhid_{t-\ell}|\hmmobs_{1:t})$ (note that this
  426. # infers $\ell$ steps in the past given data up to the present).
  427. #
  428. # We may also want to compute the
  429. # predictive distribution $h$ steps into the future:
  430. # \begin{align}
  431. # p(\hmmobs_{t+h}|\hmmobs_{1:t})
  432. # &= \sum_{\hmmhid_{t+h}} p(\hmmobs_{t+h}|\hmmhid_{t+h}) p(\hmmhid_{t+h}|\hmmobs_{1:t})
  433. # \end{align}
  434. # where the hidden state predictive distribution is
  435. # \begin{align}
  436. # p(\hmmhid_{t+h}|\hmmobs_{1:t})
  437. # &= \sum_{\hmmhid_{t:t+h-1}}
  438. # p(\hmmhid_t|\hmmobs_{1:t})
  439. # p(\hmmhid_{t+1}|\hmmhid_{t})
  440. # p(\hmmhid_{t+2}|\hmmhid_{t+1})
  441. # \cdots
  442. # p(\hmmhid_{t+h}|\hmmhid_{t+h-1})
  443. # \end{align}
  444. # See \cref{fig:dbn_inf_problems} for a summary of these distributions.
  445. #
  446. # In addition to comuting posterior marginals,
  447. # we may want to compute the most probable hidden sequence,
  448. # i.e., the joint MAP estimate
  449. # ```{math}
  450. # \arg \max_{\hmmhid_{1:T}} p(\hmmhid_{1:T}|\hmmobs_{1:T})
  451. # ```
  452. # or sample sequences from the posterior
  453. # ```{math}
  454. # \hmmhid_{1:T} \sim p(\hmmhid_{1:T}|\hmmobs_{1:T})
  455. # ```
  456. #
  457. # Algorithms for all these task are discussed in the following chapters,
  458. # since the details depend on the form of the SSM.
  459. #
  460. #
  461. #
  462. #
  463. #
  464. # ## Example: inference in the casino HMM
  465. #
  466. # We now illustrate filtering, smoothing and MAP decoding applied
  467. # to the casino HMM from {ref}`sec:casino`.
  468. #
  469. # In[7]:
  470. # Call inference engine
  471. filtered_dist, _, smoothed_dist, loglik = hmm.forward_backward(x_hist)
  472. map_path = hmm.viterbi(x_hist)
  473. # In[8]:
  474. # Find the span of timesteps that the simulated systems turns to be in state 1
  475. def find_dishonest_intervals(z_hist):
  476. spans = []
  477. x_init = 0
  478. for t, _ in enumerate(z_hist[:-1]):
  479. if z_hist[t + 1] == 0 and z_hist[t] == 1:
  480. x_end = t
  481. spans.append((x_init, x_end))
  482. elif z_hist[t + 1] == 1 and z_hist[t] == 0:
  483. x_init = t + 1
  484. return spans
  485. # In[9]:
  486. # Plot posterior
  487. def plot_inference(inference_values, z_hist, ax, state=1, map_estimate=False):
  488. n_samples = len(inference_values)
  489. xspan = np.arange(1, n_samples + 1)
  490. spans = find_dishonest_intervals(z_hist)
  491. if map_estimate:
  492. ax.step(xspan, inference_values, where="post")
  493. else:
  494. ax.plot(xspan, inference_values[:, state])
  495. for span in spans:
  496. ax.axvspan(*span, alpha=0.5, facecolor="tab:gray", edgecolor="none")
  497. ax.set_xlim(1, n_samples)
  498. # ax.set_ylim(0, 1)
  499. ax.set_ylim(-0.1, 1.1)
  500. ax.set_xlabel("Observation number")
  501. # In[10]:
  502. # Filtering
  503. fig, ax = plt.subplots()
  504. plot_inference(filtered_dist, z_hist, ax)
  505. ax.set_ylabel("p(loaded)")
  506. ax.set_title("Filtered")
  507. # In[11]:
  508. # Smoothing
  509. fig, ax = plt.subplots()
  510. plot_inference(smoothed_dist, z_hist, ax)
  511. ax.set_ylabel("p(loaded)")
  512. ax.set_title("Smoothed")
  513. # In[12]:
  514. # MAP estimation
  515. fig, ax = plt.subplots()
  516. plot_inference(map_path, z_hist, ax, map_estimate=True)
  517. ax.set_ylabel("MAP state")
  518. ax.set_title("Viterbi")
  519. # In[13]:
  520. # TODO: posterior samples
  521. # ## Example: inference in the tracking SSM
  522. #
  523. # We now illustrate filtering, smoothing and MAP decoding applied
  524. # to the 2d tracking HMM from {ref}`sec:tracking-lds`.