ssm.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. #!/usr/bin/env python
  2. # coding: utf-8
  3. # In[1]:
  4. {
  5. "tags": [
  6. "hide-cell"
  7. ]
  8. }
  9. ### Install necessary libraries
  10. try:
  11. import jax
  12. except:
  13. # For cuda version, see https://github.com/google/jax#installation
  14. get_ipython().run_line_magic('pip', 'install --upgrade "jax[cpu]"')
  15. import jax
  16. try:
  17. import jsl
  18. except:
  19. get_ipython().run_line_magic('pip', 'install git+https://github.com/probml/jsl')
  20. import jsl
  21. try:
  22. import rich
  23. except:
  24. get_ipython().run_line_magic('pip', 'install rich')
  25. import rich
  26. # In[2]:
  27. {
  28. "tags": [
  29. "hide-cell"
  30. ]
  31. }
  32. ### Import standard libraries
  33. import abc
  34. from dataclasses import dataclass
  35. import functools
  36. import itertools
  37. from typing import Any, Callable, NamedTuple, Optional, Union, Tuple
  38. import matplotlib.pyplot as plt
  39. import numpy as np
  40. import jax
  41. import jax.numpy as jnp
  42. from jax import lax, vmap, jit, grad
  43. from jax.scipy.special import logit
  44. from jax.nn import softmax
  45. from functools import partial
  46. from jax.random import PRNGKey, split
  47. import inspect
  48. import inspect as py_inspect
  49. from rich import inspect as r_inspect
  50. from rich import print as r_print
  51. def print_source(fname):
  52. r_print(py_inspect.getsource(fname))
  53. # (sec:ssm-intro)=
  54. # # What are State Space Models?
  55. #
  56. #
  57. # A state space model or SSM
  58. # is a partially observed Markov model,
  59. # in which the hidden state, $z_t$,
  60. # evolves over time according to a Markov process.
  61. #
  62. #
  63. # ```{figure} /figures/SSM-AR-inputs.png
  64. # :scale: 100%
  65. # :name: ssm-ar
  66. #
  67. # Illustration of an SSM as a graphical model.
  68. # ```
  69. #
  70. # ```{figure} /figures/SSM-simplified.png
  71. # :scale: 100%
  72. # :name: ssm-simplifed
  73. #
  74. # Illustration of a simplified SSM.
  75. # ```
  76. # (sec:casino-ex)=
  77. # ## Example: Casino HMM
  78. #
  79. # We first create the "Ocassionally dishonest casino" model from {cite}`Durbin98`.
  80. #
  81. #
  82. #
  83. # There are 2 hidden states, each of which emit 6 possible observations.
  84. # In[3]:
  85. # state transition matrix
  86. A = np.array([
  87. [0.95, 0.05],
  88. [0.10, 0.90]
  89. ])
  90. # observation matrix
  91. B = np.array([
  92. [1/6, 1/6, 1/6, 1/6, 1/6, 1/6], # fair die
  93. [1/10, 1/10, 1/10, 1/10, 1/10, 5/10] # loaded die
  94. ])
  95. pi, _ = normalize(np.array([1, 1]))
  96. pi = np.array(pi)
  97. (nstates, nobs) = np.shape(B)
  98. # In[ ]: