scratchpad.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. #!/usr/bin/env python
  2. # coding: utf-8
  3. # In[1]:
  4. ### Import standard libraries
  5. import abc
  6. from dataclasses import dataclass
  7. import functools
  8. from functools import partial
  9. import itertools
  10. import matplotlib.pyplot as plt
  11. import numpy as np
  12. from typing import Any, Callable, NamedTuple, Optional, Union, Tuple
  13. import jax
  14. import jax.numpy as jnp
  15. from jax import lax, vmap, jit, grad
  16. #from jax.scipy.special import logit
  17. #from jax.nn import softmax
  18. import jax.random as jr
  19. import distrax
  20. import optax
  21. import jsl
  22. import ssm_jax
  23. # In[2]:
  24. import inspect
  25. import inspect as py_inspect
  26. import rich
  27. from rich import inspect as r_inspect
  28. from rich import print as r_print
  29. def print_source(fname):
  30. r_print(py_inspect.getsource(fname))
  31. # In[3]:
  32. # meta-data does not work yet in VScode
  33. # https://github.com/microsoft/vscode-jupyter/issues/1121
  34. {
  35. "tags": [
  36. "hide-cell"
  37. ]
  38. }
  39. ### Install necessary libraries
  40. try:
  41. import jax
  42. except:
  43. # For cuda version, see https://github.com/google/jax#installation
  44. get_ipython().run_line_magic('pip', 'install --upgrade "jax[cpu]"')
  45. import jax
  46. try:
  47. import distrax
  48. except:
  49. get_ipython().run_line_magic('pip', 'install --upgrade distrax')
  50. import distrax
  51. try:
  52. import jsl
  53. except:
  54. get_ipython().run_line_magic('pip', 'install git+https://github.com/probml/jsl')
  55. import jsl
  56. try:
  57. import rich
  58. except:
  59. get_ipython().run_line_magic('pip', 'install rich')
  60. import rich
  61. # In[ ]:
  62. # In[4]:
  63. {
  64. "tags": [
  65. "hide-cell"
  66. ]
  67. }
  68. ### Import standard libraries
  69. import abc
  70. from dataclasses import dataclass
  71. import functools
  72. import itertools
  73. from typing import Any, Callable, NamedTuple, Optional, Union, Tuple
  74. import matplotlib.pyplot as plt
  75. import numpy as np
  76. import jax
  77. import jax.numpy as jnp
  78. from jax import lax, vmap, jit, grad
  79. from jax.scipy.special import logit
  80. from jax.nn import softmax
  81. from functools import partial
  82. from jax.random import PRNGKey, split
  83. import inspect
  84. import inspect as py_inspect
  85. import rich
  86. from rich import inspect as r_inspect
  87. from rich import print as r_print
  88. def print_source(fname):
  89. r_print(py_inspect.getsource(fname))
  90. # In[5]:
  91. import ssm_jax
  92. from ssm_jax.hmm.models import GaussianHMM
  93. print_source(GaussianHMM)
  94. # In[6]:
  95. # Set dimensions
  96. num_states = 5
  97. emission_dim = 2
  98. # Specify parameters of the HMM
  99. initial_probs = jnp.ones(num_states) / num_states
  100. transition_matrix = 0.95 * jnp.eye(num_states) + 0.05 * jnp.roll(jnp.eye(num_states), 1, axis=1)
  101. emission_means = jnp.column_stack([
  102. jnp.cos(jnp.linspace(0, 2 * jnp.pi, num_states+1))[:-1],
  103. jnp.sin(jnp.linspace(0, 2 * jnp.pi, num_states+1))[:-1]
  104. ])
  105. emission_covs = jnp.tile(0.1**2 * jnp.eye(emission_dim), (num_states, 1, 1))
  106. hmm = GaussianHMM(initial_probs,
  107. transition_matrix,
  108. emission_means,
  109. emission_covs)
  110. print_source(hmm.sample)
  111. # In[7]:
  112. import distrax
  113. from distrax import HMM
  114. A = np.array([
  115. [0.95, 0.05],
  116. [0.10, 0.90]
  117. ])
  118. # observation matrix
  119. B = np.array([
  120. [1/6, 1/6, 1/6, 1/6, 1/6, 1/6], # fair die
  121. [1/10, 1/10, 1/10, 1/10, 1/10, 5/10] # loaded die
  122. ])
  123. pi = np.array([0.5, 0.5])
  124. (nstates, nobs) = np.shape(B)
  125. hmm = HMM(trans_dist=distrax.Categorical(probs=A),
  126. init_dist=distrax.Categorical(probs=pi),
  127. obs_dist=distrax.Categorical(probs=B))
  128. print(hmm)
  129. # In[8]:
  130. print_source(hmm.sample)
  131. # In[ ]: