{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "### Import standard libraries\n", "\n", "import abc\n", "from dataclasses import dataclass\n", "import functools\n", "from functools import partial\n", "import itertools\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "from typing import Any, Callable, NamedTuple, Optional, Union, Tuple\n", "\n", "import jax\n", "import jax.numpy as jnp\n", "from jax import lax, vmap, jit, grad\n", "#from jax.scipy.special import logit\n", "#from jax.nn import softmax\n", "import jax.random as jr\n", "\n", "\n", "\n", "import distrax\n", "import optax\n", "\n", "import jsl\n", "import ssm_jax\n", "\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import inspect\n", "import inspect as py_inspect\n", "import rich\n", "from rich import inspect as r_inspect\n", "from rich import print as r_print\n", "\n", "def print_source(fname):\n", " r_print(py_inspect.getsource(fname))" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# meta-data does not work yet in VScode\n", "# https://github.com/microsoft/vscode-jupyter/issues/1121\n", "\n", "{\n", " \"tags\": [\n", " \"hide-cell\"\n", " ]\n", "}\n", "\n", "\n", "### Install necessary libraries\n", "\n", "try:\n", " import jax\n", "except:\n", " # For cuda version, see https://github.com/google/jax#installation\n", " %pip install --upgrade \"jax[cpu]\" \n", " import jax\n", "\n", "try:\n", " import distrax\n", "except:\n", " %pip install --upgrade distrax\n", " import distrax\n", "\n", "try:\n", " import jsl\n", "except:\n", " %pip install git+https://github.com/probml/jsl\n", " import jsl\n", "\n", "try:\n", " import rich\n", "except:\n", " %pip install rich\n", " import rich\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "{\n", " \"tags\": [\n", " \"hide-cell\"\n", " ]\n", "}\n", "\n", "\n", "### Import standard libraries\n", "\n", "import abc\n", "from dataclasses import dataclass\n", "import functools\n", "import itertools\n", "\n", "from typing import Any, Callable, NamedTuple, Optional, Union, Tuple\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "\n", "import jax\n", "import jax.numpy as jnp\n", "from jax import lax, vmap, jit, grad\n", "from jax.scipy.special import logit\n", "from jax.nn import softmax\n", "from functools import partial\n", "from jax.random import PRNGKey, split\n", "\n", "import inspect\n", "import inspect as py_inspect\n", "import rich\n", "from rich import inspect as r_inspect\n", "from rich import print as r_print\n", "\n", "def print_source(fname):\n", " r_print(py_inspect.getsource(fname))" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
class GaussianHMM(BaseHMM):\n", " def __init__(self,\n", " initial_probabilities,\n", " transition_matrix,\n", " emission_means,\n", " emission_covariance_matrices):\n", " \"\"\"_summary_\n", "\n", " Args:\n", " initial_probabilities (_type_): _description_\n", " transition_matrix (_type_): _description_\n", " emission_means (_type_): _description_\n", " emission_covariance_matrices (_type_): _description_\n", " \"\"\"\n", " super().__init__(initial_probabilities,\n", " transition_matrix)\n", "\n", " self._emission_distribution = tfd.MultivariateNormalFullCovariance(\n", " emission_means, emission_covariance_matrices)\n", "\n", " @classmethod\n", " def random_initialization(cls, key, num_states, emission_dim):\n", " key1, key2, key3 = jr.split(key, 3)\n", " initial_probs = jr.dirichlet(key1, jnp.ones(num_states))\n", " transition_matrix = jr.dirichlet(key2, jnp.ones(num_states), (num_states,))\n", " emission_means = jr.normal(key3, (num_states, emission_dim))\n", " emission_covs = jnp.tile(jnp.eye(emission_dim), (num_states, 1, 1))\n", " return cls(initial_probs, transition_matrix, emission_means, emission_covs)\n", "\n", " # Properties to get various parameters of the model\n", " @property\n", " def emission_distribution(self):\n", " return self._emission_distribution\n", "\n", " @property\n", " def emission_means(self):\n", " return self.emission_distribution.mean()\n", "\n", " @property\n", " def emission_covariance_matrices(self):\n", " return self.emission_distribution.covariance()\n", "\n", " @property\n", " def unconstrained_params(self):\n", " \"\"\"Helper property to get a PyTree of unconstrained parameters.\n", " \"\"\"\n", " return tfb.SoftmaxCentered().inverse(self.initial_probabilities), \\\n", " tfb.SoftmaxCentered().inverse(self.transition_matrix), \\\n", " self.emission_means, \\\n", " PSDToRealBijector.forward(self.emission_covariance_matrices)\n", "\n", " @classmethod\n", " def from_unconstrained_params(cls, unconstrained_params, hypers):\n", " initial_probabilities = tfb.SoftmaxCentered().forward(unconstrained_params[0])\n", " transition_matrix = tfb.SoftmaxCentered().forward(unconstrained_params[1])\n", " emission_means = unconstrained_params[2]\n", " emission_covs = PSDToRealBijector.inverse(unconstrained_params[3])\n", " return cls(initial_probabilities, transition_matrix, emission_means, emission_covs, \n", "*hypers)\n", "\n", "\n" ], "text/plain": [ "
def sample(self, key, num_timesteps):\n", " \"\"\"Sample a sequence of latent states and emissions.\n", "\n", " Args:\n", " key (_type_): _description_\n", " num_timesteps (_type_): _description_\n", " \"\"\"\n", " def _step(state, key):\n", " key1, key2 = jr.split(key, 2)\n", " emission = self.emission_distribution.sample(seed=key1)\n", " next_state = self.transition_distribution.sample(seed=key2)\n", " return next_state, (state, emission)\n", "\n", " # Sample the initial state\n", " key1, key = jr.split(key, 2)\n", " initial_state = self.initial_distribution.sample(seed=key1)\n", "\n", " # Sample the remaining emissions and states\n", " keys = jr.split(key, num_timesteps)\n", " _, (states, emissions) = lax.scan(_step, initial_state, keys)\n", " return states, emissions\n", "\n", "\n" ], "text/plain": [ "
def sample(self,\n", " *,\n", " seed: chex.PRNGKey,\n", " seq_len: chex.Array) -> Tuple:\n", " \"\"\"Sample from this HMM.\n", "\n", " Samples an observation of given length according to this\n", " Hidden Markov Model and gives the sequence of the hidden states\n", " as well as the observation.\n", "\n", " Args:\n", " seed: Random key of shape (2,) and dtype uint32.\n", " seq_len: The length of the observation sequence.\n", "\n", " Returns:\n", " Tuple of hidden state sequence, and observation sequence.\n", " \"\"\"\n", " rng_key, rng_init = jax.random.split(seed)\n", " initial_state = self._init_dist.sample(seed=rng_init)\n", "\n", " def draw_state(prev_state, key):\n", " state = self._trans_dist.sample(seed=key)\n", " return state, state\n", "\n", " rng_state, rng_obs = jax.random.split(rng_key)\n", " keys = jax.random.split(rng_state, seq_len - 1)\n", " _, states = jax.lax.scan(draw_state, initial_state, keys)\n", " states = jnp.append(initial_state, states)\n", "\n", " def draw_obs(state, key):\n", " return self._obs_dist.sample(seed=key)\n", "\n", " keys = jax.random.split(rng_obs, seq_len)\n", " obs_seq = jax.vmap(draw_obs, in_axes=(0, 0))(states, keys)\n", "\n", " return states, obs_seq\n", "\n", "\n" ], "text/plain": [ "