|
@@ -24,6 +24,12 @@ except:
|
|
|
import jax
|
|
|
|
|
|
try:
|
|
|
+ import distrax
|
|
|
+except:
|
|
|
+ get_ipython().run_line_magic('pip', 'install --upgrade distrax')
|
|
|
+ import distrax
|
|
|
+
|
|
|
+try:
|
|
|
import jsl
|
|
|
except:
|
|
|
get_ipython().run_line_magic('pip', 'install git+https://github.com/probml/jsl')
|
|
@@ -71,7 +77,6 @@ from jax.random import PRNGKey, split
|
|
|
import inspect
|
|
|
import inspect as py_inspect
|
|
|
import rich
|
|
|
-
|
|
|
from rich import inspect as r_inspect
|
|
|
from rich import print as r_print
|
|
|
|
|
@@ -209,7 +214,9 @@ def print_source(fname):
|
|
|
#
|
|
|
# %%%%
|
|
|
# \newcommand{\hidden}{\vz}
|
|
|
-# \newcommand{\obs}{\vy}
|
|
|
+# \newcommand{\hid}{\hidden}
|
|
|
+# \newcommand{\observed}{\vy}
|
|
|
+# \newcommand{\obs}{\observed}
|
|
|
# \newcommand{\inputs}{\vu}
|
|
|
# \newcommand{\input}{\inputs}
|
|
|
#
|
|
@@ -279,10 +286,6 @@ def print_source(fname):
|
|
|
# Illustration of an SSM as a graphical model.
|
|
|
# ```
|
|
|
#
|
|
|
-#
|
|
|
-#
|
|
|
-#
|
|
|
-
|
|
|
# We often consider a simpler setting in which there
|
|
|
# are no external inputs,
|
|
|
# and the observations are conditionally independent of each other
|
|
@@ -306,6 +309,8 @@ def print_source(fname):
|
|
|
#
|
|
|
# Illustration of a simplified SSM.
|
|
|
# ```
|
|
|
+#
|
|
|
+#
|
|
|
|
|
|
# (sec:hmm-intro)=
|
|
|
# # Hidden Markov Models
|
|
@@ -325,7 +330,7 @@ def print_source(fname):
|
|
|
# For an interactive introduction,
|
|
|
# see https://nipunbatra.github.io/hmm/.
|
|
|
|
|
|
-# (sec:casino-ex)=
|
|
|
+# (sec:casino)=
|
|
|
# ### Example: Casino HMM
|
|
|
#
|
|
|
# To illustrate HMMs with categorical observation model,
|
|
@@ -342,7 +347,9 @@ def print_source(fname):
|
|
|
# meaning each row sums to one.
|
|
|
# We can visualize
|
|
|
# the non-zero entries in the transition matrix by creating a state transition diagram,
|
|
|
-# as shown in {ref}`casino-fig`.
|
|
|
+# as shown in
|
|
|
+# {numref}`Figure %s <casino-fig>`
|
|
|
+# %{ref}`casino-fig`.
|
|
|
#
|
|
|
# ```{figure} /figures/casino.png
|
|
|
# :scale: 50%
|
|
@@ -352,7 +359,7 @@ def print_source(fname):
|
|
|
# ```
|
|
|
#
|
|
|
# The observation model
|
|
|
-# $p(\obs_t|\hiddden_t=j)$ has the form
|
|
|
+# $p(\obs_t|\hidden_t=j)$ has the form
|
|
|
# ```{math}
|
|
|
# p(\obs_t=k|\hidden_t=j) = \hmmObs_{jk}
|
|
|
# ```
|
|
@@ -367,7 +374,7 @@ def print_source(fname):
|
|
|
#
|
|
|
# Collectively we denote all the parameters by $\vtheta=(\hmmTrans, \hmmObs, \hmmInit)$.
|
|
|
#
|
|
|
-# Now let us implement this model code.
|
|
|
+# Now let us implement this model in code.
|
|
|
|
|
|
# In[3]:
|
|
|
|
|
@@ -388,3 +395,230 @@ pi = np.array([0.5, 0.5])
|
|
|
|
|
|
(nstates, nobs) = np.shape(B)
|
|
|
|
|
|
+
|
|
|
+# In[4]:
|
|
|
+
|
|
|
+
|
|
|
+import distrax
|
|
|
+from distrax import HMM
|
|
|
+
|
|
|
+
|
|
|
+hmm = HMM(trans_dist=distrax.Categorical(probs=A),
|
|
|
+ init_dist=distrax.Categorical(probs=pi),
|
|
|
+ obs_dist=distrax.Categorical(probs=B))
|
|
|
+
|
|
|
+print(hmm)
|
|
|
+
|
|
|
+
|
|
|
+#
|
|
|
+# Let's sample from the model. We will generate a sequence of latent states, $\hid_{1:T}$,
|
|
|
+# which we then convert to a sequence of observations, $\obs_{1:T}$.
|
|
|
+
|
|
|
+# In[5]:
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+seed = 314
|
|
|
+n_samples = 300
|
|
|
+z_hist, x_hist = hmm.sample(seed=PRNGKey(seed), seq_len=n_samples)
|
|
|
+
|
|
|
+z_hist_str = "".join((np.array(z_hist) + 1).astype(str))[:60]
|
|
|
+x_hist_str = "".join((np.array(x_hist) + 1).astype(str))[:60]
|
|
|
+
|
|
|
+print("Printing sample observed/latent...")
|
|
|
+print(f"x: {x_hist_str}")
|
|
|
+print(f"z: {z_hist_str}")
|
|
|
+
|
|
|
+
|
|
|
+# In[6]:
|
|
|
+
|
|
|
+
|
|
|
+# Here is the source code for the sampling algorithm.
|
|
|
+
|
|
|
+print_source(hmm.sample)
|
|
|
+
|
|
|
+
|
|
|
+# Our primary goal will be to infer the latent state from the observations,
|
|
|
+# so we can detect if the casino is being dishonest or not. This will
|
|
|
+# affect how we choose to gamble our money.
|
|
|
+# We discuss various ways to perform this inference below.
|
|
|
+
|
|
|
+# # Linear Gaussian SSMs
|
|
|
+#
|
|
|
+# Blah blah
|
|
|
+
|
|
|
+# (sec:tracking-lds)=
|
|
|
+# ## Example: model for 2d tracking
|
|
|
+#
|
|
|
+# Blah blah
|
|
|
+
|
|
|
+# (sec:inference)=
|
|
|
+# # Inferential goals
|
|
|
+#
|
|
|
+# ```{figure} /figures/dbn-inference-problems.png
|
|
|
+# :scale: 100%
|
|
|
+# :name: dbn-inference
|
|
|
+#
|
|
|
+# Illustration of the different kinds of inference in an SSM.
|
|
|
+# The main kinds of inference for state-space models.
|
|
|
+# The shaded region is the interval for which we have data.
|
|
|
+# The arrow represents the time step at which we want to perform inference.
|
|
|
+# $t$ is the current time, $T$ is the sequence length,
|
|
|
+# $\ell$ is the lag and $h$ is the prediction horizon.
|
|
|
+# ```
|
|
|
+#
|
|
|
+#
|
|
|
+#
|
|
|
+# Given the sequence of observations, and a known model,
|
|
|
+# one of the main tasks with SSMs
|
|
|
+# to perform posterior inference,
|
|
|
+# about the hidden states; this is also called
|
|
|
+# state estimation.
|
|
|
+# At each time step $t$,
|
|
|
+# there are multiple forms of posterior we may be interested in computing,
|
|
|
+# including the following:
|
|
|
+# - the filtering distribution
|
|
|
+# $p(\hmmhid_t|\hmmobs_{1:t})$
|
|
|
+# - the smoothing distribution
|
|
|
+# $p(\hmmhid_t|\hmmobs_{1:T})$ (note that this conditions on future data $T>t$)
|
|
|
+# - the fixed-lag smoothing distribution
|
|
|
+# $p(\hmmhid_{t-\ell}|\hmmobs_{1:t})$ (note that this
|
|
|
+# infers $\ell$ steps in the past given data up to the present).
|
|
|
+#
|
|
|
+# We may also want to compute the
|
|
|
+# predictive distribution $h$ steps into the future:
|
|
|
+# \begin{align}
|
|
|
+# p(\hmmobs_{t+h}|\hmmobs_{1:t})
|
|
|
+# &= \sum_{\hmmhid_{t+h}} p(\hmmobs_{t+h}|\hmmhid_{t+h}) p(\hmmhid_{t+h}|\hmmobs_{1:t})
|
|
|
+# \end{align}
|
|
|
+# where the hidden state predictive distribution is
|
|
|
+# \begin{align}
|
|
|
+# p(\hmmhid_{t+h}|\hmmobs_{1:t})
|
|
|
+# &= \sum_{\hmmhid_{t:t+h-1}}
|
|
|
+# p(\hmmhid_t|\hmmobs_{1:t})
|
|
|
+# p(\hmmhid_{t+1}|\hmmhid_{t})
|
|
|
+# p(\hmmhid_{t+2}|\hmmhid_{t+1})
|
|
|
+# \cdots
|
|
|
+# p(\hmmhid_{t+h}|\hmmhid_{t+h-1})
|
|
|
+# \end{align}
|
|
|
+# See \cref{fig:dbn_inf_problems} for a summary of these distributions.
|
|
|
+#
|
|
|
+# In addition to comuting posterior marginals,
|
|
|
+# we may want to compute the most probable hidden sequence,
|
|
|
+# i.e., the joint MAP estimate
|
|
|
+# ```{math}
|
|
|
+# \arg \max_{\hmmhid_{1:T}} p(\hmmhid_{1:T}|\hmmobs_{1:T})
|
|
|
+# ```
|
|
|
+# or sample sequences from the posterior
|
|
|
+# ```{math}
|
|
|
+# \hmmhid_{1:T} \sim p(\hmmhid_{1:T}|\hmmobs_{1:T})
|
|
|
+# ```
|
|
|
+#
|
|
|
+# Algorithms for all these task are discussed in the following chapters,
|
|
|
+# since the details depend on the form of the SSM.
|
|
|
+#
|
|
|
+#
|
|
|
+#
|
|
|
+#
|
|
|
+#
|
|
|
+
|
|
|
+# ## Example: inference in the casino HMM
|
|
|
+#
|
|
|
+# We now illustrate filtering, smoothing and MAP decoding applied
|
|
|
+# to the casino HMM from {ref}`sec:casino`.
|
|
|
+#
|
|
|
+
|
|
|
+# In[7]:
|
|
|
+
|
|
|
+
|
|
|
+# Call inference engine
|
|
|
+
|
|
|
+filtered_dist, _, smoothed_dist, loglik = hmm.forward_backward(x_hist)
|
|
|
+map_path = hmm.viterbi(x_hist)
|
|
|
+
|
|
|
+
|
|
|
+# In[8]:
|
|
|
+
|
|
|
+
|
|
|
+# Find the span of timesteps that the simulated systems turns to be in state 1
|
|
|
+def find_dishonest_intervals(z_hist):
|
|
|
+ spans = []
|
|
|
+ x_init = 0
|
|
|
+ for t, _ in enumerate(z_hist[:-1]):
|
|
|
+ if z_hist[t + 1] == 0 and z_hist[t] == 1:
|
|
|
+ x_end = t
|
|
|
+ spans.append((x_init, x_end))
|
|
|
+ elif z_hist[t + 1] == 1 and z_hist[t] == 0:
|
|
|
+ x_init = t + 1
|
|
|
+ return spans
|
|
|
+
|
|
|
+
|
|
|
+# In[9]:
|
|
|
+
|
|
|
+
|
|
|
+# Plot posterior
|
|
|
+def plot_inference(inference_values, z_hist, ax, state=1, map_estimate=False):
|
|
|
+ n_samples = len(inference_values)
|
|
|
+ xspan = np.arange(1, n_samples + 1)
|
|
|
+ spans = find_dishonest_intervals(z_hist)
|
|
|
+ if map_estimate:
|
|
|
+ ax.step(xspan, inference_values, where="post")
|
|
|
+ else:
|
|
|
+ ax.plot(xspan, inference_values[:, state])
|
|
|
+
|
|
|
+ for span in spans:
|
|
|
+ ax.axvspan(*span, alpha=0.5, facecolor="tab:gray", edgecolor="none")
|
|
|
+ ax.set_xlim(1, n_samples)
|
|
|
+ # ax.set_ylim(0, 1)
|
|
|
+ ax.set_ylim(-0.1, 1.1)
|
|
|
+ ax.set_xlabel("Observation number")
|
|
|
+
|
|
|
+
|
|
|
+# In[10]:
|
|
|
+
|
|
|
+
|
|
|
+# Filtering
|
|
|
+fig, ax = plt.subplots()
|
|
|
+plot_inference(filtered_dist, z_hist, ax)
|
|
|
+ax.set_ylabel("p(loaded)")
|
|
|
+ax.set_title("Filtered")
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+# In[11]:
|
|
|
+
|
|
|
+
|
|
|
+# Smoothing
|
|
|
+fig, ax = plt.subplots()
|
|
|
+plot_inference(smoothed_dist, z_hist, ax)
|
|
|
+ax.set_ylabel("p(loaded)")
|
|
|
+ax.set_title("Smoothed")
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+# In[12]:
|
|
|
+
|
|
|
+
|
|
|
+# MAP estimation
|
|
|
+fig, ax = plt.subplots()
|
|
|
+plot_inference(map_path, z_hist, ax, map_estimate=True)
|
|
|
+ax.set_ylabel("MAP state")
|
|
|
+ax.set_title("Viterbi")
|
|
|
+
|
|
|
+
|
|
|
+# In[13]:
|
|
|
+
|
|
|
+
|
|
|
+# TODO: posterior samples
|
|
|
+
|
|
|
+
|
|
|
+# ## Example: inference in the tracking SSM
|
|
|
+#
|
|
|
+# We now illustrate filtering, smoothing and MAP decoding applied
|
|
|
+# to the 2d tracking HMM from {ref}`sec:tracking-lds`.
|