Kevin P Murphy 3 years ago
parent
commit
0352d267e0

BIN
_build/.doctrees/chapters/ssm/ssm.doctree


BIN
_build/.doctrees/environment.pickle


BIN
_build/html/_images/dbn-inference-problems.png


BIN
_build/html/_images/ssm_18_1.png


BIN
_build/html/_images/ssm_19_1.png


BIN
_build/html/_images/ssm_20_1.png


BIN
_build/html/_images/ssm_21_1.png


File diff suppressed because it is too large
+ 509 - 118
_build/html/_sources/chapters/ssm/ssm.ipynb


+ 1 - 1
_build/html/chapters/hmm/hmm.html

@@ -646,7 +646,7 @@ const thebe_selector_output = ".output, .cell_output"
 <p>We first create the “Ocassionally dishonest casino” model from <span id="id1">[<a class="reference internal" href="../../bib.html#id3" title="R. Durbin, S. Eddy, A. Krogh, and G. Mitchison. Biological Sequence Analysis: Probabilistic Models of Proteins and Nucleic Acids. Cambridge University Press, 1998.">DEKM98</a>]</span>.</p>
 <div class="figure align-default" id="casino-fig">
 <a class="reference internal image-reference" href="../../_images/casino.png"><img alt="../../_images/casino.png" src="../../_images/casino.png" style="width: 208.5px; height: 142.5px;" /></a>
-<p class="caption"><span class="caption-number">Fig. 6 </span><span class="caption-text">Illustration of the casino HMM.</span><a class="headerlink" href="#casino-fig" title="Permalink to this image">¶</a></p>
+<p class="caption"><span class="caption-number">Fig. 7 </span><span class="caption-text">Illustration of the casino HMM.</span><a class="headerlink" href="#casino-fig" title="Permalink to this image">¶</a></p>
 </div>
 <p>There are 2 hidden states, each of which emit 6 possible observations.</p>
 <div class="cell docutils container">

File diff suppressed because it is too large
+ 367 - 5
_build/html/chapters/ssm/ssm.html


BIN
_build/html/objects.inv


+ 14 - 26
_build/html/reports/ssm.log

@@ -17,35 +17,23 @@ Traceback (most recent call last):
     raise CellExecutionError.from_cell_and_msg(cell, exec_reply['content'])
 nbclient.exceptions.CellExecutionError: An error occurred while executing the following cell:
 ------------------
-# state transition matrix
-A = np.array([
-    [0.95, 0.05],
-    [0.10, 0.90]
-])
-
-# observation matrix
-B = np.array([
-    [1/6, 1/6, 1/6, 1/6, 1/6, 1/6], # fair die
-    [1/10, 1/10, 1/10, 1/10, 1/10, 5/10] # loaded die
-])
-
-pi, _ = normalize(np.array([1, 1]))
-pi = np.array(pi)
-
-
-(nstates, nobs) = np.shape(B)
+# MAP estimation
+fig, ax = plt.subplots()
+plot_inference(z_map, z_hist, ax, map_estimate=True)
+ax.set_ylabel("MAP state")
+ax.set_title("Viterbi")
 
 ------------------
 
 ---------------------------------------------------------------------------
 NameError                                 Traceback (most recent call last)
-<ipython-input-3-2f308bef5393> in <module>
-     11 ])
-     12 
----> 13 pi, _ = normalize(np.array([1, 1]))
-     14 pi = np.array(pi)
-     15 
-
-NameError: name 'normalize' is not defined
-NameError: name 'normalize' is not defined
+<ipython-input-12-d20416120056> in <module>
+      1 # MAP estimation
+      2 fig, ax = plt.subplots()
+----> 3 plot_inference(z_map, z_hist, ax, map_estimate=True)
+      4 ax.set_ylabel("MAP state")
+      5 ax.set_title("Viterbi")
+
+NameError: name 'z_map' is not defined
+NameError: name 'z_map' is not defined
 

+ 2 - 0
_build/html/root.html

@@ -448,6 +448,8 @@ in automatic differentiation and parallel computing.</p>
 <li class="toctree-l1"><a class="reference internal" href="chapters/scratch.html">Scratchpad</a></li>
 <li class="toctree-l1"><a class="reference internal" href="chapters/ssm/ssm.html">What are State Space Models?</a></li>
 <li class="toctree-l1"><a class="reference internal" href="chapters/ssm/ssm.html#hidden-markov-models">Hidden Markov Models</a></li>
+<li class="toctree-l1"><a class="reference internal" href="chapters/ssm/ssm.html#linear-gaussian-ssms">Linear Gaussian SSMs</a></li>
+<li class="toctree-l1"><a class="reference internal" href="chapters/ssm/ssm.html#inferential-goals">Inferential goals</a></li>
 <li class="toctree-l1"><a class="reference internal" href="chapters/hmm/hmm_index.html">Inference in discrete SSMs</a><ul>
 <li class="toctree-l2"><a class="reference internal" href="chapters/hmm/hmm.html">Hidden Markov Models</a></li>
 <li class="toctree-l2"><a class="reference internal" href="chapters/hmm/hmm_filter.html">HMM filtering (forwards algorithm)</a></li>

File diff suppressed because it is too large
+ 1 - 1
_build/html/searchindex.js


File diff suppressed because it is too large
+ 469 - 15
_build/jupyter_execute/chapters/ssm/ssm.ipynb


+ 244 - 10
_build/jupyter_execute/chapters/ssm/ssm.py

@@ -24,6 +24,12 @@ except:
     import jax
 
 try:
+    import distrax
+except:
+    get_ipython().run_line_magic('pip', 'install --upgrade  distrax')
+    import distrax
+
+try:
     import jsl
 except:
     get_ipython().run_line_magic('pip', 'install git+https://github.com/probml/jsl')
@@ -71,7 +77,6 @@ from jax.random import PRNGKey, split
 import inspect
 import inspect as py_inspect
 import rich
-
 from rich import inspect as r_inspect
 from rich import print as r_print
 
@@ -209,7 +214,9 @@ def print_source(fname):
 # 
 # %%%%
 # \newcommand{\hidden}{\vz}
-# \newcommand{\obs}{\vy}
+# \newcommand{\hid}{\hidden}
+# \newcommand{\observed}{\vy}
+# \newcommand{\obs}{\observed}
 # \newcommand{\inputs}{\vu}
 # \newcommand{\input}{\inputs}
 # 
@@ -279,10 +286,6 @@ def print_source(fname):
 # Illustration of an SSM as a graphical model.
 # ```
 # 
-# 
-# 
-# 
-
 # We often consider a simpler setting in which there
 # are no external inputs,
 # and the observations are conditionally independent of each other
@@ -306,6 +309,8 @@ def print_source(fname):
 # 
 # Illustration of a simplified SSM.
 # ```
+# 
+# 
 
 # (sec:hmm-intro)=
 # # Hidden Markov Models
@@ -325,7 +330,7 @@ def print_source(fname):
 # For an interactive introduction,
 # see https://nipunbatra.github.io/hmm/.
 
-# (sec:casino-ex)=
+# (sec:casino)=
 # ### Example: Casino HMM
 # 
 # To illustrate HMMs with categorical observation model,
@@ -342,7 +347,9 @@ def print_source(fname):
 # meaning each row sums to one.
 # We can visualize
 # the non-zero entries in the transition matrix by creating a state transition diagram,
-# as shown in {ref}`casino-fig`.
+# as shown in 
+# {numref}`Figure %s <casino-fig>`
+# %{ref}`casino-fig`.
 # 
 # ```{figure} /figures/casino.png
 # :scale: 50%
@@ -352,7 +359,7 @@ def print_source(fname):
 # ```
 # 
 # The  observation model
-# $p(\obs_t|\hiddden_t=j)$ has the form
+# $p(\obs_t|\hidden_t=j)$ has the form
 # ```{math}
 # p(\obs_t=k|\hidden_t=j) = \hmmObs_{jk} 
 # ```
@@ -367,7 +374,7 @@ def print_source(fname):
 # 
 # Collectively we denote all the parameters by $\vtheta=(\hmmTrans, \hmmObs, \hmmInit)$.
 # 
-# Now let us implement this model code.
+# Now let us implement this model in code.
 
 # In[3]:
 
@@ -388,3 +395,230 @@ pi = np.array([0.5, 0.5])
 
 (nstates, nobs) = np.shape(B)
 
+
+# In[4]:
+
+
+import distrax
+from distrax import HMM
+
+
+hmm = HMM(trans_dist=distrax.Categorical(probs=A),
+            init_dist=distrax.Categorical(probs=pi),
+            obs_dist=distrax.Categorical(probs=B))
+
+print(hmm)
+
+
+# 
+# Let's sample from the model. We will generate a sequence of latent states, $\hid_{1:T}$,
+# which we then convert to a sequence of observations, $\obs_{1:T}$.
+
+# In[5]:
+
+
+
+
+
+seed = 314
+n_samples = 300
+z_hist, x_hist = hmm.sample(seed=PRNGKey(seed), seq_len=n_samples)
+
+z_hist_str = "".join((np.array(z_hist) + 1).astype(str))[:60]
+x_hist_str = "".join((np.array(x_hist) + 1).astype(str))[:60]
+
+print("Printing sample observed/latent...")
+print(f"x: {x_hist_str}")
+print(f"z: {z_hist_str}")
+
+
+# In[6]:
+
+
+# Here is the source code for the sampling algorithm.
+
+print_source(hmm.sample)
+
+
+# Our primary goal will be to infer the latent state from the observations,
+# so we can detect if the casino is being dishonest or not. This will
+# affect how we choose to gamble our money.
+# We discuss various ways to perform this inference below.
+
+# # Linear Gaussian SSMs
+# 
+# Blah blah
+
+# (sec:tracking-lds)=
+# ## Example: model for 2d tracking
+# 
+# Blah blah
+
+# (sec:inference)=
+# # Inferential goals
+# 
+# ```{figure} /figures/dbn-inference-problems.png
+# :scale: 100%
+# :name: dbn-inference
+# 
+# Illustration of the different kinds of inference in an SSM.
+#  The main kinds of inference for state-space models.
+#     The shaded region is the interval for which we have data.
+#     The arrow represents the time step at which we want to perform inference.
+#     $t$ is the current time,  $T$ is the sequence length,
+# $\ell$ is the lag and $h$ is the prediction horizon.
+# ```
+# 
+# 
+# 
+# Given the sequence of observations, and a known model,
+# one of the main tasks with SSMs
+# to perform posterior inference,
+# about the hidden states; this is also called
+# state estimation.
+# At each time step $t$,
+# there are multiple forms of posterior we may be interested in computing,
+# including the following:
+# - the filtering distribution
+# $p(\hmmhid_t|\hmmobs_{1:t})$
+# - the smoothing distribution
+# $p(\hmmhid_t|\hmmobs_{1:T})$ (note that this conditions on future data $T>t$)
+# - the fixed-lag smoothing distribution
+# $p(\hmmhid_{t-\ell}|\hmmobs_{1:t})$ (note that this
+# infers $\ell$ steps in the past given data up to the present).
+# 
+# We may also want to compute the
+# predictive distribution $h$ steps into the future:
+# \begin{align}
+# p(\hmmobs_{t+h}|\hmmobs_{1:t})
+# &= \sum_{\hmmhid_{t+h}} p(\hmmobs_{t+h}|\hmmhid_{t+h}) p(\hmmhid_{t+h}|\hmmobs_{1:t})
+# \end{align}
+# where the hidden state predictive distribution is
+# \begin{align}
+# p(\hmmhid_{t+h}|\hmmobs_{1:t})
+# &= \sum_{\hmmhid_{t:t+h-1}}
+#  p(\hmmhid_t|\hmmobs_{1:t}) 
+#  p(\hmmhid_{t+1}|\hmmhid_{t})
+#  p(\hmmhid_{t+2}|\hmmhid_{t+1})
+# \cdots
+#  p(\hmmhid_{t+h}|\hmmhid_{t+h-1})
+# \end{align}
+# See \cref{fig:dbn_inf_problems} for a summary of these distributions.
+# 
+# In addition  to comuting posterior marginals,
+# we may want to compute the most probable hidden sequence,
+# i.e., the joint MAP estimate
+# ```{math}
+# \arg \max_{\hmmhid_{1:T}} p(\hmmhid_{1:T}|\hmmobs_{1:T})
+# ```
+# or sample sequences from the posterior
+# ```{math}
+# \hmmhid_{1:T} \sim p(\hmmhid_{1:T}|\hmmobs_{1:T})
+# ```
+# 
+# Algorithms for all these task are discussed in the following chapters,
+# since the details depend on the form of the SSM.
+# 
+# 
+# 
+# 
+# 
+
+# ## Example: inference in the casino HMM
+# 
+# We now illustrate filtering, smoothing and MAP decoding applied
+# to the casino HMM from {ref}`sec:casino`. 
+# 
+
+# In[7]:
+
+
+# Call inference engine
+
+filtered_dist, _, smoothed_dist, loglik = hmm.forward_backward(x_hist)
+map_path = hmm.viterbi(x_hist)
+
+
+# In[8]:
+
+
+# Find the span of timesteps that the    simulated systems turns to be in state 1
+def find_dishonest_intervals(z_hist):
+    spans = []
+    x_init = 0
+    for t, _ in enumerate(z_hist[:-1]):
+        if z_hist[t + 1] == 0 and z_hist[t] == 1:
+            x_end = t
+            spans.append((x_init, x_end))
+        elif z_hist[t + 1] == 1 and z_hist[t] == 0:
+            x_init = t + 1
+    return spans
+
+
+# In[9]:
+
+
+# Plot posterior
+def plot_inference(inference_values, z_hist, ax, state=1, map_estimate=False):
+    n_samples = len(inference_values)
+    xspan = np.arange(1, n_samples + 1)
+    spans = find_dishonest_intervals(z_hist)
+    if map_estimate:
+        ax.step(xspan, inference_values, where="post")
+    else:
+        ax.plot(xspan, inference_values[:, state])
+
+    for span in spans:
+        ax.axvspan(*span, alpha=0.5, facecolor="tab:gray", edgecolor="none")
+    ax.set_xlim(1, n_samples)
+    # ax.set_ylim(0, 1)
+    ax.set_ylim(-0.1, 1.1)
+    ax.set_xlabel("Observation number")
+
+
+# In[10]:
+
+
+# Filtering
+fig, ax = plt.subplots()
+plot_inference(filtered_dist, z_hist, ax)
+ax.set_ylabel("p(loaded)")
+ax.set_title("Filtered")
+
+
+ 
+
+
+# In[11]:
+
+
+# Smoothing
+fig, ax = plt.subplots()
+plot_inference(smoothed_dist, z_hist, ax)
+ax.set_ylabel("p(loaded)")
+ax.set_title("Smoothed")
+
+
+ 
+
+
+# In[12]:
+
+
+# MAP estimation
+fig, ax = plt.subplots()
+plot_inference(map_path, z_hist, ax, map_estimate=True)
+ax.set_ylabel("MAP state")
+ax.set_title("Viterbi")
+
+
+# In[13]:
+
+
+# TODO: posterior samples
+
+
+# ## Example: inference in the tracking SSM
+# 
+# We now illustrate filtering, smoothing and MAP decoding applied
+# to the 2d tracking HMM from {ref}`sec:tracking-lds`. 

BIN
_build/jupyter_execute/chapters/ssm/ssm_18_1.png


BIN
_build/jupyter_execute/chapters/ssm/ssm_19_1.png


BIN
_build/jupyter_execute/chapters/ssm/ssm_20_1.png


BIN
_build/jupyter_execute/chapters/ssm/ssm_21_1.png


File diff suppressed because it is too large
+ 371 - 18
chapters/ssm/ssm.ipynb


BIN
figures/dbn-inference-problems.png