123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208 |
- #!/usr/bin/env python
- # coding: utf-8
- # In[1]:
- ### Import standard libraries
- import abc
- from dataclasses import dataclass
- import functools
- from functools import partial
- import itertools
- import matplotlib.pyplot as plt
- import numpy as np
- from typing import Any, Callable, NamedTuple, Optional, Union, Tuple
- import jax
- import jax.numpy as jnp
- from jax import lax, vmap, jit, grad
- #from jax.scipy.special import logit
- #from jax.nn import softmax
- import jax.random as jr
- import distrax
- import optax
- import jsl
- import ssm_jax
- # In[2]:
- import inspect
- import inspect as py_inspect
- import rich
- from rich import inspect as r_inspect
- from rich import print as r_print
- def print_source(fname):
- r_print(py_inspect.getsource(fname))
- # In[3]:
- # meta-data does not work yet in VScode
- # https://github.com/microsoft/vscode-jupyter/issues/1121
- {
- "tags": [
- "hide-cell"
- ]
- }
- ### Install necessary libraries
- try:
- import jax
- except:
- # For cuda version, see https://github.com/google/jax#installation
- get_ipython().run_line_magic('pip', 'install --upgrade "jax[cpu]"')
- import jax
- try:
- import distrax
- except:
- get_ipython().run_line_magic('pip', 'install --upgrade distrax')
- import distrax
- try:
- import jsl
- except:
- get_ipython().run_line_magic('pip', 'install git+https://github.com/probml/jsl')
- import jsl
- try:
- import rich
- except:
- get_ipython().run_line_magic('pip', 'install rich')
- import rich
- # In[ ]:
- # In[4]:
- {
- "tags": [
- "hide-cell"
- ]
- }
- ### Import standard libraries
- import abc
- from dataclasses import dataclass
- import functools
- import itertools
- from typing import Any, Callable, NamedTuple, Optional, Union, Tuple
- import matplotlib.pyplot as plt
- import numpy as np
- import jax
- import jax.numpy as jnp
- from jax import lax, vmap, jit, grad
- from jax.scipy.special import logit
- from jax.nn import softmax
- from functools import partial
- from jax.random import PRNGKey, split
- import inspect
- import inspect as py_inspect
- import rich
- from rich import inspect as r_inspect
- from rich import print as r_print
- def print_source(fname):
- r_print(py_inspect.getsource(fname))
- # In[5]:
- import ssm_jax
- from ssm_jax.hmm.models import GaussianHMM
- print_source(GaussianHMM)
- # In[6]:
- # Set dimensions
- num_states = 5
- emission_dim = 2
- # Specify parameters of the HMM
- initial_probs = jnp.ones(num_states) / num_states
- transition_matrix = 0.95 * jnp.eye(num_states) + 0.05 * jnp.roll(jnp.eye(num_states), 1, axis=1)
- emission_means = jnp.column_stack([
- jnp.cos(jnp.linspace(0, 2 * jnp.pi, num_states+1))[:-1],
- jnp.sin(jnp.linspace(0, 2 * jnp.pi, num_states+1))[:-1]
- ])
- emission_covs = jnp.tile(0.1**2 * jnp.eye(emission_dim), (num_states, 1, 1))
- hmm = GaussianHMM(initial_probs,
- transition_matrix,
- emission_means,
- emission_covs)
- print_source(hmm.sample)
- # In[7]:
- import distrax
- from distrax import HMM
- A = np.array([
- [0.95, 0.05],
- [0.10, 0.90]
- ])
- # observation matrix
- B = np.array([
- [1/6, 1/6, 1/6, 1/6, 1/6, 1/6], # fair die
- [1/10, 1/10, 1/10, 1/10, 1/10, 5/10] # loaded die
- ])
- pi = np.array([0.5, 0.5])
- (nstates, nobs) = np.shape(B)
- hmm = HMM(trans_dist=distrax.Categorical(probs=A),
- init_dist=distrax.Categorical(probs=pi),
- obs_dist=distrax.Categorical(probs=B))
- print(hmm)
- # In[8]:
- print_source(hmm.sample)
- # In[ ]:
|