hmm.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563
  1. #!/usr/bin/env python
  2. # coding: utf-8
  3. # In[1]:
  4. # meta-data does not work yet in VScode
  5. # https://github.com/microsoft/vscode-jupyter/issues/1121
  6. {
  7. "tags": [
  8. "hide-cell"
  9. ]
  10. }
  11. ### Install necessary libraries
  12. try:
  13. import jax
  14. except:
  15. # For cuda version, see https://github.com/google/jax#installation
  16. get_ipython().run_line_magic('pip', 'install --upgrade "jax[cpu]"')
  17. import jax
  18. try:
  19. import distrax
  20. except:
  21. get_ipython().run_line_magic('pip', 'install --upgrade distrax')
  22. import distrax
  23. try:
  24. import jsl
  25. except:
  26. get_ipython().run_line_magic('pip', 'install git+https://github.com/probml/jsl')
  27. import jsl
  28. try:
  29. import rich
  30. except:
  31. get_ipython().run_line_magic('pip', 'install rich')
  32. import rich
  33. # In[2]:
  34. {
  35. "tags": [
  36. "hide-cell"
  37. ]
  38. }
  39. ### Import standard libraries
  40. import abc
  41. from dataclasses import dataclass
  42. import functools
  43. import itertools
  44. from typing import Any, Callable, NamedTuple, Optional, Union, Tuple
  45. import matplotlib.pyplot as plt
  46. import numpy as np
  47. import jax
  48. import jax.numpy as jnp
  49. from jax import lax, vmap, jit, grad
  50. from jax.scipy.special import logit
  51. from jax.nn import softmax
  52. from functools import partial
  53. from jax.random import PRNGKey, split
  54. import inspect
  55. import inspect as py_inspect
  56. import rich
  57. from rich import inspect as r_inspect
  58. from rich import print as r_print
  59. def print_source(fname):
  60. r_print(py_inspect.getsource(fname))
  61. # ```{math}
  62. #
  63. # \newcommand\floor[1]{\lfloor#1\rfloor}
  64. #
  65. # \newcommand{\real}{\mathbb{R}}
  66. #
  67. # % Numbers
  68. # \newcommand{\vzero}{\boldsymbol{0}}
  69. # \newcommand{\vone}{\boldsymbol{1}}
  70. #
  71. # % Greek https://www.latex-tutorial.com/symbols/greek-alphabet/
  72. # \newcommand{\valpha}{\boldsymbol{\alpha}}
  73. # \newcommand{\vbeta}{\boldsymbol{\beta}}
  74. # \newcommand{\vchi}{\boldsymbol{\chi}}
  75. # \newcommand{\vdelta}{\boldsymbol{\delta}}
  76. # \newcommand{\vDelta}{\boldsymbol{\Delta}}
  77. # \newcommand{\vepsilon}{\boldsymbol{\epsilon}}
  78. # \newcommand{\vzeta}{\boldsymbol{\zeta}}
  79. # \newcommand{\vXi}{\boldsymbol{\Xi}}
  80. # \newcommand{\vell}{\boldsymbol{\ell}}
  81. # \newcommand{\veta}{\boldsymbol{\eta}}
  82. # %\newcommand{\vEta}{\boldsymbol{\Eta}}
  83. # \newcommand{\vgamma}{\boldsymbol{\gamma}}
  84. # \newcommand{\vGamma}{\boldsymbol{\Gamma}}
  85. # \newcommand{\vmu}{\boldsymbol{\mu}}
  86. # \newcommand{\vmut}{\boldsymbol{\tilde{\mu}}}
  87. # \newcommand{\vnu}{\boldsymbol{\nu}}
  88. # \newcommand{\vkappa}{\boldsymbol{\kappa}}
  89. # \newcommand{\vlambda}{\boldsymbol{\lambda}}
  90. # \newcommand{\vLambda}{\boldsymbol{\Lambda}}
  91. # \newcommand{\vLambdaBar}{\overline{\vLambda}}
  92. # %\newcommand{\vnu}{\boldsymbol{\nu}}
  93. # \newcommand{\vomega}{\boldsymbol{\omega}}
  94. # \newcommand{\vOmega}{\boldsymbol{\Omega}}
  95. # \newcommand{\vphi}{\boldsymbol{\phi}}
  96. # \newcommand{\vvarphi}{\boldsymbol{\varphi}}
  97. # \newcommand{\vPhi}{\boldsymbol{\Phi}}
  98. # \newcommand{\vpi}{\boldsymbol{\pi}}
  99. # \newcommand{\vPi}{\boldsymbol{\Pi}}
  100. # \newcommand{\vpsi}{\boldsymbol{\psi}}
  101. # \newcommand{\vPsi}{\boldsymbol{\Psi}}
  102. # \newcommand{\vrho}{\boldsymbol{\rho}}
  103. # \newcommand{\vtheta}{\boldsymbol{\theta}}
  104. # \newcommand{\vthetat}{\boldsymbol{\tilde{\theta}}}
  105. # \newcommand{\vTheta}{\boldsymbol{\Theta}}
  106. # \newcommand{\vsigma}{\boldsymbol{\sigma}}
  107. # \newcommand{\vSigma}{\boldsymbol{\Sigma}}
  108. # \newcommand{\vSigmat}{\boldsymbol{\tilde{\Sigma}}}
  109. # \newcommand{\vsigmoid}{\vsigma}
  110. # \newcommand{\vtau}{\boldsymbol{\tau}}
  111. # \newcommand{\vxi}{\boldsymbol{\xi}}
  112. #
  113. #
  114. # % Lower Roman (Vectors)
  115. # \newcommand{\va}{\mathbf{a}}
  116. # \newcommand{\vb}{\mathbf{b}}
  117. # \newcommand{\vBt}{\mathbf{\tilde{B}}}
  118. # \newcommand{\vc}{\mathbf{c}}
  119. # \newcommand{\vct}{\mathbf{\tilde{c}}}
  120. # \newcommand{\vd}{\mathbf{d}}
  121. # \newcommand{\ve}{\mathbf{e}}
  122. # \newcommand{\vf}{\mathbf{f}}
  123. # \newcommand{\vg}{\mathbf{g}}
  124. # \newcommand{\vh}{\mathbf{h}}
  125. # %\newcommand{\myvh}{\mathbf{h}}
  126. # \newcommand{\vi}{\mathbf{i}}
  127. # \newcommand{\vj}{\mathbf{j}}
  128. # \newcommand{\vk}{\mathbf{k}}
  129. # \newcommand{\vl}{\mathbf{l}}
  130. # \newcommand{\vm}{\mathbf{m}}
  131. # \newcommand{\vn}{\mathbf{n}}
  132. # \newcommand{\vo}{\mathbf{o}}
  133. # \newcommand{\vp}{\mathbf{p}}
  134. # \newcommand{\vq}{\mathbf{q}}
  135. # \newcommand{\vr}{\mathbf{r}}
  136. # \newcommand{\vs}{\mathbf{s}}
  137. # \newcommand{\vt}{\mathbf{t}}
  138. # \newcommand{\vu}{\mathbf{u}}
  139. # \newcommand{\vv}{\mathbf{v}}
  140. # \newcommand{\vw}{\mathbf{w}}
  141. # \newcommand{\vws}{\vw_s}
  142. # \newcommand{\vwt}{\mathbf{\tilde{w}}}
  143. # \newcommand{\vWt}{\mathbf{\tilde{W}}}
  144. # \newcommand{\vwh}{\hat{\vw}}
  145. # \newcommand{\vx}{\mathbf{x}}
  146. # %\newcommand{\vx}{\mathbf{x}}
  147. # \newcommand{\vxt}{\mathbf{\tilde{x}}}
  148. # \newcommand{\vy}{\mathbf{y}}
  149. # \newcommand{\vyt}{\mathbf{\tilde{y}}}
  150. # \newcommand{\vz}{\mathbf{z}}
  151. # %\newcommand{\vzt}{\mathbf{\tilde{z}}}
  152. #
  153. #
  154. # % Upper Roman (Matrices)
  155. # \newcommand{\vA}{\mathbf{A}}
  156. # \newcommand{\vB}{\mathbf{B}}
  157. # \newcommand{\vC}{\mathbf{C}}
  158. # \newcommand{\vD}{\mathbf{D}}
  159. # \newcommand{\vE}{\mathbf{E}}
  160. # \newcommand{\vF}{\mathbf{F}}
  161. # \newcommand{\vG}{\mathbf{G}}
  162. # \newcommand{\vH}{\mathbf{H}}
  163. # \newcommand{\vI}{\mathbf{I}}
  164. # \newcommand{\vJ}{\mathbf{J}}
  165. # \newcommand{\vK}{\mathbf{K}}
  166. # \newcommand{\vL}{\mathbf{L}}
  167. # \newcommand{\vM}{\mathbf{M}}
  168. # \newcommand{\vMt}{\mathbf{\tilde{M}}}
  169. # \newcommand{\vN}{\mathbf{N}}
  170. # \newcommand{\vO}{\mathbf{O}}
  171. # \newcommand{\vP}{\mathbf{P}}
  172. # \newcommand{\vQ}{\mathbf{Q}}
  173. # \newcommand{\vR}{\mathbf{R}}
  174. # \newcommand{\vS}{\mathbf{S}}
  175. # \newcommand{\vT}{\mathbf{T}}
  176. # \newcommand{\vU}{\mathbf{U}}
  177. # \newcommand{\vV}{\mathbf{V}}
  178. # \newcommand{\vW}{\mathbf{W}}
  179. # \newcommand{\vX}{\mathbf{X}}
  180. # %\newcommand{\vXs}{\vX_{\vs}}
  181. # \newcommand{\vXs}{\vX_{s}}
  182. # \newcommand{\vXt}{\mathbf{\tilde{X}}}
  183. # \newcommand{\vY}{\mathbf{Y}}
  184. # \newcommand{\vZ}{\mathbf{Z}}
  185. # \newcommand{\vZt}{\mathbf{\tilde{Z}}}
  186. # \newcommand{\vzt}{\mathbf{\tilde{z}}}
  187. #
  188. #
  189. # %%%%
  190. # \newcommand{\hidden}{\vz}
  191. # \newcommand{\hid}{\hidden}
  192. # \newcommand{\observed}{\vy}
  193. # \newcommand{\obs}{\observed}
  194. # \newcommand{\inputs}{\vu}
  195. # \newcommand{\input}{\inputs}
  196. #
  197. # \newcommand{\hmmTrans}{\vA}
  198. # \newcommand{\hmmObs}{\vB}
  199. # \newcommand{\hmmInit}{\vpi}
  200. # \newcommand{\hmmhid}{\hidden}
  201. # \newcommand{\hmmobs}{\obs}
  202. #
  203. # \newcommand{\ldsDyn}{\vA}
  204. # \newcommand{\ldsObs}{\vC}
  205. # \newcommand{\ldsDynIn}{\vB}
  206. # \newcommand{\ldsObsIn}{\vD}
  207. # \newcommand{\ldsDynNoise}{\vQ}
  208. # \newcommand{\ldsObsNoise}{\vR}
  209. #
  210. # \newcommand{\ssmDynFn}{f}
  211. # \newcommand{\ssmObsFn}{h}
  212. #
  213. #
  214. # %%%
  215. # \newcommand{\gauss}{\mathcal{N}}
  216. #
  217. # \newcommand{\diag}{\mathrm{diag}}
  218. # ```
  219. #
  220. # (sec:hmm-intro)=
  221. # # Hidden Markov Models
  222. #
  223. # In this section, we discuss the
  224. # hidden Markov model or HMM,
  225. # which is a state space model in which the hidden states
  226. # are discrete, so $\hmmhid_t \in \{1,\ldots, K\}$.
  227. # The observations may be discrete,
  228. # $\hmmobs_t \in \{1,\ldots, C\}$,
  229. # or continuous,
  230. # $\hmmobs_t \in \real^D$,
  231. # or some combination,
  232. # as we illustrate below.
  233. # More details can be found in e.g.,
  234. # {cite}`Rabiner89,Fraser08,Cappe05`.
  235. # For an interactive introduction,
  236. # see https://nipunbatra.github.io/hmm/.
  237. # (sec:casino)=
  238. # ### Example: Casino HMM
  239. #
  240. # To illustrate HMMs with categorical observation model,
  241. # we consider the "Ocassionally dishonest casino" model from {cite}`Durbin98`.
  242. # There are 2 hidden states, representing whether the dice being used in the casino is fair or loaded.
  243. # Each state defines a distribution over the 6 possible observations.
  244. #
  245. # The transition model is denoted by
  246. # ```{math}
  247. # p(z_t=j|z_{t-1}=i) = \hmmTrans_{ij}
  248. # ```
  249. # Here the $i$'th row of $\vA$ corresponds to the outgoing distribution from state $i$.
  250. # This is a row stochastic matrix,
  251. # meaning each row sums to one.
  252. # We can visualize
  253. # the non-zero entries in the transition matrix by creating a state transition diagram,
  254. # as shown in
  255. # {numref}`fig:casino`.
  256. #
  257. # ```{figure} /figures/casino.png
  258. # :scale: 50%
  259. # :name: fig:casino
  260. #
  261. # Illustration of the casino HMM.
  262. # ```
  263. #
  264. # The observation model
  265. # $p(\obs_t|\hidden_t=j)$ has the form
  266. # ```{math}
  267. # p(\obs_t=k|\hidden_t=j) = \hmmObs_{jk}
  268. # ```
  269. # This is represented by the histograms associated with each
  270. # state in {numref}`casino-fig`.
  271. #
  272. # Finally,
  273. # the initial state distribution is denoted by
  274. # ```{math}
  275. # p(z_1=j) = \hmmInit_j
  276. # ```
  277. #
  278. # Collectively we denote all the parameters by $\vtheta=(\hmmTrans, \hmmObs, \hmmInit)$.
  279. #
  280. # Now let us implement this model in code.
  281. # In[3]:
  282. # state transition matrix
  283. A = np.array([
  284. [0.95, 0.05],
  285. [0.10, 0.90]
  286. ])
  287. # observation matrix
  288. B = np.array([
  289. [1/6, 1/6, 1/6, 1/6, 1/6, 1/6], # fair die
  290. [1/10, 1/10, 1/10, 1/10, 1/10, 5/10] # loaded die
  291. ])
  292. pi = np.array([0.5, 0.5])
  293. (nstates, nobs) = np.shape(B)
  294. # In[4]:
  295. import distrax
  296. from distrax import HMM
  297. hmm = HMM(trans_dist=distrax.Categorical(probs=A),
  298. init_dist=distrax.Categorical(probs=pi),
  299. obs_dist=distrax.Categorical(probs=B))
  300. print(hmm)
  301. #
  302. # Let's sample from the model. We will generate a sequence of latent states, $\hid_{1:T}$,
  303. # which we then convert to a sequence of observations, $\obs_{1:T}$.
  304. # In[5]:
  305. seed = 314
  306. n_samples = 300
  307. z_hist, x_hist = hmm.sample(seed=PRNGKey(seed), seq_len=n_samples)
  308. z_hist_str = "".join((np.array(z_hist) + 1).astype(str))[:60]
  309. x_hist_str = "".join((np.array(x_hist) + 1).astype(str))[:60]
  310. print("Printing sample observed/latent...")
  311. print(f"x: {x_hist_str}")
  312. print(f"z: {z_hist_str}")
  313. # Below is the source code for the sampling algorithm.
  314. #
  315. #
  316. # In[6]:
  317. print_source(hmm.sample)
  318. # Let us check correctness by computing empirical pairwise statistics
  319. #
  320. # We will compute the number of i->j latent state transitions, and check that it is close to the true
  321. # A[i,j] transition probabilites.
  322. # In[7]:
  323. import collections
  324. def compute_counts(state_seq, nstates):
  325. wseq = np.array(state_seq)
  326. word_pairs = [pair for pair in zip(wseq[:-1], wseq[1:])]
  327. counter_pairs = collections.Counter(word_pairs)
  328. counts = np.zeros((nstates, nstates))
  329. for (k,v) in counter_pairs.items():
  330. counts[k[0], k[1]] = v
  331. return counts
  332. def normalize(u, axis=0, eps=1e-15):
  333. u = jnp.where(u == 0, 0, jnp.where(u < eps, eps, u))
  334. c = u.sum(axis=axis)
  335. c = jnp.where(c == 0, 1, c)
  336. return u / c, c
  337. def normalize_counts(counts):
  338. ncounts = vmap(lambda v: normalize(v)[0], in_axes=0)(counts)
  339. return ncounts
  340. init_dist = jnp.array([1.0, 0.0])
  341. trans_mat = jnp.array([[0.7, 0.3], [0.5, 0.5]])
  342. obs_mat = jnp.eye(2)
  343. hmm = HMM(trans_dist=distrax.Categorical(probs=trans_mat),
  344. init_dist=distrax.Categorical(probs=init_dist),
  345. obs_dist=distrax.Categorical(probs=obs_mat))
  346. rng_key = jax.random.PRNGKey(0)
  347. seq_len = 500
  348. state_seq, _ = hmm.sample(seed=PRNGKey(seed), seq_len=seq_len)
  349. counts = compute_counts(state_seq, nstates=2)
  350. print(counts)
  351. trans_mat_empirical = normalize_counts(counts)
  352. print(trans_mat_empirical)
  353. assert jnp.allclose(trans_mat, trans_mat_empirical, atol=1e-1)
  354. # Our primary goal will be to infer the latent state from the observations,
  355. # so we can detect if the casino is being dishonest or not. This will
  356. # affect how we choose to gamble our money.
  357. # We discuss various ways to perform this inference below.
  358. # (sec:lillypad)=
  359. # ## Example: Lillypad HMM
  360. #
  361. #
  362. # If $\obs_t$ is continuous, it is common to use a Gaussian
  363. # observation model:
  364. # ```{math}
  365. # p(\obs_t|\hidden_t=j) = \gauss(\obs_t|\vmu_j,\vSigma_j)
  366. # ```
  367. # This is sometimes called a Gaussian HMM.
  368. #
  369. # As a simple example, suppose we have an HMM with 3 hidden states,
  370. # each of which generates a 2d Gaussian.
  371. # We can represent these Gaussian distributions are 2d ellipses,
  372. # as we show below.
  373. # We call these ``lilly pads'', because of their shape.
  374. # We can imagine a frog hopping from one lilly pad to another.
  375. # (This analogy is due to the late Sam Roweis.)
  376. # The frog will stay on a pad for a while (corresponding to remaining in the same
  377. # discrete state $\hidden_t$), and then jump to a new pad
  378. # (corresponding to a transition to a new state).
  379. # The data we see are just the 2d points (e.g., water droplets)
  380. # coming from near the pad that the frog is currently on.
  381. # Thus this model is like a Gaussian mixture model,
  382. # in that it generates clusters of observations,
  383. # except now there is temporal correlation between the data points.
  384. #
  385. # Let us now illustrate this model in code.
  386. #
  387. #
  388. # In[8]:
  389. # Let us create the model
  390. initial_probs = jnp.array([0.3, 0.2, 0.5])
  391. # transition matrix
  392. A = jnp.array([
  393. [0.3, 0.4, 0.3],
  394. [0.1, 0.6, 0.3],
  395. [0.2, 0.3, 0.5]
  396. ])
  397. # Observation model
  398. mu_collection = jnp.array([
  399. [0.3, 0.3],
  400. [0.8, 0.5],
  401. [0.3, 0.8]
  402. ])
  403. S1 = jnp.array([[1.1, 0], [0, 0.3]])
  404. S2 = jnp.array([[0.3, -0.5], [-0.5, 1.3]])
  405. S3 = jnp.array([[0.8, 0.4], [0.4, 0.5]])
  406. cov_collection = jnp.array([S1, S2, S3]) / 60
  407. import tensorflow_probability as tfp
  408. if False:
  409. hmm = HMM(trans_dist=distrax.Categorical(probs=A),
  410. init_dist=distrax.Categorical(probs=initial_probs),
  411. obs_dist=distrax.MultivariateNormalFullCovariance(
  412. loc=mu_collection, covariance_matrix=cov_collection))
  413. else:
  414. hmm = HMM(trans_dist=distrax.Categorical(probs=A),
  415. init_dist=distrax.Categorical(probs=initial_probs),
  416. obs_dist=distrax.as_distribution(
  417. tfp.substrates.jax.distributions.MultivariateNormalFullCovariance(loc=mu_collection,
  418. covariance_matrix=cov_collection)))
  419. print(hmm)
  420. # In[9]:
  421. n_samples, seed = 50, 10
  422. samples_state, samples_obs = hmm.sample(seed=PRNGKey(seed), seq_len=n_samples)
  423. print(samples_state.shape)
  424. print(samples_obs.shape)
  425. # In[10]:
  426. # Let's plot the observed data in 2d
  427. xmin, xmax = 0, 1
  428. ymin, ymax = 0, 1.2
  429. colors = ["tab:green", "tab:blue", "tab:red"]
  430. def plot_2dhmm(hmm, samples_obs, samples_state, colors, ax, xmin, xmax, ymin, ymax, step=1e-2):
  431. obs_dist = hmm.obs_dist
  432. color_sample = [colors[i] for i in samples_state]
  433. xs = jnp.arange(xmin, xmax, step)
  434. ys = jnp.arange(ymin, ymax, step)
  435. v_prob = vmap(lambda x, y: obs_dist.prob(jnp.array([x, y])), in_axes=(None, 0))
  436. z = vmap(v_prob, in_axes=(0, None))(xs, ys)
  437. grid = np.mgrid[xmin:xmax:step, ymin:ymax:step]
  438. for k, color in enumerate(colors):
  439. ax.contour(*grid, z[:, :, k], levels=[1], colors=color, linewidths=3)
  440. ax.text(*(obs_dist.mean()[k] + 0.13), f"$k$={k + 1}", fontsize=13, horizontalalignment="right")
  441. ax.plot(*samples_obs.T, c="black", alpha=0.3, zorder=1)
  442. ax.scatter(*samples_obs.T, c=color_sample, s=30, zorder=2, alpha=0.8)
  443. return ax, color_sample
  444. fig, ax = plt.subplots()
  445. _, color_sample = plot_2dhmm(hmm, samples_obs, samples_state, colors, ax, xmin, xmax, ymin, ymax)
  446. # In[11]:
  447. # Let's plot the hidden state sequence
  448. fig, ax = plt.subplots()
  449. ax.step(range(n_samples), samples_state, where="post", c="black", linewidth=1, alpha=0.3)
  450. ax.scatter(range(n_samples), samples_state, c=color_sample, zorder=3)