Contents
### Import standard libraries
import abc
from dataclasses import dataclass
import functools
from functools import partial
import itertools
import matplotlib.pyplot as plt
import numpy as np
from typing import Any, Callable, NamedTuple, Optional, Union, Tuple
import jax
import jax.numpy as jnp
from jax import lax, vmap, jit, grad
#from jax.scipy.special import logit
#from jax.nn import softmax
import jax.random as jr
import distrax
import optax
import jsl
import ssm_jax
import inspect
import inspect as py_inspect
import rich
from rich import inspect as r_inspect
from rich import print as r_print
def print_source(fname):
r_print(py_inspect.getsource(fname))
# meta-data does not work yet in VScode
# https://github.com/microsoft/vscode-jupyter/issues/1121
{
"tags": [
"hide-cell"
]
}
### Install necessary libraries
try:
import jax
except:
# For cuda version, see https://github.com/google/jax#installation
%pip install --upgrade "jax[cpu]"
import jax
try:
import distrax
except:
%pip install --upgrade distrax
import distrax
try:
import jsl
except:
%pip install git+https://github.com/probml/jsl
import jsl
try:
import rich
except:
%pip install rich
import rich
{
"tags": [
"hide-cell"
]
}
### Import standard libraries
import abc
from dataclasses import dataclass
import functools
import itertools
from typing import Any, Callable, NamedTuple, Optional, Union, Tuple
import matplotlib.pyplot as plt
import numpy as np
import jax
import jax.numpy as jnp
from jax import lax, vmap, jit, grad
from jax.scipy.special import logit
from jax.nn import softmax
from functools import partial
from jax.random import PRNGKey, split
import inspect
import inspect as py_inspect
import rich
from rich import inspect as r_inspect
from rich import print as r_print
def print_source(fname):
r_print(py_inspect.getsource(fname))
import ssm_jax
from ssm_jax.hmm.models import GaussianHMM
print_source(GaussianHMM)
class GaussianHMM(BaseHMM): def __init__(self, initial_probabilities, transition_matrix, emission_means, emission_covariance_matrices): """_summary_ Args: initial_probabilities (_type_): _description_ transition_matrix (_type_): _description_ emission_means (_type_): _description_ emission_covariance_matrices (_type_): _description_ """ super().__init__(initial_probabilities, transition_matrix) self._emission_distribution = tfd.MultivariateNormalFullCovariance( emission_means, emission_covariance_matrices) @classmethod def random_initialization(cls, key, num_states, emission_dim): key1, key2, key3 = jr.split(key, 3) initial_probs = jr.dirichlet(key1, jnp.ones(num_states)) transition_matrix = jr.dirichlet(key2, jnp.ones(num_states), (num_states,)) emission_means = jr.normal(key3, (num_states, emission_dim)) emission_covs = jnp.tile(jnp.eye(emission_dim), (num_states, 1, 1)) return cls(initial_probs, transition_matrix, emission_means, emission_covs) # Properties to get various parameters of the model @property def emission_distribution(self): return self._emission_distribution @property def emission_means(self): return self.emission_distribution.mean() @property def emission_covariance_matrices(self): return self.emission_distribution.covariance() @property def unconstrained_params(self): """Helper property to get a PyTree of unconstrained parameters. """ return tfb.SoftmaxCentered().inverse(self.initial_probabilities), \ tfb.SoftmaxCentered().inverse(self.transition_matrix), \ self.emission_means, \ PSDToRealBijector.forward(self.emission_covariance_matrices) @classmethod def from_unconstrained_params(cls, unconstrained_params, hypers): initial_probabilities = tfb.SoftmaxCentered().forward(unconstrained_params[0]) transition_matrix = tfb.SoftmaxCentered().forward(unconstrained_params[1]) emission_means = unconstrained_params[2] emission_covs = PSDToRealBijector.inverse(unconstrained_params[3]) return cls(initial_probabilities, transition_matrix, emission_means, emission_covs, *hypers)
# Set dimensions
num_states = 5
emission_dim = 2
# Specify parameters of the HMM
initial_probs = jnp.ones(num_states) / num_states
transition_matrix = 0.95 * jnp.eye(num_states) + 0.05 * jnp.roll(jnp.eye(num_states), 1, axis=1)
emission_means = jnp.column_stack([
jnp.cos(jnp.linspace(0, 2 * jnp.pi, num_states+1))[:-1],
jnp.sin(jnp.linspace(0, 2 * jnp.pi, num_states+1))[:-1]
])
emission_covs = jnp.tile(0.1**2 * jnp.eye(emission_dim), (num_states, 1, 1))
hmm = GaussianHMM(initial_probs,
transition_matrix,
emission_means,
emission_covs)
print_source(hmm.sample)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
def sample(self, key, num_timesteps): """Sample a sequence of latent states and emissions. Args: key (_type_): _description_ num_timesteps (_type_): _description_ """ def _step(state, key): key1, key2 = jr.split(key, 2) emission = self.emission_distribution.sample(seed=key1) next_state = self.transition_distribution.sample(seed=key2) return next_state, (state, emission) # Sample the initial state key1, key = jr.split(key, 2) initial_state = self.initial_distribution.sample(seed=key1) # Sample the remaining emissions and states keys = jr.split(key, num_timesteps) _, (states, emissions) = lax.scan(_step, initial_state, keys) return states, emissions
import distrax
from distrax import HMM
A = np.array([
[0.95, 0.05],
[0.10, 0.90]
])
# observation matrix
B = np.array([
[1/6, 1/6, 1/6, 1/6, 1/6, 1/6], # fair die
[1/10, 1/10, 1/10, 1/10, 1/10, 5/10] # loaded die
])
pi = np.array([0.5, 0.5])
(nstates, nobs) = np.shape(B)
hmm = HMM(trans_dist=distrax.Categorical(probs=A),
init_dist=distrax.Categorical(probs=pi),
obs_dist=distrax.Categorical(probs=B))
print(hmm)
<distrax._src.utils.hmm.HMM object at 0x7fde82c856d0>
print_source(hmm.sample)
def sample(self, *, seed: chex.PRNGKey, seq_len: chex.Array) -> Tuple: """Sample from this HMM. Samples an observation of given length according to this Hidden Markov Model and gives the sequence of the hidden states as well as the observation. Args: seed: Random key of shape (2,) and dtype uint32. seq_len: The length of the observation sequence. Returns: Tuple of hidden state sequence, and observation sequence. """ rng_key, rng_init = jax.random.split(seed) initial_state = self._init_dist.sample(seed=rng_init) def draw_state(prev_state, key): state = self._trans_dist.sample(seed=key) return state, state rng_state, rng_obs = jax.random.split(rng_key) keys = jax.random.split(rng_state, seq_len - 1) _, states = jax.lax.scan(draw_state, initial_state, keys) states = jnp.append(initial_state, states) def draw_obs(state, key): return self._obs_dist.sample(seed=key) keys = jax.random.split(rng_obs, seq_len) obs_seq = jax.vmap(draw_obs, in_axes=(0, 0))(states, keys) return states, obs_seq