hmm.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366
  1. #!/usr/bin/env python
  2. # coding: utf-8
  3. # (sec:hmm-intro)=
  4. # # Hidden Markov Models
  5. #
  6. # In this section, we discuss the
  7. # hidden Markov model or HMM,
  8. # which is a state space model in which the hidden states
  9. # are discrete, so $\hidden_t \in \{1,\ldots, \nstates\}$.
  10. # The observations may be discrete,
  11. # $\obs_t \in \{1,\ldots, \nsymbols\}$,
  12. # or continuous,
  13. # $\obs_t \in \real^\nstates$,
  14. # or some combination,
  15. # as we illustrate below.
  16. # More details can be found in e.g.,
  17. # {cite}`Rabiner89,Fraser08,Cappe05`.
  18. # For an interactive introduction,
  19. # see https://nipunbatra.github.io/hmm/.
  20. # In[1]:
  21. {
  22. "tags": [
  23. "hide-cell"
  24. ]
  25. }
  26. ### Import standard libraries
  27. import abc
  28. from dataclasses import dataclass
  29. import functools
  30. from functools import partial
  31. import itertools
  32. import matplotlib.pyplot as plt
  33. import numpy as np
  34. from typing import Any, Callable, NamedTuple, Optional, Union, Tuple
  35. import jax
  36. import jax.numpy as jnp
  37. from jax import lax, vmap, jit, grad
  38. #from jax.scipy.special import logit
  39. #from jax.nn import softmax
  40. import jax.random as jr
  41. import distrax
  42. import optax
  43. import jsl
  44. import ssm_jax
  45. import inspect
  46. import inspect as py_inspect
  47. import rich
  48. from rich import inspect as r_inspect
  49. from rich import print as r_print
  50. def print_source(fname):
  51. r_print(py_inspect.getsource(fname))
  52. # (sec:casino)=
  53. # ### Example: Casino HMM
  54. #
  55. # To illustrate HMMs with categorical observation model,
  56. # we consider the "Ocassionally dishonest casino" model from {cite}`Durbin98`.
  57. # There are 2 hidden states, representing whether the dice being used in the casino is fair or loaded.
  58. # Each state defines a distribution over the 6 possible observations.
  59. #
  60. # The transition model is denoted by
  61. # ```{math}
  62. # p(\hidden_t=j|\hidden_{t-1}=i) = \hmmTransScalar_{ij}
  63. # ```
  64. # Here the $i$'th row of $\hmmTrans$ corresponds to the outgoing distribution from state $i$.
  65. # This is a row stochastic matrix,
  66. # meaning each row sums to one.
  67. # We can visualize
  68. # the non-zero entries in the transition matrix by creating a state transition diagram,
  69. # as shown in
  70. # {numref}`fig:casino`.
  71. #
  72. # ```{figure} /figures/casino.png
  73. # :scale: 50%
  74. # :name: fig:casino
  75. #
  76. # Illustration of the casino HMM.
  77. # ```
  78. #
  79. # The observation model
  80. # $p(\obs_t|\hidden_t=j)$ has the form
  81. # ```{math}
  82. # p(\obs_t=k|\hidden_t=j) = \hmmObsScalar_{jk}
  83. # ```
  84. # This is represented by the histograms associated with each
  85. # state in {numref}`fig:casino`.
  86. #
  87. # Finally,
  88. # the initial state distribution is denoted by
  89. # ```{math}
  90. # p(\hidden_1=j) = \hmmInitScalar_j
  91. # ```
  92. #
  93. # Collectively we denote all the parameters by $\params=(\hmmTrans, \hmmObs, \hmmInit)$.
  94. #
  95. # Now let us implement this model in code.
  96. # In[2]:
  97. # state transition matrix
  98. A = np.array([
  99. [0.95, 0.05],
  100. [0.10, 0.90]
  101. ])
  102. # observation matrix
  103. B = np.array([
  104. [1/6, 1/6, 1/6, 1/6, 1/6, 1/6], # fair die
  105. [1/10, 1/10, 1/10, 1/10, 1/10, 5/10] # loaded die
  106. ])
  107. pi = np.array([0.5, 0.5])
  108. (nstates, nobs) = np.shape(B)
  109. # In[3]:
  110. import distrax
  111. from distrax import HMM
  112. hmm = HMM(trans_dist=distrax.Categorical(probs=A),
  113. init_dist=distrax.Categorical(probs=pi),
  114. obs_dist=distrax.Categorical(probs=B))
  115. print(hmm)
  116. #
  117. # Let's sample from the model. We will generate a sequence of latent states, $\hidden_{1:T}$,
  118. # which we then convert to a sequence of observations, $\obs_{1:T}$.
  119. # In[4]:
  120. seed = 314
  121. n_samples = 300
  122. z_hist, x_hist = hmm.sample(seed=jr.PRNGKey(seed), seq_len=n_samples)
  123. z_hist_str = "".join((np.array(z_hist) + 1).astype(str))[:60]
  124. x_hist_str = "".join((np.array(x_hist) + 1).astype(str))[:60]
  125. print("Printing sample observed/latent...")
  126. print(f"x: {x_hist_str}")
  127. print(f"z: {z_hist_str}")
  128. # Below is the source code for the sampling algorithm.
  129. #
  130. #
  131. # In[5]:
  132. print_source(hmm.sample)
  133. # Let us check correctness by computing empirical pairwise statistics
  134. #
  135. # We will compute the number of i->j latent state transitions, and check that it is close to the true
  136. # A[i,j] transition probabilites.
  137. # In[6]:
  138. import collections
  139. def compute_counts(state_seq, nstates):
  140. wseq = np.array(state_seq)
  141. word_pairs = [pair for pair in zip(wseq[:-1], wseq[1:])]
  142. counter_pairs = collections.Counter(word_pairs)
  143. counts = np.zeros((nstates, nstates))
  144. for (k,v) in counter_pairs.items():
  145. counts[k[0], k[1]] = v
  146. return counts
  147. def normalize(u, axis=0, eps=1e-15):
  148. u = jnp.where(u == 0, 0, jnp.where(u < eps, eps, u))
  149. c = u.sum(axis=axis)
  150. c = jnp.where(c == 0, 1, c)
  151. return u / c, c
  152. def normalize_counts(counts):
  153. ncounts = vmap(lambda v: normalize(v)[0], in_axes=0)(counts)
  154. return ncounts
  155. init_dist = jnp.array([1.0, 0.0])
  156. trans_mat = jnp.array([[0.7, 0.3], [0.5, 0.5]])
  157. obs_mat = jnp.eye(2)
  158. hmm = HMM(trans_dist=distrax.Categorical(probs=trans_mat),
  159. init_dist=distrax.Categorical(probs=init_dist),
  160. obs_dist=distrax.Categorical(probs=obs_mat))
  161. rng_key = jax.random.PRNGKey(0)
  162. seq_len = 500
  163. state_seq, _ = hmm.sample(seed=PRNGKey(seed), seq_len=seq_len)
  164. counts = compute_counts(state_seq, nstates=2)
  165. print(counts)
  166. trans_mat_empirical = normalize_counts(counts)
  167. print(trans_mat_empirical)
  168. assert jnp.allclose(trans_mat, trans_mat_empirical, atol=1e-1)
  169. # Our primary goal will be to infer the latent state from the observations,
  170. # so we can detect if the casino is being dishonest or not. This will
  171. # affect how we choose to gamble our money.
  172. # We discuss various ways to perform this inference below.
  173. # (sec:lillypad)=
  174. # ## Example: Lillypad HMM
  175. #
  176. #
  177. # If $\obs_t$ is continuous, it is common to use a Gaussian
  178. # observation model:
  179. # ```{math}
  180. # p(\obs_t|\hidden_t=j) = \gauss(\obs_t|\vmu_j,\vSigma_j)
  181. # ```
  182. # This is sometimes called a Gaussian HMM.
  183. #
  184. # As a simple example, suppose we have an HMM with 3 hidden states,
  185. # each of which generates a 2d Gaussian.
  186. # We can represent these Gaussian distributions are 2d ellipses,
  187. # as we show below.
  188. # We call these ``lilly pads'', because of their shape.
  189. # We can imagine a frog hopping from one lilly pad to another.
  190. # (This analogy is due to the late Sam Roweis.)
  191. # The frog will stay on a pad for a while (corresponding to remaining in the same
  192. # discrete state $\hidden_t$), and then jump to a new pad
  193. # (corresponding to a transition to a new state).
  194. # The data we see are just the 2d points (e.g., water droplets)
  195. # coming from near the pad that the frog is currently on.
  196. # Thus this model is like a Gaussian mixture model,
  197. # in that it generates clusters of observations,
  198. # except now there is temporal correlation between the data points.
  199. #
  200. # Let us now illustrate this model in code.
  201. #
  202. #
  203. # In[19]:
  204. # Let us create the model
  205. initial_probs = jnp.array([0.3, 0.2, 0.5])
  206. # transition matrix
  207. A = jnp.array([
  208. [0.3, 0.4, 0.3],
  209. [0.1, 0.6, 0.3],
  210. [0.2, 0.3, 0.5]
  211. ])
  212. # Observation model
  213. mu_collection = jnp.array([
  214. [0.3, 0.3],
  215. [0.8, 0.5],
  216. [0.3, 0.8]
  217. ])
  218. S1 = jnp.array([[1.1, 0], [0, 0.3]])
  219. S2 = jnp.array([[0.3, -0.5], [-0.5, 1.3]])
  220. S3 = jnp.array([[0.8, 0.4], [0.4, 0.5]])
  221. cov_collection = jnp.array([S1, S2, S3]) / 60
  222. import tensorflow_probability as tfp
  223. if False:
  224. hmm = HMM(trans_dist=distrax.Categorical(probs=A),
  225. init_dist=distrax.Categorical(probs=initial_probs),
  226. obs_dist=distrax.MultivariateNormalFullCovariance(
  227. loc=mu_collection, covariance_matrix=cov_collection))
  228. else:
  229. hmm = HMM(trans_dist=distrax.Categorical(probs=A),
  230. init_dist=distrax.Categorical(probs=initial_probs),
  231. obs_dist=distrax.as_distribution(
  232. tfp.substrates.jax.distributions.MultivariateNormalFullCovariance(loc=mu_collection,
  233. covariance_matrix=cov_collection)))
  234. print(hmm)
  235. # In[22]:
  236. n_samples, seed = 50, 10
  237. samples_state, samples_obs = hmm.sample(seed=PRNGKey(seed), seq_len=n_samples)
  238. print(samples_state.shape)
  239. print(samples_obs.shape)
  240. # In[25]:
  241. # Let's plot the observed data in 2d
  242. xmin, xmax = 0, 1
  243. ymin, ymax = 0, 1.2
  244. colors = ["tab:green", "tab:blue", "tab:red"]
  245. def plot_2dhmm(hmm, samples_obs, samples_state, colors, ax, xmin, xmax, ymin, ymax, step=1e-2):
  246. obs_dist = hmm.obs_dist
  247. color_sample = [colors[i] for i in samples_state]
  248. xs = jnp.arange(xmin, xmax, step)
  249. ys = jnp.arange(ymin, ymax, step)
  250. v_prob = vmap(lambda x, y: obs_dist.prob(jnp.array([x, y])), in_axes=(None, 0))
  251. z = vmap(v_prob, in_axes=(0, None))(xs, ys)
  252. grid = np.mgrid[xmin:xmax:step, ymin:ymax:step]
  253. for k, color in enumerate(colors):
  254. ax.contour(*grid, z[:, :, k], levels=[1], colors=color, linewidths=3)
  255. ax.text(*(obs_dist.mean()[k] + 0.13), f"$k$={k + 1}", fontsize=13, horizontalalignment="right")
  256. ax.plot(*samples_obs.T, c="black", alpha=0.3, zorder=1)
  257. ax.scatter(*samples_obs.T, c=color_sample, s=30, zorder=2, alpha=0.8)
  258. return ax, color_sample
  259. fig, ax = plt.subplots()
  260. _, color_sample = plot_2dhmm(hmm, samples_obs, samples_state, colors, ax, xmin, xmax, ymin, ymax)
  261. # In[26]:
  262. # Let's plot the hidden state sequence
  263. fig, ax = plt.subplots()
  264. ax.step(range(n_samples), samples_state, where="post", c="black", linewidth=1, alpha=0.3)
  265. ax.scatter(range(n_samples), samples_state, c=color_sample, zorder=3)