hmm.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. #!/usr/bin/env python
  2. # coding: utf-8
  3. # (sec:hmm-ex)=
  4. # # Hidden Markov Models
  5. #
  6. # In this section, we introduce Hidden Markov Models (HMMs).
  7. # ## Boilerplate
  8. # In[1]:
  9. # Install necessary libraries
  10. try:
  11. import jax
  12. except:
  13. # For cuda version, see https://github.com/google/jax#installation
  14. get_ipython().run_line_magic('pip', 'install --upgrade "jax[cpu]"')
  15. import jax
  16. try:
  17. import jsl
  18. except:
  19. get_ipython().run_line_magic('pip', 'install git+https://github.com/probml/jsl')
  20. import jsl
  21. try:
  22. import rich
  23. except:
  24. get_ipython().run_line_magic('pip', 'install rich')
  25. import rich
  26. # In[2]:
  27. # Import standard libraries
  28. import abc
  29. from dataclasses import dataclass
  30. import functools
  31. import itertools
  32. from typing import Any, Callable, NamedTuple, Optional, Union, Tuple
  33. import matplotlib.pyplot as plt
  34. import numpy as np
  35. import jax
  36. import jax.numpy as jnp
  37. from jax import lax, vmap, jit, grad
  38. from jax.scipy.special import logit
  39. from jax.nn import softmax
  40. from functools import partial
  41. from jax.random import PRNGKey, split
  42. import inspect
  43. import inspect as py_inspect
  44. from rich import inspect as r_inspect
  45. from rich import print as r_print
  46. def print_source(fname):
  47. r_print(py_inspect.getsource(fname))
  48. # ## Utility code
  49. # In[3]:
  50. def normalize(u, axis=0, eps=1e-15):
  51. '''
  52. Normalizes the values within the axis in a way that they sum up to 1.
  53. Parameters
  54. ----------
  55. u : array
  56. axis : int
  57. eps : float
  58. Threshold for the alpha values
  59. Returns
  60. -------
  61. * array
  62. Normalized version of the given matrix
  63. * array(seq_len, n_hidden) :
  64. The values of the normalizer
  65. '''
  66. u = jnp.where(u == 0, 0, jnp.where(u < eps, eps, u))
  67. c = u.sum(axis=axis)
  68. c = jnp.where(c == 0, 1, c)
  69. return u / c, c
  70. # (sec:casino-ex)=
  71. # ## Example: Casino HMM
  72. #
  73. # We first create the "Ocassionally dishonest casino" model from {cite}`Durbin98`.
  74. #
  75. # ```{figure} /figures/casino.png
  76. # :scale: 50%
  77. # :name: casino-fig
  78. #
  79. # Illustration of the casino HMM.
  80. # ```
  81. #
  82. # There are 2 hidden states, each of which emit 6 possible observations.
  83. # In[4]:
  84. # state transition matrix
  85. A = np.array([
  86. [0.95, 0.05],
  87. [0.10, 0.90]
  88. ])
  89. # observation matrix
  90. B = np.array([
  91. [1/6, 1/6, 1/6, 1/6, 1/6, 1/6], # fair die
  92. [1/10, 1/10, 1/10, 1/10, 1/10, 5/10] # loaded die
  93. ])
  94. pi, _ = normalize(np.array([1, 1]))
  95. pi = np.array(pi)
  96. (nstates, nobs) = np.shape(B)
  97. # Let's make a little data structure to store all the parameters.
  98. # We use NamedTuple rather than dataclass, since we assume these are immutable.
  99. # (Also, standard python dataclass does not work well with JAX, which requires parameters to be
  100. # pytrees, as discussed in https://github.com/google/jax/issues/2371).
  101. # In[5]:
  102. Array = Union[np.array, jnp.array]
  103. class HMM(NamedTuple):
  104. trans_mat: Array # A : (n_states, n_states)
  105. obs_mat: Array # B : (n_states, n_obs)
  106. init_dist: Array # pi : (n_states)
  107. params_np = HMM(A, B, pi)
  108. print(params_np)
  109. print(type(params_np.trans_mat))
  110. params = jax.tree_map(lambda x: jnp.array(x), params_np)
  111. print(params)
  112. print(type(params.trans_mat))
  113. # ## Sampling from the joint
  114. #
  115. # Let's write code to sample from this model.
  116. #
  117. # ### Numpy version
  118. #
  119. # First we code it in numpy using a for loop.
  120. # In[6]:
  121. def hmm_sample_np(params, seq_len, random_state=0):
  122. np.random.seed(random_state)
  123. trans_mat, obs_mat, init_dist = params.trans_mat, params.obs_mat, params.init_dist
  124. n_states, n_obs = obs_mat.shape
  125. state_seq = np.zeros(seq_len, dtype=int)
  126. obs_seq = np.zeros(seq_len, dtype=int)
  127. for t in range(seq_len):
  128. if t==0:
  129. zt = np.random.choice(n_states, p=init_dist)
  130. else:
  131. zt = np.random.choice(n_states, p=trans_mat[zt])
  132. yt = np.random.choice(n_obs, p=obs_mat[zt])
  133. state_seq[t] = zt
  134. obs_seq[t] = yt
  135. return state_seq, obs_seq
  136. # In[7]:
  137. seq_len = 100
  138. state_seq, obs_seq = hmm_sample_np(params_np, seq_len, random_state=1)
  139. print(state_seq)
  140. print(obs_seq)
  141. # ### JAX version
  142. #
  143. # Now let's write a JAX version using jax.lax.scan (for the inter-dependent states) and vmap (for the observations).
  144. # This is harder to read than the numpy version, but faster.
  145. # In[8]:
  146. #@partial(jit, static_argnums=(1,))
  147. def markov_chain_sample(rng_key, init_dist, trans_mat, seq_len):
  148. n_states = len(init_dist)
  149. def draw_state(prev_state, key):
  150. state = jax.random.choice(key, n_states, p=trans_mat[prev_state])
  151. return state, state
  152. rng_key, rng_state = jax.random.split(rng_key, 2)
  153. keys = jax.random.split(rng_state, seq_len - 1)
  154. initial_state = jax.random.choice(rng_key, n_states, p=init_dist)
  155. final_state, states = jax.lax.scan(draw_state, initial_state, keys)
  156. state_seq = jnp.append(jnp.array([initial_state]), states)
  157. return state_seq
  158. # In[9]:
  159. #@partial(jit, static_argnums=(1,))
  160. def hmm_sample(rng_key, params, seq_len):
  161. trans_mat, obs_mat, init_dist = params.trans_mat, params.obs_mat, params.init_dist
  162. n_states, n_obs = obs_mat.shape
  163. rng_key, rng_obs = jax.random.split(rng_key, 2)
  164. state_seq = markov_chain_sample(rng_key, init_dist, trans_mat, seq_len)
  165. def draw_obs(z, key):
  166. obs = jax.random.choice(key, n_obs, p=obs_mat[z])
  167. return obs
  168. keys = jax.random.split(rng_obs, seq_len)
  169. obs_seq = jax.vmap(draw_obs, in_axes=(0, 0))(state_seq, keys)
  170. return state_seq, obs_seq
  171. # In[10]:
  172. #@partial(jit, static_argnums=(1,))
  173. def hmm_sample2(rng_key, params, seq_len):
  174. trans_mat, obs_mat, init_dist = params.trans_mat, params.obs_mat, params.init_dist
  175. n_states, n_obs = obs_mat.shape
  176. def draw_state(prev_state, key):
  177. state = jax.random.choice(key, n_states, p=trans_mat[prev_state])
  178. return state, state
  179. rng_key, rng_state, rng_obs = jax.random.split(rng_key, 3)
  180. keys = jax.random.split(rng_state, seq_len - 1)
  181. initial_state = jax.random.choice(rng_key, n_states, p=init_dist)
  182. final_state, states = jax.lax.scan(draw_state, initial_state, keys)
  183. state_seq = jnp.append(jnp.array([initial_state]), states)
  184. def draw_obs(z, key):
  185. obs = jax.random.choice(key, n_obs, p=obs_mat[z])
  186. return obs
  187. keys = jax.random.split(rng_obs, seq_len)
  188. obs_seq = jax.vmap(draw_obs, in_axes=(0, 0))(state_seq, keys)
  189. return state_seq, obs_seq
  190. # In[11]:
  191. key = PRNGKey(2)
  192. seq_len = 100
  193. state_seq, obs_seq = hmm_sample(key, params, seq_len)
  194. print(state_seq)
  195. print(obs_seq)
  196. # ### Check correctness by computing empirical pairwise statistics
  197. #
  198. # We will compute the number of i->j transitions, and check that it is close to the true
  199. # A[i,j] transition probabilites.
  200. # In[12]:
  201. import collections
  202. def compute_counts(state_seq, nstates):
  203. wseq = np.array(state_seq)
  204. word_pairs = [pair for pair in zip(wseq[:-1], wseq[1:])]
  205. counter_pairs = collections.Counter(word_pairs)
  206. counts = np.zeros((nstates, nstates))
  207. for (k,v) in counter_pairs.items():
  208. counts[k[0], k[1]] = v
  209. return counts
  210. def normalize_counts(counts):
  211. ncounts = vmap(lambda v: normalize(v)[0], in_axes=0)(counts)
  212. return ncounts
  213. init_dist = jnp.array([1.0, 0.0])
  214. trans_mat = jnp.array([[0.7, 0.3], [0.5, 0.5]])
  215. rng_key = jax.random.PRNGKey(0)
  216. seq_len = 500
  217. state_seq = markov_chain_sample(rng_key, init_dist, trans_mat, seq_len)
  218. print(state_seq)
  219. counts = compute_counts(state_seq, nstates=2)
  220. print(counts)
  221. trans_mat_empirical = normalize_counts(counts)
  222. print(trans_mat_empirical)
  223. assert jnp.allclose(trans_mat, trans_mat_empirical, atol=1e-1)
  224. # In[ ]: