#!/usr/bin/env python # coding: utf-8 # In[1]: ### Import standard libraries import abc from dataclasses import dataclass import functools from functools import partial import itertools import matplotlib.pyplot as plt import numpy as np from typing import Any, Callable, NamedTuple, Optional, Union, Tuple import jax import jax.numpy as jnp from jax import lax, vmap, jit, grad #from jax.scipy.special import logit #from jax.nn import softmax import jax.random as jr import distrax import optax import jsl import ssm_jax # In[2]: import inspect import inspect as py_inspect import rich from rich import inspect as r_inspect from rich import print as r_print def print_source(fname): r_print(py_inspect.getsource(fname)) # In[3]: # meta-data does not work yet in VScode # https://github.com/microsoft/vscode-jupyter/issues/1121 { "tags": [ "hide-cell" ] } ### Install necessary libraries try: import jax except: # For cuda version, see https://github.com/google/jax#installation get_ipython().run_line_magic('pip', 'install --upgrade "jax[cpu]"') import jax try: import distrax except: get_ipython().run_line_magic('pip', 'install --upgrade distrax') import distrax try: import jsl except: get_ipython().run_line_magic('pip', 'install git+https://github.com/probml/jsl') import jsl try: import rich except: get_ipython().run_line_magic('pip', 'install rich') import rich # In[ ]: # In[4]: { "tags": [ "hide-cell" ] } ### Import standard libraries import abc from dataclasses import dataclass import functools import itertools from typing import Any, Callable, NamedTuple, Optional, Union, Tuple import matplotlib.pyplot as plt import numpy as np import jax import jax.numpy as jnp from jax import lax, vmap, jit, grad from jax.scipy.special import logit from jax.nn import softmax from functools import partial from jax.random import PRNGKey, split import inspect import inspect as py_inspect import rich from rich import inspect as r_inspect from rich import print as r_print def print_source(fname): r_print(py_inspect.getsource(fname)) # In[5]: import ssm_jax from ssm_jax.hmm.models import GaussianHMM print_source(GaussianHMM) # In[6]: # Set dimensions num_states = 5 emission_dim = 2 # Specify parameters of the HMM initial_probs = jnp.ones(num_states) / num_states transition_matrix = 0.95 * jnp.eye(num_states) + 0.05 * jnp.roll(jnp.eye(num_states), 1, axis=1) emission_means = jnp.column_stack([ jnp.cos(jnp.linspace(0, 2 * jnp.pi, num_states+1))[:-1], jnp.sin(jnp.linspace(0, 2 * jnp.pi, num_states+1))[:-1] ]) emission_covs = jnp.tile(0.1**2 * jnp.eye(emission_dim), (num_states, 1, 1)) hmm = GaussianHMM(initial_probs, transition_matrix, emission_means, emission_covs) print_source(hmm.sample) # In[7]: import distrax from distrax import HMM A = np.array([ [0.95, 0.05], [0.10, 0.90] ]) # observation matrix B = np.array([ [1/6, 1/6, 1/6, 1/6, 1/6, 1/6], # fair die [1/10, 1/10, 1/10, 1/10, 1/10, 5/10] # loaded die ]) pi = np.array([0.5, 0.5]) (nstates, nobs) = np.shape(B) hmm = HMM(trans_dist=distrax.Categorical(probs=A), init_dist=distrax.Categorical(probs=pi), obs_dist=distrax.Categorical(probs=B)) print(hmm) # In[8]: print_source(hmm.sample) # In[ ]: