| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137 |
- #!/usr/bin/env python
- # coding: utf-8
- # In[1]:
- {
- "tags": [
- "hide-cell"
- ]
- }
- ### Install necessary libraries
- try:
- import jax
- except:
- # For cuda version, see https://github.com/google/jax#installation
- get_ipython().run_line_magic('pip', 'install --upgrade "jax[cpu]"')
- import jax
- try:
- import jsl
- except:
- get_ipython().run_line_magic('pip', 'install git+https://github.com/probml/jsl')
- import jsl
- try:
- import rich
- except:
- get_ipython().run_line_magic('pip', 'install rich')
- import rich
- # In[2]:
- {
- "tags": [
- "hide-cell"
- ]
- }
- ### Import standard libraries
- import abc
- from dataclasses import dataclass
- import functools
- import itertools
- from typing import Any, Callable, NamedTuple, Optional, Union, Tuple
- import matplotlib.pyplot as plt
- import numpy as np
- import jax
- import jax.numpy as jnp
- from jax import lax, vmap, jit, grad
- from jax.scipy.special import logit
- from jax.nn import softmax
- from functools import partial
- from jax.random import PRNGKey, split
- import inspect
- import inspect as py_inspect
- from rich import inspect as r_inspect
- from rich import print as r_print
- def print_source(fname):
- r_print(py_inspect.getsource(fname))
- # (sec:ssm-intro)=
- # # What are State Space Models?
- #
- #
- # A state space model or SSM
- # is a partially observed Markov model,
- # in which the hidden state, $z_t$,
- # evolves over time according to a Markov process.
- #
- #
- # ```{figure} /figures/SSM-AR-inputs.png
- # :scale: 100%
- # :name: ssm-ar
- #
- # Illustration of an SSM as a graphical model.
- # ```
- #
- # ```{figure} /figures/SSM-simplified.png
- # :scale: 100%
- # :name: ssm-simplifed
- #
- # Illustration of a simplified SSM.
- # ```
- # (sec:casino-ex)=
- # ## Example: Casino HMM
- #
- # We first create the "Ocassionally dishonest casino" model from {cite}`Durbin98`.
- #
- #
- #
- # There are 2 hidden states, each of which emit 6 possible observations.
- # In[3]:
- # state transition matrix
- A = np.array([
- [0.95, 0.05],
- [0.10, 0.90]
- ])
- # observation matrix
- B = np.array([
- [1/6, 1/6, 1/6, 1/6, 1/6, 1/6], # fair die
- [1/10, 1/10, 1/10, 1/10, 1/10, 5/10] # loaded die
- ])
- pi, _ = normalize(np.array([1, 1]))
- pi = np.array(pi)
- (nstates, nobs) = np.shape(B)
- # In[ ]:
|