|
- <!DOCTYPE html>
- <html>
- <head>
- <meta charset="utf-8" />
- <meta name="viewport" content="width=device-width, initial-scale=1.0" />
- <title>What are State Space Models? — State Space Models: A Modern Approach</title>
-
- <link href="../_static/css/theme.css" rel="stylesheet">
- <link href="../_static/css/index.ff1ffe594081f20da1ef19478df9384b.css" rel="stylesheet">
-
- <link rel="stylesheet"
- href="../_static/vendor/fontawesome/5.13.0/css/all.min.css">
- <link rel="preload" as="font" type="font/woff2" crossorigin
- href="../_static/vendor/fontawesome/5.13.0/webfonts/fa-solid-900.woff2">
- <link rel="preload" as="font" type="font/woff2" crossorigin
- href="../_static/vendor/fontawesome/5.13.0/webfonts/fa-brands-400.woff2">
-
-
-
- <link rel="stylesheet" type="text/css" href="../_static/pygments.css" />
- <link rel="stylesheet" type="text/css" href="../_static/sphinx-book-theme.css?digest=c3fdc42140077d1ad13ad2f1588a4309" />
- <link rel="stylesheet" type="text/css" href="../_static/togglebutton.css" />
- <link rel="stylesheet" type="text/css" href="../_static/copybutton.css" />
- <link rel="stylesheet" type="text/css" href="../_static/mystnb.css" />
- <link rel="stylesheet" type="text/css" href="../_static/sphinx-thebe.css" />
- <link rel="stylesheet" type="text/css" href="../_static/panels-main.c949a650a448cc0ae9fd3441c0e17fb0.css" />
- <link rel="stylesheet" type="text/css" href="../_static/panels-variables.06eb56fa6e07937060861dad626602ad.css" />
-
- <link rel="preload" as="script" href="../_static/js/index.be7d3bbb2ef33a8344ce.js">
- <script data-url_root="../" id="documentation_options" src="../_static/documentation_options.js"></script>
- <script src="../_static/jquery.js"></script>
- <script src="../_static/underscore.js"></script>
- <script src="../_static/doctools.js"></script>
- <script src="../_static/clipboard.min.js"></script>
- <script src="../_static/copybutton.js"></script>
- <script>let toggleHintShow = 'Click to show';</script>
- <script>let toggleHintHide = 'Click to hide';</script>
- <script>let toggleOpenOnPrint = 'true';</script>
- <script src="../_static/togglebutton.js"></script>
- <script>var togglebuttonSelector = '.toggle, .admonition.dropdown, .tag_hide_input div.cell_input, .tag_hide-input div.cell_input, .tag_hide_output div.cell_output, .tag_hide-output div.cell_output, .tag_hide_cell.cell, .tag_hide-cell.cell';</script>
- <script src="../_static/sphinx-book-theme.d59cb220de22ca1c485ebbdc042f0030.js"></script>
- <script>const THEBE_JS_URL = "https://unpkg.com/thebe@0.8.2/lib/index.js"
- const thebe_selector = ".thebe,.cell"
- const thebe_selector_input = "pre"
- const thebe_selector_output = ".output, .cell_output"
- </script>
- <script async="async" src="../_static/sphinx-thebe.js"></script>
- <script>window.MathJax = {"options": {"processHtmlClass": "tex2jax_process|mathjax_process|math|output_area"}}</script>
- <script defer="defer" src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
- <link rel="index" title="Index" href="../genindex.html" />
- <link rel="search" title="Search" href="../search.html" />
- <meta name="viewport" content="width=device-width, initial-scale=1" />
- <meta name="docsearch:language" content="None">
-
- <!-- Google Analytics -->
-
- </head>
- <body data-spy="scroll" data-target="#bd-toc-nav" data-offset="80">
-
- <div class="container-fluid" id="banner"></div>
-
- <div class="container-xl">
- <div class="row">
-
- <div class="col-12 col-md-3 bd-sidebar site-navigation show" id="site-navigation">
-
- <div class="navbar-brand-box">
- <a class="navbar-brand text-wrap" href="../index.html">
-
-
-
- <h1 class="site-logo" id="site-title">State Space Models: A Modern Approach</h1>
-
- </a>
- </div><form class="bd-search d-flex align-items-center" action="../search.html" method="get">
- <i class="icon fas fa-search"></i>
- <input type="search" class="form-control" name="q" id="search-input" placeholder="Search this book..." aria-label="Search this book..." autocomplete="off" >
- </form><nav class="bd-links" id="bd-docs-nav" aria-label="Main">
- <div class="bd-toc-item active">
- <ul class="nav bd-sidenav">
- <li class="toctree-l1">
- <a class="reference internal" href="../root.html">
- State Space Models: A Modern Approach
- </a>
- </li>
- </ul>
- <ul class="nav bd-sidenav">
- <li class="toctree-l1">
- <a class="reference internal" href="../chapters/scratch.html">
- Scratchpad
- </a>
- </li>
- <li class="toctree-l1 has-children">
- <a class="reference internal" href="../chapters/ssm/ssm_index.html">
- State Space Models
- </a>
- <input class="toctree-checkbox" id="toctree-checkbox-1" name="toctree-checkbox-1" type="checkbox"/>
- <label for="toctree-checkbox-1">
- <i class="fas fa-chevron-down">
- </i>
- </label>
- <ul>
- <li class="toctree-l2">
- <a class="reference internal" href="../chapters/ssm/ssm_intro.html">
- What are State Space Models?
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../chapters/ssm/hmm.html">
- Hidden Markov Models
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../chapters/ssm/lds.html">
- Linear Gaussian SSMs
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../chapters/ssm/nlds.html">
- Nonlinear Gaussian SSMs
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../chapters/ssm/inference.html">
- Inferential goals
- </a>
- </li>
- </ul>
- </li>
- <li class="toctree-l1 has-children">
- <a class="reference internal" href="../chapters/hmm/hmm_index.html">
- Hidden Markov Models
- </a>
- <input class="toctree-checkbox" id="toctree-checkbox-2" name="toctree-checkbox-2" type="checkbox"/>
- <label for="toctree-checkbox-2">
- <i class="fas fa-chevron-down">
- </i>
- </label>
- <ul>
- <li class="toctree-l2">
- <a class="reference internal" href="../chapters/hmm/hmm.html">
- Hidden Markov Models
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../chapters/hmm/hmm_filter.html">
- HMM filtering (forwards algorithm)
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../chapters/hmm/hmm_smoother.html">
- HMM smoothing (forwards-backwards algorithm)
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../chapters/hmm/hmm_viterbi.html">
- Viterbi algorithm
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../chapters/hmm/hmm_parallel.html">
- Parallel HMM smoothing
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../chapters/hmm/hmm_sampling.html">
- Forwards-filtering backwards-sampling algorithm
- </a>
- </li>
- </ul>
- </li>
- <li class="toctree-l1 has-children">
- <a class="reference internal" href="../chapters/lgssm/lgssm_index.html">
- Inference in linear-Gaussian SSMs
- </a>
- <input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
- <label for="toctree-checkbox-3">
- <i class="fas fa-chevron-down">
- </i>
- </label>
- <ul>
- <li class="toctree-l2">
- <a class="reference internal" href="../chapters/lgssm/kalman_filter.html">
- Kalman filtering
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../chapters/lgssm/kalman_smoother.html">
- Kalman (RTS) smoother
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../chapters/lgssm/kalman_parallel.html">
- Parallel Kalman Smoother
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../chapters/lgssm/kalman_sampling.html">
- Forwards-filtering backwards sampling
- </a>
- </li>
- </ul>
- </li>
- <li class="toctree-l1 has-children">
- <a class="reference internal" href="../chapters/extended/extended_index.html">
- Extended (linearized) methods
- </a>
- <input class="toctree-checkbox" id="toctree-checkbox-4" name="toctree-checkbox-4" type="checkbox"/>
- <label for="toctree-checkbox-4">
- <i class="fas fa-chevron-down">
- </i>
- </label>
- <ul>
- <li class="toctree-l2">
- <a class="reference internal" href="../chapters/extended/extended_filter.html">
- Extended Kalman filtering
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../chapters/extended/extended_smoother.html">
- Extended Kalman smoother
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../chapters/extended/extended_parallel.html">
- Parallel extended Kalman smoothing
- </a>
- </li>
- </ul>
- </li>
- <li class="toctree-l1 has-children">
- <a class="reference internal" href="../chapters/unscented/unscented_index.html">
- Unscented methods
- </a>
- <input class="toctree-checkbox" id="toctree-checkbox-5" name="toctree-checkbox-5" type="checkbox"/>
- <label for="toctree-checkbox-5">
- <i class="fas fa-chevron-down">
- </i>
- </label>
- <ul>
- <li class="toctree-l2">
- <a class="reference internal" href="../chapters/unscented/unscented_filter.html">
- Unscented filtering
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../chapters/unscented/unscented_smoother.html">
- Unscented smoothing
- </a>
- </li>
- </ul>
- </li>
- <li class="toctree-l1">
- <a class="reference internal" href="../chapters/quadrature/quadrature_index.html">
- Quadrature and cubature methods
- </a>
- </li>
- <li class="toctree-l1">
- <a class="reference internal" href="../chapters/postlin/postlin_index.html">
- Posterior linearization
- </a>
- </li>
- <li class="toctree-l1">
- <a class="reference internal" href="../chapters/adf/adf_index.html">
- Assumed Density Filtering
- </a>
- </li>
- <li class="toctree-l1">
- <a class="reference internal" href="../chapters/vi/vi_index.html">
- Variational inference
- </a>
- </li>
- <li class="toctree-l1">
- <a class="reference internal" href="../chapters/pf/pf_index.html">
- Particle filtering
- </a>
- </li>
- <li class="toctree-l1">
- <a class="reference internal" href="../chapters/smc/smc_index.html">
- Sequential Monte Carlo
- </a>
- </li>
- <li class="toctree-l1 has-children">
- <a class="reference internal" href="../chapters/learning/learning_index.html">
- Offline parameter estimation (learning)
- </a>
- <input class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" type="checkbox"/>
- <label for="toctree-checkbox-6">
- <i class="fas fa-chevron-down">
- </i>
- </label>
- <ul>
- <li class="toctree-l2">
- <a class="reference internal" href="../chapters/learning/em.html">
- Expectation Maximization (EM)
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../chapters/learning/sgd.html">
- Stochastic Gradient Descent (SGD)
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../chapters/learning/vb.html">
- Variational Bayes (VB)
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../chapters/learning/mcmc.html">
- Markov Chain Monte Carlo (MCMC)
- </a>
- </li>
- </ul>
- </li>
- <li class="toctree-l1">
- <a class="reference internal" href="../chapters/tracking/tracking_index.html">
- Multi-target tracking
- </a>
- </li>
- <li class="toctree-l1">
- <a class="reference internal" href="../chapters/ensemble/ensemble_index.html">
- Data assimilation using Ensemble Kalman filter
- </a>
- </li>
- <li class="toctree-l1">
- <a class="reference internal" href="../chapters/bnp/bnp_index.html">
- Bayesian non-parametric SSMs
- </a>
- </li>
- <li class="toctree-l1">
- <a class="reference internal" href="../chapters/changepoint/changepoint_index.html">
- Changepoint detection
- </a>
- </li>
- <li class="toctree-l1">
- <a class="reference internal" href="../chapters/timeseries/timeseries_index.html">
- Timeseries forecasting
- </a>
- </li>
- <li class="toctree-l1">
- <a class="reference internal" href="../chapters/gp/gp_index.html">
- Markovian Gaussian processes
- </a>
- </li>
- <li class="toctree-l1">
- <a class="reference internal" href="../chapters/ode/ode_index.html">
- Differential equations and SSMs
- </a>
- </li>
- <li class="toctree-l1">
- <a class="reference internal" href="../chapters/control/control_index.html">
- Optimal control
- </a>
- </li>
- <li class="toctree-l1">
- <a class="reference internal" href="../bib.html">
- Bibliography
- </a>
- </li>
- </ul>
- </div>
- </nav> <!-- To handle the deprecated key -->
- <div class="navbar_extra_footer">
- Powered by <a href="https://jupyterbook.org">Jupyter Book</a>
- </div>
- </div>
-
-
- <main class="col py-md-3 pl-md-4 bd-content overflow-auto" role="main">
-
- <div class="topbar container-xl fixed-top">
- <div class="topbar-contents row">
- <div class="col-12 col-md-3 bd-topbar-whitespace site-navigation show"></div>
- <div class="col pl-md-4 topbar-main">
-
- <button id="navbar-toggler" class="navbar-toggler ml-0" type="button" data-toggle="collapse"
- data-toggle="tooltip" data-placement="bottom" data-target=".site-navigation" aria-controls="navbar-menu"
- aria-expanded="true" aria-label="Toggle navigation" aria-controls="site-navigation"
- title="Toggle navigation" data-toggle="tooltip" data-placement="left">
- <i class="fas fa-bars"></i>
- <i class="fas fa-arrow-left"></i>
- <i class="fas fa-arrow-up"></i>
- </button>
-
-
- <div class="dropdown-buttons-trigger">
- <button id="dropdown-buttons-trigger" class="btn btn-secondary topbarbtn" aria-label="Download this page"><i
- class="fas fa-download"></i></button>
- <div class="dropdown-buttons">
- <!-- ipynb file if we had a myst markdown file -->
-
- <!-- Download raw file -->
- <a class="dropdown-buttons" href="../_sources/old/ssm_old.ipynb"><button type="button"
- class="btn btn-secondary topbarbtn" title="Download source file" data-toggle="tooltip"
- data-placement="left">.ipynb</button></a>
- <!-- Download PDF via print -->
- <button type="button" id="download-print" class="btn btn-secondary topbarbtn" title="Print to PDF"
- onclick="printPdf(this)" data-toggle="tooltip" data-placement="left">.pdf</button>
- </div>
- </div>
- <!-- Source interaction buttons -->
- <div class="dropdown-buttons-trigger">
- <button id="dropdown-buttons-trigger" class="btn btn-secondary topbarbtn"
- aria-label="Connect with source repository"><i class="fab fa-github"></i></button>
- <div class="dropdown-buttons sourcebuttons">
- <a class="repository-button"
- href="https://github.com/probml/ssm-book"><button type="button" class="btn btn-secondary topbarbtn"
- data-toggle="tooltip" data-placement="left" title="Source repository"><i
- class="fab fa-github"></i>repository</button></a>
- <a class="issues-button"
- href="https://github.com/probml/ssm-book/issues/new?title=Issue%20on%20page%20%2Fold/ssm_old.html&body=Your%20issue%20content%20here."><button
- type="button" class="btn btn-secondary topbarbtn" data-toggle="tooltip" data-placement="left"
- title="Open an issue"><i class="fas fa-lightbulb"></i>open issue</button></a>
-
- </div>
- </div>
- <!-- Full screen (wrap in <a> to have style consistency -->
- <a class="full-screen-button"><button type="button" class="btn btn-secondary topbarbtn" data-toggle="tooltip"
- data-placement="bottom" onclick="toggleFullScreen()" aria-label="Fullscreen mode"
- title="Fullscreen mode"><i
- class="fas fa-expand"></i></button></a>
- <!-- Launch buttons -->
- <div class="dropdown-buttons-trigger">
- <button id="dropdown-buttons-trigger" class="btn btn-secondary topbarbtn"
- aria-label="Launch interactive content"><i class="fas fa-rocket"></i></button>
- <div class="dropdown-buttons">
-
- <a class="binder-button" href="https://mybinder.org/v2/gh/probml/ssm-book/main?urlpath=tree/old/ssm_old.ipynb"><button type="button"
- class="btn btn-secondary topbarbtn" title="Launch Binder" data-toggle="tooltip"
- data-placement="left"><img class="binder-button-logo"
- src="../_static/images/logo_binder.svg"
- alt="Interact on binder">Binder</button></a>
-
-
-
- <a class="colab-button" href="https://colab.research.google.com/github/probml/ssm-book/blob/main/old/ssm_old.ipynb"><button type="button" class="btn btn-secondary topbarbtn"
- title="Launch Colab" data-toggle="tooltip" data-placement="left"><img class="colab-button-logo"
- src="../_static/images/logo_colab.png"
- alt="Interact on Colab">Colab</button></a>
-
-
- </div>
- </div>
- </div>
- <!-- Table of contents -->
- <div class="d-none d-md-block col-md-2 bd-toc show noprint">
-
- <div class="tocsection onthispage pt-5 pb-3">
- <i class="fas fa-list"></i> Contents
- </div>
- <nav id="bd-toc-nav" aria-label="Page">
- <ul class="visible nav section-nav flex-column">
- <li class="toc-h1 nav-item toc-entry">
- <a class="reference internal nav-link" href="#">
- What are State Space Models?
- </a>
- </li>
- <li class="toc-h1 nav-item toc-entry">
- <a class="reference internal nav-link" href="#hidden-markov-models">
- Hidden Markov Models
- </a>
- <ul class="visible nav section-nav flex-column">
- <li class="toc-h2 nav-item toc-entry">
- <a class="reference internal nav-link" href="#example-casino-hmm">
- Example: Casino HMM
- </a>
- </li>
- <li class="toc-h2 nav-item toc-entry">
- <a class="reference internal nav-link" href="#example-lillypad-hmm">
- Example: Lillypad HMM
- </a>
- </li>
- </ul>
- </li>
- <li class="toc-h1 nav-item toc-entry">
- <a class="reference internal nav-link" href="#linear-gaussian-ssms">
- Linear Gaussian SSMs
- </a>
- <ul class="visible nav section-nav flex-column">
- <li class="toc-h2 nav-item toc-entry">
- <a class="reference internal nav-link" href="#example-tracking-a-2d-point">
- Example: tracking a 2d point
- </a>
- </li>
- </ul>
- </li>
- <li class="toc-h1 nav-item toc-entry">
- <a class="reference internal nav-link" href="#nonlinear-gaussian-ssms">
- Nonlinear Gaussian SSMs
- </a>
- <ul class="visible nav section-nav flex-column">
- <li class="toc-h2 nav-item toc-entry">
- <a class="reference internal nav-link" href="#example-tracking-a-1d-pendulum">
- Example: tracking a 1d pendulum
- </a>
- </li>
- </ul>
- </li>
- <li class="toc-h1 nav-item toc-entry">
- <a class="reference internal nav-link" href="#inferential-goals">
- Inferential goals
- </a>
- <ul class="visible nav section-nav flex-column">
- <li class="toc-h2 nav-item toc-entry">
- <a class="reference internal nav-link" href="#example-inference-in-the-casino-hmm">
- Example: inference in the casino HMM
- </a>
- </li>
- <li class="toc-h2 nav-item toc-entry">
- <a class="reference internal nav-link" href="#example-inference-in-the-tracking-ssm">
- Example: inference in the tracking SSM
- </a>
- </li>
- </ul>
- </li>
- </ul>
- </nav>
- </div>
- </div>
- </div>
- <div id="main-content" class="row">
- <div class="col-12 col-md-9 pl-md-3 pr-md-0">
- <!-- Table of contents that is only displayed when printing the page -->
- <div id="jb-print-docs-body" class="onlyprint">
- <h1>What are State Space Models?</h1>
- <!-- Table of contents -->
- <div id="print-main-content">
- <div id="jb-print-toc">
-
- <div>
- <h2> Contents </h2>
- </div>
- <nav aria-label="Page">
- <ul class="visible nav section-nav flex-column">
- <li class="toc-h1 nav-item toc-entry">
- <a class="reference internal nav-link" href="#">
- What are State Space Models?
- </a>
- </li>
- <li class="toc-h1 nav-item toc-entry">
- <a class="reference internal nav-link" href="#hidden-markov-models">
- Hidden Markov Models
- </a>
- <ul class="visible nav section-nav flex-column">
- <li class="toc-h2 nav-item toc-entry">
- <a class="reference internal nav-link" href="#example-casino-hmm">
- Example: Casino HMM
- </a>
- </li>
- <li class="toc-h2 nav-item toc-entry">
- <a class="reference internal nav-link" href="#example-lillypad-hmm">
- Example: Lillypad HMM
- </a>
- </li>
- </ul>
- </li>
- <li class="toc-h1 nav-item toc-entry">
- <a class="reference internal nav-link" href="#linear-gaussian-ssms">
- Linear Gaussian SSMs
- </a>
- <ul class="visible nav section-nav flex-column">
- <li class="toc-h2 nav-item toc-entry">
- <a class="reference internal nav-link" href="#example-tracking-a-2d-point">
- Example: tracking a 2d point
- </a>
- </li>
- </ul>
- </li>
- <li class="toc-h1 nav-item toc-entry">
- <a class="reference internal nav-link" href="#nonlinear-gaussian-ssms">
- Nonlinear Gaussian SSMs
- </a>
- <ul class="visible nav section-nav flex-column">
- <li class="toc-h2 nav-item toc-entry">
- <a class="reference internal nav-link" href="#example-tracking-a-1d-pendulum">
- Example: tracking a 1d pendulum
- </a>
- </li>
- </ul>
- </li>
- <li class="toc-h1 nav-item toc-entry">
- <a class="reference internal nav-link" href="#inferential-goals">
- Inferential goals
- </a>
- <ul class="visible nav section-nav flex-column">
- <li class="toc-h2 nav-item toc-entry">
- <a class="reference internal nav-link" href="#example-inference-in-the-casino-hmm">
- Example: inference in the casino HMM
- </a>
- </li>
- <li class="toc-h2 nav-item toc-entry">
- <a class="reference internal nav-link" href="#example-inference-in-the-tracking-ssm">
- Example: inference in the tracking SSM
- </a>
- </li>
- </ul>
- </li>
- </ul>
- </nav>
- </div>
- </div>
- </div>
-
- <div>
-
- <div class="cell docutils container">
- <div class="cell_input docutils container">
- <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># meta-data does not work yet in VScode</span>
- <span class="c1"># https://github.com/microsoft/vscode-jupyter/issues/1121</span>
- <span class="p">{</span>
- <span class="s2">"tags"</span><span class="p">:</span> <span class="p">[</span>
- <span class="s2">"hide-cell"</span>
- <span class="p">]</span>
- <span class="p">}</span>
- <span class="c1">### Install necessary libraries</span>
- <span class="k">try</span><span class="p">:</span>
- <span class="kn">import</span> <span class="nn">jax</span>
- <span class="k">except</span><span class="p">:</span>
- <span class="c1"># For cuda version, see https://github.com/google/jax#installation</span>
- <span class="o">%</span><span class="k">pip</span> install --upgrade "jax[cpu]"
- <span class="kn">import</span> <span class="nn">jax</span>
- <span class="k">try</span><span class="p">:</span>
- <span class="kn">import</span> <span class="nn">distrax</span>
- <span class="k">except</span><span class="p">:</span>
- <span class="o">%</span><span class="k">pip</span> install --upgrade distrax
- <span class="kn">import</span> <span class="nn">distrax</span>
- <span class="k">try</span><span class="p">:</span>
- <span class="kn">import</span> <span class="nn">jsl</span>
- <span class="k">except</span><span class="p">:</span>
- <span class="o">%</span><span class="k">pip</span> install git+https://github.com/probml/jsl
- <span class="kn">import</span> <span class="nn">jsl</span>
- <span class="k">try</span><span class="p">:</span>
- <span class="kn">import</span> <span class="nn">rich</span>
- <span class="k">except</span><span class="p">:</span>
- <span class="o">%</span><span class="k">pip</span> install rich
- <span class="kn">import</span> <span class="nn">rich</span>
- </pre></div>
- </div>
- </div>
- </div>
- <div class="cell docutils container">
- <div class="cell_input docutils container">
- <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="p">{</span>
- <span class="s2">"tags"</span><span class="p">:</span> <span class="p">[</span>
- <span class="s2">"hide-cell"</span>
- <span class="p">]</span>
- <span class="p">}</span>
- <span class="c1">### Import standard libraries</span>
- <span class="kn">import</span> <span class="nn">abc</span>
- <span class="kn">from</span> <span class="nn">dataclasses</span> <span class="kn">import</span> <span class="n">dataclass</span>
- <span class="kn">import</span> <span class="nn">functools</span>
- <span class="kn">import</span> <span class="nn">itertools</span>
- <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span><span class="p">,</span> <span class="n">Callable</span><span class="p">,</span> <span class="n">NamedTuple</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Union</span><span class="p">,</span> <span class="n">Tuple</span>
- <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span>
- <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
- <span class="kn">import</span> <span class="nn">jax</span>
- <span class="kn">import</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span>
- <span class="kn">from</span> <span class="nn">jax</span> <span class="kn">import</span> <span class="n">lax</span><span class="p">,</span> <span class="n">vmap</span><span class="p">,</span> <span class="n">jit</span><span class="p">,</span> <span class="n">grad</span>
- <span class="kn">from</span> <span class="nn">jax.scipy.special</span> <span class="kn">import</span> <span class="n">logit</span>
- <span class="kn">from</span> <span class="nn">jax.nn</span> <span class="kn">import</span> <span class="n">softmax</span>
- <span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">partial</span>
- <span class="kn">from</span> <span class="nn">jax.random</span> <span class="kn">import</span> <span class="n">PRNGKey</span><span class="p">,</span> <span class="n">split</span>
- <span class="kn">import</span> <span class="nn">inspect</span>
- <span class="kn">import</span> <span class="nn">inspect</span> <span class="k">as</span> <span class="nn">py_inspect</span>
- <span class="kn">import</span> <span class="nn">rich</span>
- <span class="kn">from</span> <span class="nn">rich</span> <span class="kn">import</span> <span class="n">inspect</span> <span class="k">as</span> <span class="n">r_inspect</span>
- <span class="kn">from</span> <span class="nn">rich</span> <span class="kn">import</span> <span class="nb">print</span> <span class="k">as</span> <span class="n">r_print</span>
- <span class="k">def</span> <span class="nf">print_source</span><span class="p">(</span><span class="n">fname</span><span class="p">):</span>
- <span class="n">r_print</span><span class="p">(</span><span class="n">py_inspect</span><span class="o">.</span><span class="n">getsource</span><span class="p">(</span><span class="n">fname</span><span class="p">))</span>
- </pre></div>
- </div>
- </div>
- </div>
- <div class="math notranslate nohighlight">
- \[ \begin{align}\begin{aligned}\newcommand\floor[1]{\lfloor#1\rfloor}\\\newcommand{\real}{\mathbb{R}}\\% Numbers
- \newcommand{\vzero}{\boldsymbol{0}}
- \newcommand{\vone}{\boldsymbol{1}}\\% Greek https://www.latex-tutorial.com/symbols/greek-alphabet/
- \newcommand{\valpha}{\boldsymbol{\alpha}}
- \newcommand{\vbeta}{\boldsymbol{\beta}}
- \newcommand{\vchi}{\boldsymbol{\chi}}
- \newcommand{\vdelta}{\boldsymbol{\delta}}
- \newcommand{\vDelta}{\boldsymbol{\Delta}}
- \newcommand{\vepsilon}{\boldsymbol{\epsilon}}
- \newcommand{\vzeta}{\boldsymbol{\zeta}}
- \newcommand{\vXi}{\boldsymbol{\Xi}}
- \newcommand{\vell}{\boldsymbol{\ell}}
- \newcommand{\veta}{\boldsymbol{\eta}}
- %\newcommand{\vEta}{\boldsymbol{\Eta}}
- \newcommand{\vgamma}{\boldsymbol{\gamma}}
- \newcommand{\vGamma}{\boldsymbol{\Gamma}}
- \newcommand{\vmu}{\boldsymbol{\mu}}
- \newcommand{\vmut}{\boldsymbol{\tilde{\mu}}}
- \newcommand{\vnu}{\boldsymbol{\nu}}
- \newcommand{\vkappa}{\boldsymbol{\kappa}}
- \newcommand{\vlambda}{\boldsymbol{\lambda}}
- \newcommand{\vLambda}{\boldsymbol{\Lambda}}
- \newcommand{\vLambdaBar}{\overline{\vLambda}}
- %\newcommand{\vnu}{\boldsymbol{\nu}}
- \newcommand{\vomega}{\boldsymbol{\omega}}
- \newcommand{\vOmega}{\boldsymbol{\Omega}}
- \newcommand{\vphi}{\boldsymbol{\phi}}
- \newcommand{\vvarphi}{\boldsymbol{\varphi}}
- \newcommand{\vPhi}{\boldsymbol{\Phi}}
- \newcommand{\vpi}{\boldsymbol{\pi}}
- \newcommand{\vPi}{\boldsymbol{\Pi}}
- \newcommand{\vpsi}{\boldsymbol{\psi}}
- \newcommand{\vPsi}{\boldsymbol{\Psi}}
- \newcommand{\vrho}{\boldsymbol{\rho}}
- \newcommand{\vtheta}{\boldsymbol{\theta}}
- \newcommand{\vthetat}{\boldsymbol{\tilde{\theta}}}
- \newcommand{\vTheta}{\boldsymbol{\Theta}}
- \newcommand{\vsigma}{\boldsymbol{\sigma}}
- \newcommand{\vSigma}{\boldsymbol{\Sigma}}
- \newcommand{\vSigmat}{\boldsymbol{\tilde{\Sigma}}}
- \newcommand{\vsigmoid}{\vsigma}
- \newcommand{\vtau}{\boldsymbol{\tau}}
- \newcommand{\vxi}{\boldsymbol{\xi}}\\
- % Lower Roman (Vectors)
- \newcommand{\va}{\mathbf{a}}
- \newcommand{\vb}{\mathbf{b}}
- \newcommand{\vBt}{\mathbf{\tilde{B}}}
- \newcommand{\vc}{\mathbf{c}}
- \newcommand{\vct}{\mathbf{\tilde{c}}}
- \newcommand{\vd}{\mathbf{d}}
- \newcommand{\ve}{\mathbf{e}}
- \newcommand{\vf}{\mathbf{f}}
- \newcommand{\vg}{\mathbf{g}}
- \newcommand{\vh}{\mathbf{h}}
- %\newcommand{\myvh}{\mathbf{h}}
- \newcommand{\vi}{\mathbf{i}}
- \newcommand{\vj}{\mathbf{j}}
- \newcommand{\vk}{\mathbf{k}}
- \newcommand{\vl}{\mathbf{l}}
- \newcommand{\vm}{\mathbf{m}}
- \newcommand{\vn}{\mathbf{n}}
- \newcommand{\vo}{\mathbf{o}}
- \newcommand{\vp}{\mathbf{p}}
- \newcommand{\vq}{\mathbf{q}}
- \newcommand{\vr}{\mathbf{r}}
- \newcommand{\vs}{\mathbf{s}}
- \newcommand{\vt}{\mathbf{t}}
- \newcommand{\vu}{\mathbf{u}}
- \newcommand{\vv}{\mathbf{v}}
- \newcommand{\vw}{\mathbf{w}}
- \newcommand{\vws}{\vw_s}
- \newcommand{\vwt}{\mathbf{\tilde{w}}}
- \newcommand{\vWt}{\mathbf{\tilde{W}}}
- \newcommand{\vwh}{\hat{\vw}}
- \newcommand{\vx}{\mathbf{x}}
- %\newcommand{\vx}{\mathbf{x}}
- \newcommand{\vxt}{\mathbf{\tilde{x}}}
- \newcommand{\vy}{\mathbf{y}}
- \newcommand{\vyt}{\mathbf{\tilde{y}}}
- \newcommand{\vz}{\mathbf{z}}
- %\newcommand{\vzt}{\mathbf{\tilde{z}}}\\
- % Upper Roman (Matrices)
- \newcommand{\vA}{\mathbf{A}}
- \newcommand{\vB}{\mathbf{B}}
- \newcommand{\vC}{\mathbf{C}}
- \newcommand{\vD}{\mathbf{D}}
- \newcommand{\vE}{\mathbf{E}}
- \newcommand{\vF}{\mathbf{F}}
- \newcommand{\vG}{\mathbf{G}}
- \newcommand{\vH}{\mathbf{H}}
- \newcommand{\vI}{\mathbf{I}}
- \newcommand{\vJ}{\mathbf{J}}
- \newcommand{\vK}{\mathbf{K}}
- \newcommand{\vL}{\mathbf{L}}
- \newcommand{\vM}{\mathbf{M}}
- \newcommand{\vMt}{\mathbf{\tilde{M}}}
- \newcommand{\vN}{\mathbf{N}}
- \newcommand{\vO}{\mathbf{O}}
- \newcommand{\vP}{\mathbf{P}}
- \newcommand{\vQ}{\mathbf{Q}}
- \newcommand{\vR}{\mathbf{R}}
- \newcommand{\vS}{\mathbf{S}}
- \newcommand{\vT}{\mathbf{T}}
- \newcommand{\vU}{\mathbf{U}}
- \newcommand{\vV}{\mathbf{V}}
- \newcommand{\vW}{\mathbf{W}}
- \newcommand{\vX}{\mathbf{X}}
- %\newcommand{\vXs}{\vX_{\vs}}
- \newcommand{\vXs}{\vX_{s}}
- \newcommand{\vXt}{\mathbf{\tilde{X}}}
- \newcommand{\vY}{\mathbf{Y}}
- \newcommand{\vZ}{\mathbf{Z}}
- \newcommand{\vZt}{\mathbf{\tilde{Z}}}
- \newcommand{\vzt}{\mathbf{\tilde{z}}}\\
- %%%%
- \newcommand{\hidden}{\vz}
- \newcommand{\hid}{\hidden}
- \newcommand{\observed}{\vy}
- \newcommand{\obs}{\observed}
- \newcommand{\inputs}{\vu}
- \newcommand{\input}{\inputs}\\\newcommand{\hmmTrans}{\vA}
- \newcommand{\hmmObs}{\vB}
- \newcommand{\hmmInit}{\vpi}
- \newcommand{\hmmhid}{\hidden}
- \newcommand{\hmmobs}{\obs}\\\newcommand{\ldsDyn}{\vA}
- \newcommand{\ldsObs}{\vC}
- \newcommand{\ldsDynIn}{\vB}
- \newcommand{\ldsObsIn}{\vD}
- \newcommand{\ldsDynNoise}{\vQ}
- \newcommand{\ldsObsNoise}{\vR}\\\newcommand{\ssmDynFn}{f}
- \newcommand{\ssmObsFn}{h}\\
- %%%
- \newcommand{\gauss}{\mathcal{N}}\\\newcommand{\diag}{\mathrm{diag}}\end{aligned}\end{align} \]</div>
- <div class="tex2jax_ignore mathjax_ignore section" id="what-are-state-space-models">
- <span id="sec-ssm-intro"></span><h1>What are State Space Models?<a class="headerlink" href="#what-are-state-space-models" title="Permalink to this headline">¶</a></h1>
- <p>A state space model or SSM
- is a partially observed Markov model,
- in which the hidden state, <span class="math notranslate nohighlight">\(\hidden_t\)</span>,
- evolves over time according to a Markov process,
- possibly conditional on external inputs or controls <span class="math notranslate nohighlight">\(\input_t\)</span>,
- and each hidden state generates some
- observations <span class="math notranslate nohighlight">\(\obs_t\)</span> at each time step.
- (In this book, we mostly focus on discrete time systems,
- although we consider the continuous-time case in XXX.)
- We get to see the observations, but not the hidden state.
- Our main goal is to infer the hidden state given the observations.
- However, we can also use the model to predict future observations,
- by first predicting future hidden states, and then predicting
- what observations they might generate.
- By using a hidden state <span class="math notranslate nohighlight">\(\hidden_t\)</span>
- to represent the past observations, <span class="math notranslate nohighlight">\(\obs_{1:t-1}\)</span>,
- the model can have ``infinite’’ memory,
- unlike a standard Markov model.</p>
- <p>Formally we can define an SSM
- as the following joint distribution:</p>
- <div class="math notranslate nohighlight" id="equation-eq-ssm-ar">
- <span class="eqno">()<a class="headerlink" href="#equation-eq-ssm-ar" title="Permalink to this equation">¶</a></span>\[p(\hmmobs_{1:T},\hmmhid_{1:T}|\inputs_{1:T})
- = \left[ p(\hmmhid_1|\inputs_1) \prod_{t=2}^{T}
- p(\hmmhid_t|\hmmhid_{t-1},\inputs_t) \right]
- \left[ \prod_{t=1}^T p(\hmmobs_t|\hmmhid_t, \inputs_t, \hmmobs_{t-1}) \right]\]</div>
- <p>where <span class="math notranslate nohighlight">\(p(\hmmhid_t|\hmmhid_{t-1},\inputs_t)\)</span> is the
- transition model,
- <span class="math notranslate nohighlight">\(p(\hmmobs_t|\hmmhid_t, \inputs_t, \hmmobs_{t-1})\)</span> is the
- observation model,
- and <span class="math notranslate nohighlight">\(\inputs_{t}\)</span> is an optional input or action.
- See <code class="xref std std-numref docutils literal notranslate"><span class="pre">Figure</span> <span class="pre">%s</span></code>
- for an illustration of the corresponding graphical model.</p>
- <div class="figure align-default" id="ssm-ar">
- <a class="reference internal image-reference" href="../_images/SSM-AR-inputs.png"><img alt="../_images/SSM-AR-inputs.png" src="../_images/SSM-AR-inputs.png" style="width: 152.0px; height: 165.0px;" /></a>
- <p class="caption"><span class="caption-text">Illustration of an SSM as a graphical model.</span><a class="headerlink" href="#ssm-ar" title="Permalink to this image">¶</a></p>
- </div>
- <p>We often consider a simpler setting in which the
- observations are conditionally independent of each other
- (rather than having Markovian dependencies) given the hidden state.
- In this case the joint simplifies to</p>
- <div class="math notranslate nohighlight" id="equation-eq-ssm-input">
- <span class="eqno">()<a class="headerlink" href="#equation-eq-ssm-input" title="Permalink to this equation">¶</a></span>\[p(\hmmobs_{1:T},\hmmhid_{1:T}|\inputs_{1:T})
- = \left[ p(\hmmhid_1|\inputs_1) \prod_{t=2}^{T}
- p(\hmmhid_t|\hmmhid_{t-1},\inputs_t) \right]
- \left[ \prod_{t=1}^T p(\hmmobs_t|\hmmhid_t, \inputs_t) \right]\]</div>
- <p>Sometimes there are no external inputs, so the model further
- simplifies to the following unconditional generative model:</p>
- <div class="math notranslate nohighlight" id="equation-eq-ssm-no-input">
- <span class="eqno">()<a class="headerlink" href="#equation-eq-ssm-no-input" title="Permalink to this equation">¶</a></span>\[p(\hmmobs_{1:T},\hmmhid_{1:T})
- = \left[ p(\hmmhid_1) \prod_{t=2}^{T}
- p(\hmmhid_t|\hmmhid_{t-1}) \right]
- \left[ \prod_{t=1}^T p(\hmmobs_t|\hmmhid_t) \right]\]</div>
- <p>See <code class="xref std std-numref docutils literal notranslate"><span class="pre">Figure</span> <span class="pre">%s</span></code>
- for an illustration of the corresponding graphical model.</p>
- <div class="figure align-default" id="ssm-simplified">
- <a class="reference internal image-reference" href="../_images/SSM-simplified.png"><img alt="../_images/SSM-simplified.png" src="../_images/SSM-simplified.png" style="width: 136.0px; height: 98.0px;" /></a>
- <p class="caption"><span class="caption-text">Illustration of a simplified SSM.</span><a class="headerlink" href="#ssm-simplified" title="Permalink to this image">¶</a></p>
- </div>
- </div>
- <div class="tex2jax_ignore mathjax_ignore section" id="hidden-markov-models">
- <span id="sec-hmm-intro"></span><h1>Hidden Markov Models<a class="headerlink" href="#hidden-markov-models" title="Permalink to this headline">¶</a></h1>
- <p>In this section, we discuss the
- hidden Markov model or HMM,
- which is a state space model in which the hidden states
- are discrete, so <span class="math notranslate nohighlight">\(\hmmhid_t \in \{1,\ldots, K\}\)</span>.
- The observations may be discrete,
- <span class="math notranslate nohighlight">\(\hmmobs_t \in \{1,\ldots, C\}\)</span>,
- or continuous,
- <span class="math notranslate nohighlight">\(\hmmobs_t \in \real^D\)</span>,
- or some combination,
- as we illustrate below.
- More details can be found in e.g.,
- <span id="id1">[<a class="reference internal" href="../bib.html#id34" title="O. Cappe, E. Moulines, and T. Ryden. Inference in Hidden Markov Models. Springer, 2005.">CMR05</a>, <a class="reference internal" href="../bib.html#id33" title="A. Fraser. Hidden Markov Models and Dynamical Systems. SIAM Press, 2008.">Fra08</a>, <a class="reference internal" href="../bib.html#id32" title="L. R. Rabiner. A tutorial on Hidden Markov Models and selected applications in speech recognition. Proc. of the IEEE, 77(2):257–286, 1989.">Rab89</a>]</span>.
- For an interactive introduction,
- see <a class="reference external" href="https://nipunbatra.github.io/hmm/">https://nipunbatra.github.io/hmm/</a>.</p>
- <div class="section" id="example-casino-hmm">
- <h2>Example: Casino HMM<a class="headerlink" href="#example-casino-hmm" title="Permalink to this headline">¶</a></h2>
- <p>To illustrate HMMs with categorical observation model,
- we consider the “Ocassionally dishonest casino” model from <span id="id2">[<a class="reference internal" href="../bib.html#id3" title="R. Durbin, S. Eddy, A. Krogh, and G. Mitchison. Biological Sequence Analysis: Probabilistic Models of Proteins and Nucleic Acids. Cambridge University Press, 1998.">DEKM98</a>]</span>.
- There are 2 hidden states, representing whether the dice being used in the casino is fair or loaded.
- Each state defines a distribution over the 6 possible observations.</p>
- <p>The transition model is denoted by</p>
- <div class="math notranslate nohighlight">
- \[p(z_t=j|z_{t-1}=i) = \hmmTrans_{ij}\]</div>
- <p>Here the <span class="math notranslate nohighlight">\(i\)</span>’th row of <span class="math notranslate nohighlight">\(\vA\)</span> corresponds to the outgoing distribution from state <span class="math notranslate nohighlight">\(i\)</span>.
- This is a row stochastic matrix,
- meaning each row sums to one.
- We can visualize
- the non-zero entries in the transition matrix by creating a state transition diagram,
- as shown in
- <code class="xref std std-numref docutils literal notranslate"><span class="pre">Figure</span> <span class="pre">%s</span></code></p>
- <div class="figure align-default" id="casino-fig">
- <a class="reference internal image-reference" href="../_images/casino.png"><img alt="../_images/casino.png" src="../_images/casino.png" style="width: 208.5px; height: 142.5px;" /></a>
- <p class="caption"><span class="caption-text">Illustration of the casino HMM.</span><a class="headerlink" href="#casino-fig" title="Permalink to this image">¶</a></p>
- </div>
- <p>The observation model
- <span class="math notranslate nohighlight">\(p(\obs_t|\hidden_t=j)\)</span> has the form</p>
- <div class="math notranslate nohighlight">
- \[p(\obs_t=k|\hidden_t=j) = \hmmObs_{jk} \]</div>
- <p>This is represented by the histograms associated with each
- state in <a class="reference internal" href="#casino-fig"><span class="std std-ref">Illustration of the casino HMM.</span></a>.</p>
- <p>Finally,
- the initial state distribution is denoted by</p>
- <div class="math notranslate nohighlight">
- \[p(z_1=j) = \hmmInit_j\]</div>
- <p>Collectively we denote all the parameters by <span class="math notranslate nohighlight">\(\vtheta=(\hmmTrans, \hmmObs, \hmmInit)\)</span>.</p>
- <p>Now let us implement this model in code.</p>
- <div class="cell docutils container">
- <div class="cell_input docutils container">
- <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># state transition matrix</span>
- <span class="n">A</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span>
- <span class="p">[</span><span class="mf">0.95</span><span class="p">,</span> <span class="mf">0.05</span><span class="p">],</span>
- <span class="p">[</span><span class="mf">0.10</span><span class="p">,</span> <span class="mf">0.90</span><span class="p">]</span>
- <span class="p">])</span>
- <span class="c1"># observation matrix</span>
- <span class="n">B</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span>
- <span class="p">[</span><span class="mi">1</span><span class="o">/</span><span class="mi">6</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">6</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">6</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">6</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">6</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">6</span><span class="p">],</span> <span class="c1"># fair die</span>
- <span class="p">[</span><span class="mi">1</span><span class="o">/</span><span class="mi">10</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">10</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">10</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">10</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">10</span><span class="p">,</span> <span class="mi">5</span><span class="o">/</span><span class="mi">10</span><span class="p">]</span> <span class="c1"># loaded die</span>
- <span class="p">])</span>
- <span class="n">pi</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">])</span>
- <span class="p">(</span><span class="n">nstates</span><span class="p">,</span> <span class="n">nobs</span><span class="p">)</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">B</span><span class="p">)</span>
- </pre></div>
- </div>
- </div>
- </div>
- <div class="cell docutils container">
- <div class="cell_input docutils container">
- <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">distrax</span>
- <span class="kn">from</span> <span class="nn">distrax</span> <span class="kn">import</span> <span class="n">HMM</span>
- <span class="n">hmm</span> <span class="o">=</span> <span class="n">HMM</span><span class="p">(</span><span class="n">trans_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">Categorical</span><span class="p">(</span><span class="n">probs</span><span class="o">=</span><span class="n">A</span><span class="p">),</span>
- <span class="n">init_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">Categorical</span><span class="p">(</span><span class="n">probs</span><span class="o">=</span><span class="n">pi</span><span class="p">),</span>
- <span class="n">obs_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">Categorical</span><span class="p">(</span><span class="n">probs</span><span class="o">=</span><span class="n">B</span><span class="p">))</span>
- <span class="nb">print</span><span class="p">(</span><span class="n">hmm</span><span class="p">)</span>
- </pre></div>
- </div>
- </div>
- <div class="cell_output docutils container">
- <div class="output stderr highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
- </pre></div>
- </div>
- <div class="output stream highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span><distrax._src.utils.hmm.HMM object at 0x7feaf905e250>
- </pre></div>
- </div>
- </div>
- </div>
- <p>Let’s sample from the model. We will generate a sequence of latent states, <span class="math notranslate nohighlight">\(\hid_{1:T}\)</span>,
- which we then convert to a sequence of observations, <span class="math notranslate nohighlight">\(\obs_{1:T}\)</span>.</p>
- <div class="cell docutils container">
- <div class="cell_input docutils container">
- <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">seed</span> <span class="o">=</span> <span class="mi">314</span>
- <span class="n">n_samples</span> <span class="o">=</span> <span class="mi">300</span>
- <span class="n">z_hist</span><span class="p">,</span> <span class="n">x_hist</span> <span class="o">=</span> <span class="n">hmm</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="n">seed</span><span class="o">=</span><span class="n">PRNGKey</span><span class="p">(</span><span class="n">seed</span><span class="p">),</span> <span class="n">seq_len</span><span class="o">=</span><span class="n">n_samples</span><span class="p">)</span>
- <span class="n">z_hist_str</span> <span class="o">=</span> <span class="s2">""</span><span class="o">.</span><span class="n">join</span><span class="p">((</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">z_hist</span><span class="p">)</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="nb">str</span><span class="p">))[:</span><span class="mi">60</span><span class="p">]</span>
- <span class="n">x_hist_str</span> <span class="o">=</span> <span class="s2">""</span><span class="o">.</span><span class="n">join</span><span class="p">((</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">x_hist</span><span class="p">)</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="nb">str</span><span class="p">))[:</span><span class="mi">60</span><span class="p">]</span>
- <span class="nb">print</span><span class="p">(</span><span class="s2">"Printing sample observed/latent..."</span><span class="p">)</span>
- <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"x: </span><span class="si">{</span><span class="n">x_hist_str</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
- <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"z: </span><span class="si">{</span><span class="n">z_hist_str</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
- </pre></div>
- </div>
- </div>
- <div class="cell_output docutils container">
- <div class="output stream highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>Printing sample observed/latent...
- x: 633665342652353616444236412331351246651613325161656366246242
- z: 222222211111111111111111111111111111111222111111112222211111
- </pre></div>
- </div>
- </div>
- </div>
- <div class="cell docutils container">
- <div class="cell_input docutils container">
- <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># Here is the source code for the sampling algorithm.</span>
- <span class="n">print_source</span><span class="p">(</span><span class="n">hmm</span><span class="o">.</span><span class="n">sample</span><span class="p">)</span>
- </pre></div>
- </div>
- </div>
- <div class="cell_output docutils container">
- <div class="output text_html"><pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"> def sample<span style="font-weight: bold">(</span>self,
- *,
- seed: chex.PRNGKey,
- seq_len: chex.Array<span style="font-weight: bold">)</span> -> Tuple:
- <span style="color: #008000; text-decoration-color: #008000">""</span>"Sample from this HMM.
- Samples an observation of given length according to this
- Hidden Markov Model and gives the sequence of the hidden states
- as well as the observation.
- Args:
- seed: Random key of shape <span style="font-weight: bold">(</span><span style="color: #000080; text-decoration-color: #000080; font-weight: bold">2</span>,<span style="font-weight: bold">)</span> and dtype uint32.
- seq_len: The length of the observation sequence.
- Returns:
- Tuple of hidden state sequence, and observation sequence.
- <span style="color: #008000; text-decoration-color: #008000">""</span>"
- rng_key, rng_init = jax.random.split<span style="font-weight: bold">(</span>seed<span style="font-weight: bold">)</span>
- initial_state = self._init_dist.sample<span style="font-weight: bold">(</span><span style="color: #808000; text-decoration-color: #808000">seed</span>=<span style="color: #800080; text-decoration-color: #800080">rng_init</span><span style="font-weight: bold">)</span>
- def draw_state<span style="font-weight: bold">(</span>prev_state, key<span style="font-weight: bold">)</span>:
- state = self._trans_dist.sample<span style="font-weight: bold">(</span><span style="color: #808000; text-decoration-color: #808000">seed</span>=<span style="color: #800080; text-decoration-color: #800080">key</span><span style="font-weight: bold">)</span>
- return state, state
- rng_state, rng_obs = jax.random.split<span style="font-weight: bold">(</span>rng_key<span style="font-weight: bold">)</span>
- keys = jax.random.split<span style="font-weight: bold">(</span>rng_state, seq_len - <span style="color: #000080; text-decoration-color: #000080; font-weight: bold">1</span><span style="font-weight: bold">)</span>
- _, states = jax.lax.scan<span style="font-weight: bold">(</span>draw_state, initial_state, keys<span style="font-weight: bold">)</span>
- states = jnp.append<span style="font-weight: bold">(</span>initial_state, states<span style="font-weight: bold">)</span>
- def draw_obs<span style="font-weight: bold">(</span>state, key<span style="font-weight: bold">)</span>:
- return self._obs_dist.sample<span style="font-weight: bold">(</span><span style="color: #808000; text-decoration-color: #808000">seed</span>=<span style="color: #800080; text-decoration-color: #800080">key</span><span style="font-weight: bold">)</span>
- keys = jax.random.split<span style="font-weight: bold">(</span>rng_obs, seq_len<span style="font-weight: bold">)</span>
- obs_seq = jax.vmap<span style="font-weight: bold">(</span>draw_obs, <span style="color: #808000; text-decoration-color: #808000">in_axes</span>=<span style="font-weight: bold">(</span><span style="color: #000080; text-decoration-color: #000080; font-weight: bold">0</span>, <span style="color: #000080; text-decoration-color: #000080; font-weight: bold">0</span><span style="font-weight: bold">))(</span>states, keys<span style="font-weight: bold">)</span>
- return states, obs_seq
- </pre>
- </div></div>
- </div>
- <p>Our primary goal will be to infer the latent state from the observations,
- so we can detect if the casino is being dishonest or not. This will
- affect how we choose to gamble our money.
- We discuss various ways to perform this inference below.</p>
- </div>
- <div class="section" id="example-lillypad-hmm">
- <span id="sec-lillypad"></span><h2>Example: Lillypad HMM<a class="headerlink" href="#example-lillypad-hmm" title="Permalink to this headline">¶</a></h2>
- <p>If <span class="math notranslate nohighlight">\(\obs_t\)</span> is continuous, it is common to use a Gaussian
- observation model:</p>
- <div class="math notranslate nohighlight">
- \[p(\obs_t|\hidden_t=j) = \gauss(\obs_t|\vmu_j,\vSigma_j)\]</div>
- <p>This is sometimes called a Gaussian HMM.</p>
- <p>As a simple example, suppose we have an HMM with 3 hidden states,
- each of which generates a 2d Gaussian.
- We can represent these Gaussian distributions are 2d ellipses,
- as we show below.
- We call these ``lilly pads’’, because of their shape.
- We can imagine a frog hopping from one lilly pad to another.
- (This analogy is due to the late Sam Roweis.)
- The frog will stay on a pad for a while (corresponding to remaining in the same
- discrete state <span class="math notranslate nohighlight">\(\hidden_t\)</span>), and then jump to a new pad
- (corresponding to a transition to a new state).
- The data we see are just the 2d points (e.g., water droplets)
- coming from near the pad that the frog is currently on.
- Thus this model is like a Gaussian mixture model,
- in that it generates clusters of observations,
- except now there is temporal correlation between the data points.</p>
- <p>Let us now illustrate this model in code.</p>
- <div class="cell docutils container">
- <div class="cell_input docutils container">
- <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># Let us create the model</span>
- <span class="n">initial_probs</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.2</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">])</span>
- <span class="c1"># transition matrix</span>
- <span class="n">A</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span>
- <span class="p">[</span><span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.4</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">],</span>
- <span class="p">[</span><span class="mf">0.1</span><span class="p">,</span> <span class="mf">0.6</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">],</span>
- <span class="p">[</span><span class="mf">0.2</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">]</span>
- <span class="p">])</span>
- <span class="c1"># Observation model</span>
- <span class="n">mu_collection</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span>
- <span class="p">[</span><span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">],</span>
- <span class="p">[</span><span class="mf">0.8</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">],</span>
- <span class="p">[</span><span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.8</span><span class="p">]</span>
- <span class="p">])</span>
- <span class="n">S1</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="mf">1.1</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">]])</span>
- <span class="n">S2</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="mf">0.3</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.5</span><span class="p">],</span> <span class="p">[</span><span class="o">-</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">1.3</span><span class="p">]])</span>
- <span class="n">S3</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="mf">0.8</span><span class="p">,</span> <span class="mf">0.4</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.4</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">]])</span>
- <span class="n">cov_collection</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">S1</span><span class="p">,</span> <span class="n">S2</span><span class="p">,</span> <span class="n">S3</span><span class="p">])</span> <span class="o">/</span> <span class="mi">60</span>
- <span class="kn">import</span> <span class="nn">tensorflow_probability</span> <span class="k">as</span> <span class="nn">tfp</span>
- <span class="k">if</span> <span class="kc">False</span><span class="p">:</span>
- <span class="n">hmm</span> <span class="o">=</span> <span class="n">HMM</span><span class="p">(</span><span class="n">trans_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">Categorical</span><span class="p">(</span><span class="n">probs</span><span class="o">=</span><span class="n">A</span><span class="p">),</span>
- <span class="n">init_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">Categorical</span><span class="p">(</span><span class="n">probs</span><span class="o">=</span><span class="n">initial_probs</span><span class="p">),</span>
- <span class="n">obs_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">MultivariateNormalFullCovariance</span><span class="p">(</span>
- <span class="n">loc</span><span class="o">=</span><span class="n">mu_collection</span><span class="p">,</span> <span class="n">covariance_matrix</span><span class="o">=</span><span class="n">cov_collection</span><span class="p">))</span>
- <span class="k">else</span><span class="p">:</span>
- <span class="n">hmm</span> <span class="o">=</span> <span class="n">HMM</span><span class="p">(</span><span class="n">trans_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">Categorical</span><span class="p">(</span><span class="n">probs</span><span class="o">=</span><span class="n">A</span><span class="p">),</span>
- <span class="n">init_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">Categorical</span><span class="p">(</span><span class="n">probs</span><span class="o">=</span><span class="n">initial_probs</span><span class="p">),</span>
- <span class="n">obs_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">as_distribution</span><span class="p">(</span>
- <span class="n">tfp</span><span class="o">.</span><span class="n">substrates</span><span class="o">.</span><span class="n">jax</span><span class="o">.</span><span class="n">distributions</span><span class="o">.</span><span class="n">MultivariateNormalFullCovariance</span><span class="p">(</span><span class="n">loc</span><span class="o">=</span><span class="n">mu_collection</span><span class="p">,</span>
- <span class="n">covariance_matrix</span><span class="o">=</span><span class="n">cov_collection</span><span class="p">)))</span>
- <span class="nb">print</span><span class="p">(</span><span class="n">hmm</span><span class="p">)</span>
- </pre></div>
- </div>
- </div>
- <div class="cell_output docutils container">
- <div class="output stream highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span><distrax._src.utils.hmm.HMM object at 0x7feb19733220>
- </pre></div>
- </div>
- </div>
- </div>
- <div class="cell docutils container">
- <div class="cell_input docutils container">
- <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">n_samples</span><span class="p">,</span> <span class="n">seed</span> <span class="o">=</span> <span class="mi">50</span><span class="p">,</span> <span class="mi">10</span>
- <span class="n">samples_state</span><span class="p">,</span> <span class="n">samples_obs</span> <span class="o">=</span> <span class="n">hmm</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="n">seed</span><span class="o">=</span><span class="n">PRNGKey</span><span class="p">(</span><span class="n">seed</span><span class="p">),</span> <span class="n">seq_len</span><span class="o">=</span><span class="n">n_samples</span><span class="p">)</span>
- <span class="nb">print</span><span class="p">(</span><span class="n">samples_state</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
- <span class="nb">print</span><span class="p">(</span><span class="n">samples_obs</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
- </pre></div>
- </div>
- </div>
- <div class="cell_output docutils container">
- <div class="output stream highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>(50,)
- (50, 2)
- </pre></div>
- </div>
- </div>
- </div>
- <div class="cell docutils container">
- <div class="cell_input docutils container">
- <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># Let's plot the observed data in 2d</span>
- <span class="n">xmin</span><span class="p">,</span> <span class="n">xmax</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span>
- <span class="n">ymin</span><span class="p">,</span> <span class="n">ymax</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="mf">1.2</span>
- <span class="n">colors</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"tab:green"</span><span class="p">,</span> <span class="s2">"tab:blue"</span><span class="p">,</span> <span class="s2">"tab:red"</span><span class="p">]</span>
- <span class="k">def</span> <span class="nf">plot_2dhmm</span><span class="p">(</span><span class="n">hmm</span><span class="p">,</span> <span class="n">samples_obs</span><span class="p">,</span> <span class="n">samples_state</span><span class="p">,</span> <span class="n">colors</span><span class="p">,</span> <span class="n">ax</span><span class="p">,</span> <span class="n">xmin</span><span class="p">,</span> <span class="n">xmax</span><span class="p">,</span> <span class="n">ymin</span><span class="p">,</span> <span class="n">ymax</span><span class="p">,</span> <span class="n">step</span><span class="o">=</span><span class="mf">1e-2</span><span class="p">):</span>
- <span class="n">obs_dist</span> <span class="o">=</span> <span class="n">hmm</span><span class="o">.</span><span class="n">obs_dist</span>
- <span class="n">color_sample</span> <span class="o">=</span> <span class="p">[</span><span class="n">colors</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">samples_state</span><span class="p">]</span>
- <span class="n">xs</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">xmin</span><span class="p">,</span> <span class="n">xmax</span><span class="p">,</span> <span class="n">step</span><span class="p">)</span>
- <span class="n">ys</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">ymin</span><span class="p">,</span> <span class="n">ymax</span><span class="p">,</span> <span class="n">step</span><span class="p">)</span>
- <span class="n">v_prob</span> <span class="o">=</span> <span class="n">vmap</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">obs_dist</span><span class="o">.</span><span class="n">prob</span><span class="p">(</span><span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">])),</span> <span class="n">in_axes</span><span class="o">=</span><span class="p">(</span><span class="kc">None</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span>
- <span class="n">z</span> <span class="o">=</span> <span class="n">vmap</span><span class="p">(</span><span class="n">v_prob</span><span class="p">,</span> <span class="n">in_axes</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="kc">None</span><span class="p">))(</span><span class="n">xs</span><span class="p">,</span> <span class="n">ys</span><span class="p">)</span>
- <span class="n">grid</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">mgrid</span><span class="p">[</span><span class="n">xmin</span><span class="p">:</span><span class="n">xmax</span><span class="p">:</span><span class="n">step</span><span class="p">,</span> <span class="n">ymin</span><span class="p">:</span><span class="n">ymax</span><span class="p">:</span><span class="n">step</span><span class="p">]</span>
- <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">color</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">colors</span><span class="p">):</span>
- <span class="n">ax</span><span class="o">.</span><span class="n">contour</span><span class="p">(</span><span class="o">*</span><span class="n">grid</span><span class="p">,</span> <span class="n">z</span><span class="p">[:,</span> <span class="p">:,</span> <span class="n">k</span><span class="p">],</span> <span class="n">levels</span><span class="o">=</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">colors</span><span class="o">=</span><span class="n">color</span><span class="p">,</span> <span class="n">linewidths</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span>
- <span class="n">ax</span><span class="o">.</span><span class="n">text</span><span class="p">(</span><span class="o">*</span><span class="p">(</span><span class="n">obs_dist</span><span class="o">.</span><span class="n">mean</span><span class="p">()[</span><span class="n">k</span><span class="p">]</span> <span class="o">+</span> <span class="mf">0.13</span><span class="p">),</span> <span class="sa">f</span><span class="s2">"$k$=</span><span class="si">{</span><span class="n">k</span> <span class="o">+</span> <span class="mi">1</span><span class="si">}</span><span class="s2">"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">13</span><span class="p">,</span> <span class="n">horizontalalignment</span><span class="o">=</span><span class="s2">"right"</span><span class="p">)</span>
- <span class="n">ax</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="o">*</span><span class="n">samples_obs</span><span class="o">.</span><span class="n">T</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="s2">"black"</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.3</span><span class="p">,</span> <span class="n">zorder</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
- <span class="n">ax</span><span class="o">.</span><span class="n">scatter</span><span class="p">(</span><span class="o">*</span><span class="n">samples_obs</span><span class="o">.</span><span class="n">T</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="n">color_sample</span><span class="p">,</span> <span class="n">s</span><span class="o">=</span><span class="mi">30</span><span class="p">,</span> <span class="n">zorder</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.8</span><span class="p">)</span>
- <span class="k">return</span> <span class="n">ax</span><span class="p">,</span> <span class="n">color_sample</span>
- <span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">()</span>
- <span class="n">_</span><span class="p">,</span> <span class="n">color_sample</span> <span class="o">=</span> <span class="n">plot_2dhmm</span><span class="p">(</span><span class="n">hmm</span><span class="p">,</span> <span class="n">samples_obs</span><span class="p">,</span> <span class="n">samples_state</span><span class="p">,</span> <span class="n">colors</span><span class="p">,</span> <span class="n">ax</span><span class="p">,</span> <span class="n">xmin</span><span class="p">,</span> <span class="n">xmax</span><span class="p">,</span> <span class="n">ymin</span><span class="p">,</span> <span class="n">ymax</span><span class="p">)</span>
- </pre></div>
- </div>
- </div>
- <div class="cell_output docutils container">
- <img alt="../_images/ssm_old_15_0.png" src="../_images/ssm_old_15_0.png" />
- </div>
- </div>
- <div class="cell docutils container">
- <div class="cell_input docutils container">
- <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># Let's plot the hidden state sequence</span>
- <span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">()</span>
- <span class="n">ax</span><span class="o">.</span><span class="n">step</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">n_samples</span><span class="p">),</span> <span class="n">samples_state</span><span class="p">,</span> <span class="n">where</span><span class="o">=</span><span class="s2">"post"</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="s2">"black"</span><span class="p">,</span> <span class="n">linewidth</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.3</span><span class="p">)</span>
- <span class="n">ax</span><span class="o">.</span><span class="n">scatter</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">n_samples</span><span class="p">),</span> <span class="n">samples_state</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="n">color_sample</span><span class="p">,</span> <span class="n">zorder</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span>
- </pre></div>
- </div>
- </div>
- <div class="cell_output docutils container">
- <div class="output text_plain highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span><matplotlib.collections.PathCollection at 0x7feb1957c0a0>
- </pre></div>
- </div>
- <img alt="../_images/ssm_old_16_1.png" src="../_images/ssm_old_16_1.png" />
- </div>
- </div>
- </div>
- </div>
- <div class="tex2jax_ignore mathjax_ignore section" id="linear-gaussian-ssms">
- <span id="sec-lds-intro"></span><h1>Linear Gaussian SSMs<a class="headerlink" href="#linear-gaussian-ssms" title="Permalink to this headline">¶</a></h1>
- <p>Consider the state space model in
- <a class="reference internal" href="#equation-eq-ssm-ar">()</a>
- where we assume the observations are conditionally iid given the
- hidden states and inputs (i.e. there are no auto-regressive dependencies
- between the observables).
- We can rewrite this model as
- a stochastic nonlinear dynamical system (NLDS)
- by defining the distribution of the next hidden state
- as a deterministic function of the past state
- plus random process noise <span class="math notranslate nohighlight">\(\vepsilon_t\)</span></p>
- <div class="amsmath math notranslate nohighlight" id="equation-3536df60-5800-49ba-99b0-cfe3995291c6">
- <span class="eqno">()<a class="headerlink" href="#equation-3536df60-5800-49ba-99b0-cfe3995291c6" title="Permalink to this equation">¶</a></span>\[\begin{align}
- \hmmhid_t &= \ssmDynFn(\hmmhid_{t-1}, \inputs_t, \vepsilon_t)
- \end{align}\]</div>
- <p>where <span class="math notranslate nohighlight">\(\vepsilon_t\)</span> is drawn from the distribution such
- that the induced distribution
- on <span class="math notranslate nohighlight">\(\hmmhid_t\)</span> matches <span class="math notranslate nohighlight">\(p(\hmmhid_t|\hmmhid_{t-1}, \inputs_t)\)</span>.
- Similarly we can rewrite the observation distributions
- as a deterministic function of the hidden state
- plus observation noise <span class="math notranslate nohighlight">\(\veta_t\)</span>:</p>
- <div class="amsmath math notranslate nohighlight" id="equation-a6cb71e2-fb89-4cd0-b11f-43d39e5b68fa">
- <span class="eqno">()<a class="headerlink" href="#equation-a6cb71e2-fb89-4cd0-b11f-43d39e5b68fa" title="Permalink to this equation">¶</a></span>\[\begin{align}
- \hmmobs_t &= \ssmObsFn(\hmmhid_{t}, \inputs_t, \veta_t)
- \end{align}\]</div>
- <p>If we assume additive Gaussian noise,
- the model becomes</p>
- <div class="amsmath math notranslate nohighlight" id="equation-5dba5e0d-f668-43c5-8d30-ccde8cf1d8b5">
- <span class="eqno">()<a class="headerlink" href="#equation-5dba5e0d-f668-43c5-8d30-ccde8cf1d8b5" title="Permalink to this equation">¶</a></span>\[\begin{align}
- \hmmhid_t &= \ssmDynFn(\hmmhid_{t-1}, \inputs_t) + \vepsilon_t \\
- \hmmobs_t &= \ssmObsFn(\hmmhid_{t}, \inputs_t) + \veta_t
- \end{align}\]</div>
- <p>where <span class="math notranslate nohighlight">\(\vepsilon_t \sim \gauss(\vzero,\vQ_t)\)</span>
- and <span class="math notranslate nohighlight">\(\veta_t \sim \gauss(\vzero,\vR_t)\)</span>.
- We will call these Gaussian SSMs.</p>
- <p>If we additionally assume
- the transition function <span class="math notranslate nohighlight">\(\ssmDynFn\)</span>
- and the observation function <span class="math notranslate nohighlight">\(\ssmObsFn\)</span> are both linear,
- then we can rewrite the model as follows:</p>
- <div class="amsmath math notranslate nohighlight" id="equation-ae1eaadf-e90f-49d3-b9f3-0520cc68759c">
- <span class="eqno">()<a class="headerlink" href="#equation-ae1eaadf-e90f-49d3-b9f3-0520cc68759c" title="Permalink to this equation">¶</a></span>\[\begin{align}
- p(\hmmhid_t|\hmmhid_{t-1},\inputs_t) &= \gauss(\hmmhid_t|\ldsDyn_t \hmmhid_{t-1}
- + \ldsDynIn_t \inputs_t, \vQ_t)
- \\
- p(\hmmobs_t|\hmmhid_t,\inputs_t) &= \gauss(\hmmobs_t|\ldsObs_t \hmmhid_{t}
- + \ldsObsIn_t \inputs_t, \vR_t)
- \end{align}\]</div>
- <p>This is called a
- linear-Gaussian state space model
- (LG-SSM),
- or a
- linear dynamical system (LDS).
- We usually assume the parameters are independent of time, in which case
- the model is said to be time-invariant or homogeneous.</p>
- <div class="section" id="example-tracking-a-2d-point">
- <span id="sec-kalman-tracking"></span><span id="sec-tracking-lds"></span><h2>Example: tracking a 2d point<a class="headerlink" href="#example-tracking-a-2d-point" title="Permalink to this headline">¶</a></h2>
- <p>Consider an object moving in <span class="math notranslate nohighlight">\(\real^2\)</span>.
- Let the state be
- the position and velocity of the object,
- $<span class="math notranslate nohighlight">\(\vz_t =\begin{pmatrix} u_t & \dot{u}_t & v_t & \dot{v}_t \end{pmatrix}\)</span><span class="math notranslate nohighlight">\(.
- (We use \)</span>u<span class="math notranslate nohighlight">\( and \)</span>v$ for the two coordinates,
- to avoid confusion with the state and observation variables.)
- If we use Euler discretization,
- the dynamics become</p>
- <div class="amsmath math notranslate nohighlight" id="equation-567370ff-6316-4f30-a817-15050de99ec9">
- <span class="eqno">()<a class="headerlink" href="#equation-567370ff-6316-4f30-a817-15050de99ec9" title="Permalink to this equation">¶</a></span>\[\begin{align}
- \underbrace{\begin{pmatrix} u_t\\ \dot{u}_t \\ v_t \\ \dot{v}_t \end{pmatrix}}_{\vz_t}
- =
- \underbrace{
- \begin{pmatrix}
- 1 & 0 & \Delta & 0 \\
- 0 & 1 & 0 & \Delta\\
- 0 & 0 & 1 & 0 \\
- 0 & 0 & 0 & 1
- \end{pmatrix}
- }_{\ldsDyn}
- \
- \underbrace{\begin{pmatrix} u_{t-1} \\ \dot{u}_{t-1} \\ v_{t-1} \\ \dot{v}_{t-1} \end{pmatrix}}_{\vz_{t-1}}
- + \vepsilon_t
- \end{align}\]</div>
- <p>where <span class="math notranslate nohighlight">\(\vepsilon_t \sim \gauss(\vzero,\vQ)\)</span> is
- the process noise.</p>
- <p>Let us assume
- that the process noise is
- a white noise process added to the velocity components
- of the state, but not to the location.
- (This is known as a random accelerations model.)
- We can approximate the resulting process in discrete time by assuming
- <span class="math notranslate nohighlight">\(\vQ = \diag(0, q, 0, q)\)</span>.
- (See <span id="id3">[<a class="reference internal" href="../bib.html#id18" title="Simo Sarkka. Bayesian Filtering and Smoothing. Cambridge University Press, 2013. URL: https://users.aalto.fi/~ssarkka/pub/cup_book_online_20131111.pdf.">Sar13</a>]</span> p60 for a more accurate way
- to convert the continuous time process to discrete time.)</p>
- <p>Now suppose that at each discrete time point we
- observe the location,
- corrupted by Gaussian noise.
- Thus the observation model becomes</p>
- <div class="amsmath math notranslate nohighlight" id="equation-bb5d3cb2-5e61-472e-a04a-66081f35b5e8">
- <span class="eqno">()<a class="headerlink" href="#equation-bb5d3cb2-5e61-472e-a04a-66081f35b5e8" title="Permalink to this equation">¶</a></span>\[\begin{align}
- \underbrace{\begin{pmatrix} y_{1,t} \\ y_{2,t} \end{pmatrix}}_{\vy_t}
- &=
- \underbrace{
- \begin{pmatrix}
- 1 & 0 & 0 & 0 \\
- 0 & 0 & 1 & 0
- \end{pmatrix}
- }_{\ldsObs}
- \
- \underbrace{\begin{pmatrix} u_t\\ \dot{u}_t \\ v_t \\ \dot{v}_t \end{pmatrix}}_{\vz_t}
- + \veta_t
- \end{align}\]</div>
- <p>where <span class="math notranslate nohighlight">\(\veta_t \sim \gauss(\vzero,\vR)\)</span> is the \keywordDef{observation noise}.
- We see that the observation matrix <span class="math notranslate nohighlight">\(\ldsObs\)</span> simply ``extracts’’ the
- relevant parts of the state vector.</p>
- <p>Suppose we sample a trajectory and corresponding set
- of noisy observations from this model,
- <span class="math notranslate nohighlight">\((\vz_{1:T}, \vy_{1:T}) \sim p(\vz,\vy|\vtheta)\)</span>.
- (We use diagonal observation noise,
- <span class="math notranslate nohighlight">\(\vR = \diag(\sigma_1^2, \sigma_2^2)\)</span>.)
- The results are shown below.</p>
- <div class="cell docutils container">
- <div class="cell_input docutils container">
- <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">key</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">PRNGKey</span><span class="p">(</span><span class="mi">314</span><span class="p">)</span>
- <span class="n">timesteps</span> <span class="o">=</span> <span class="mi">15</span>
- <span class="n">delta</span> <span class="o">=</span> <span class="mf">1.0</span>
- <span class="n">A</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span>
- <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">delta</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span>
- <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">delta</span><span class="p">],</span>
- <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span>
- <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">]</span>
- <span class="p">])</span>
- <span class="n">C</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span>
- <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span>
- <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span>
- <span class="p">])</span>
- <span class="n">state_size</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">A</span><span class="o">.</span><span class="n">shape</span>
- <span class="n">observation_size</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">C</span><span class="o">.</span><span class="n">shape</span>
- <span class="n">Q</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="n">state_size</span><span class="p">)</span> <span class="o">*</span> <span class="mf">0.001</span>
- <span class="n">R</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="n">observation_size</span><span class="p">)</span> <span class="o">*</span> <span class="mf">1.0</span>
- <span class="c1"># Prior parameter distribution</span>
- <span class="n">mu0</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mi">8</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">])</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="nb">float</span><span class="p">)</span>
- <span class="n">Sigma0</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="n">state_size</span><span class="p">)</span> <span class="o">*</span> <span class="mf">1.0</span>
- <span class="kn">from</span> <span class="nn">jsl.lds.kalman_filter</span> <span class="kn">import</span> <span class="n">LDS</span><span class="p">,</span> <span class="n">smooth</span><span class="p">,</span> <span class="nb">filter</span>
- <span class="n">lds</span> <span class="o">=</span> <span class="n">LDS</span><span class="p">(</span><span class="n">A</span><span class="p">,</span> <span class="n">C</span><span class="p">,</span> <span class="n">Q</span><span class="p">,</span> <span class="n">R</span><span class="p">,</span> <span class="n">mu0</span><span class="p">,</span> <span class="n">Sigma0</span><span class="p">)</span>
- <span class="nb">print</span><span class="p">(</span><span class="n">lds</span><span class="p">)</span>
- </pre></div>
- </div>
- </div>
- <div class="cell_output docutils container">
- <div class="output stream highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>LDS(A=DeviceArray([[1., 0., 1., 0.],
- [0., 1., 0., 1.],
- [0., 0., 1., 0.],
- [0., 0., 0., 1.]], dtype=float32), C=DeviceArray([[1, 0, 0, 0],
- [0, 1, 0, 0]], dtype=int32), Q=DeviceArray([[0.001, 0. , 0. , 0. ],
- [0. , 0.001, 0. , 0. ],
- [0. , 0. , 0.001, 0. ],
- [0. , 0. , 0. , 0.001]], dtype=float32), R=DeviceArray([[1., 0.],
- [0., 1.]], dtype=float32), mu=DeviceArray([ 8., 10., 1., 0.], dtype=float32), Sigma=DeviceArray([[1., 0., 0., 0.],
- [0., 1., 0., 0.],
- [0., 0., 1., 0.],
- [0., 0., 0., 1.]], dtype=float32), state_offset=None, obs_offset=None, nstates=4, nobs=2)
- </pre></div>
- </div>
- </div>
- </div>
- <div class="cell docutils container">
- <div class="cell_input docutils container">
- <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">jsl.demos.plot_utils</span> <span class="kn">import</span> <span class="n">plot_ellipse</span>
- <span class="k">def</span> <span class="nf">plot_tracking_values</span><span class="p">(</span><span class="n">observed</span><span class="p">,</span> <span class="n">filtered</span><span class="p">,</span> <span class="n">cov_hist</span><span class="p">,</span> <span class="n">signal_label</span><span class="p">,</span> <span class="n">ax</span><span class="p">):</span>
- <span class="n">timesteps</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">observed</span><span class="o">.</span><span class="n">shape</span>
- <span class="n">ax</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">observed</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">observed</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">marker</span><span class="o">=</span><span class="s2">"o"</span><span class="p">,</span> <span class="n">linewidth</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
- <span class="n">markerfacecolor</span><span class="o">=</span><span class="s2">"none"</span><span class="p">,</span> <span class="n">markeredgewidth</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">markersize</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s2">"observed"</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="s2">"tab:green"</span><span class="p">)</span>
- <span class="n">ax</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="o">*</span><span class="n">filtered</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">T</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="n">signal_label</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="s2">"tab:red"</span><span class="p">,</span> <span class="n">marker</span><span class="o">=</span><span class="s2">"x"</span><span class="p">,</span> <span class="n">linewidth</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
- <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">timesteps</span><span class="p">,</span> <span class="mi">1</span><span class="p">):</span>
- <span class="n">covn</span> <span class="o">=</span> <span class="n">cov_hist</span><span class="p">[</span><span class="n">t</span><span class="p">][:</span><span class="mi">2</span><span class="p">,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]</span>
- <span class="n">plot_ellipse</span><span class="p">(</span><span class="n">covn</span><span class="p">,</span> <span class="n">filtered</span><span class="p">[</span><span class="n">t</span><span class="p">,</span> <span class="p">:</span><span class="mi">2</span><span class="p">],</span> <span class="n">ax</span><span class="p">,</span> <span class="n">n_std</span><span class="o">=</span><span class="mf">2.0</span><span class="p">,</span> <span class="n">plot_center</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
- <span class="n">ax</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">"equal"</span><span class="p">)</span>
- <span class="n">ax</span><span class="o">.</span><span class="n">legend</span><span class="p">()</span>
- </pre></div>
- </div>
- </div>
- </div>
- <div class="cell docutils container">
- <div class="cell_input docutils container">
- <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">z_hist</span><span class="p">,</span> <span class="n">x_hist</span> <span class="o">=</span> <span class="n">lds</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">timesteps</span><span class="p">)</span>
- <span class="n">fig_truth</span><span class="p">,</span> <span class="n">axs</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">()</span>
- <span class="n">axs</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">x_hist</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">x_hist</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">],</span>
- <span class="n">marker</span><span class="o">=</span><span class="s2">"o"</span><span class="p">,</span> <span class="n">linewidth</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">markerfacecolor</span><span class="o">=</span><span class="s2">"none"</span><span class="p">,</span>
- <span class="n">markeredgewidth</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">markersize</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span>
- <span class="n">label</span><span class="o">=</span><span class="s2">"observed"</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="s2">"tab:green"</span><span class="p">)</span>
- <span class="n">axs</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">z_hist</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">z_hist</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">],</span>
- <span class="n">linewidth</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s2">"truth"</span><span class="p">,</span>
- <span class="n">marker</span><span class="o">=</span><span class="s2">"s"</span><span class="p">,</span> <span class="n">markersize</span><span class="o">=</span><span class="mi">8</span><span class="p">)</span>
- <span class="n">axs</span><span class="o">.</span><span class="n">legend</span><span class="p">()</span>
- <span class="n">axs</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">"equal"</span><span class="p">)</span>
- </pre></div>
- </div>
- </div>
- <div class="cell_output docutils container">
- <div class="output text_plain highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>(7.24486608505249, 23.857812213897706, 8.042076778411865, 11.636079120635987)
- </pre></div>
- </div>
- <img alt="../_images/ssm_old_21_1.png" src="../_images/ssm_old_21_1.png" />
- </div>
- </div>
- <p>The main task is to infer the hidden states given the noisy
- observations, i.e., <span class="math notranslate nohighlight">\(p(\vz|\vy,\vtheta)\)</span>. We discuss the topic of inference in <a class="reference internal" href="#sec-inference"><span class="std std-ref">Inferential goals</span></a>.</p>
- </div>
- </div>
- <div class="tex2jax_ignore mathjax_ignore section" id="nonlinear-gaussian-ssms">
- <span id="sec-nlds-intro"></span><h1>Nonlinear Gaussian SSMs<a class="headerlink" href="#nonlinear-gaussian-ssms" title="Permalink to this headline">¶</a></h1>
- <p>In this section, we consider SSMs in which the dynamics and/or observation models are nonlinear,
- but the process noise and observation noise are Gaussian.
- That is,</p>
- <div class="amsmath math notranslate nohighlight" id="equation-9b2f4563-7156-4411-8ff1-8b8c16adf162">
- <span class="eqno">()<a class="headerlink" href="#equation-9b2f4563-7156-4411-8ff1-8b8c16adf162" title="Permalink to this equation">¶</a></span>\[\begin{align}
- \hmmhid_t &= \ssmDynFn(\hmmhid_{t-1}, \inputs_t) + \vepsilon_t \\
- \hmmobs_t &= \ssmObsFn(\hmmhid_{t}, \inputs_t) + \veta_t
- \end{align}\]</div>
- <p>where <span class="math notranslate nohighlight">\(\vepsilon_t \sim \gauss(\vzero,\vQ_t)\)</span>
- and <span class="math notranslate nohighlight">\(\veta_t \sim \gauss(\vzero,\vR_t)\)</span>.
- This is a very widely used model class. We give some examples below.</p>
- <div class="section" id="example-tracking-a-1d-pendulum">
- <span id="sec-pendulum"></span><h2>Example: tracking a 1d pendulum<a class="headerlink" href="#example-tracking-a-1d-pendulum" title="Permalink to this headline">¶</a></h2>
- <div class="figure align-default" id="fig-pendulum">
- <a class="reference internal image-reference" href="../_images/pendulum.png"><img alt="../_images/pendulum.png" src="../_images/pendulum.png" style="width: 265.0px; height: 295.0px;" /></a>
- <p class="caption"><span class="caption-text">Illustration of a pendulum swinging.
- <span class="math notranslate nohighlight">\(g\)</span> is the force of gravity,
- <span class="math notranslate nohighlight">\(w(t)\)</span> is a random external force,
- and <span class="math notranslate nohighlight">\(\alpha\)</span> is the angle wrt the vertical.
- Based on <span id="id4">[<a class="reference internal" href="../bib.html#id18" title="Simo Sarkka. Bayesian Filtering and Smoothing. Cambridge University Press, 2013. URL: https://users.aalto.fi/~ssarkka/pub/cup_book_online_20131111.pdf.">Sar13</a>]</span> fig 3.10.</span><a class="headerlink" href="#fig-pendulum" title="Permalink to this image">¶</a></p>
- </div>
- <p>Consider a simple pendulum of unit mass and length swinging from
- a fixed attachment, as in <a class="reference internal" href="#fig-pendulum"><span class="std std-ref">Illustration of a pendulum swinging.
- g is the force of gravity,
- w(t) is a random external force,
- and \alpha is the angle wrt the vertical.
- Based on Sarkka13 fig 3.10.</span></a>.
- Such an object is in principle entirely deterministic in its behavior.
- However, in the real world, there are often unknown forces at work
- (e.g., air turbulence, friction).
- We will model these by a continuous time random Gaussian noise process <span class="math notranslate nohighlight">\(w(t)\)</span>.
- This gives rise to the following differential equation:</p>
- <div class="amsmath math notranslate nohighlight" id="equation-dcdda512-ae3f-4ffb-bd90-db678338bb27">
- <span class="eqno">()<a class="headerlink" href="#equation-dcdda512-ae3f-4ffb-bd90-db678338bb27" title="Permalink to this equation">¶</a></span>\[\begin{align}
- \frac{d^2 \alpha}{d t^2}
- = -g \sin(\alpha) + w(t)
- \end{align}\]</div>
- <p>We can write this as a nonlinear SSM by defining the state to be
- <span class="math notranslate nohighlight">\(z_1(t) = \alpha(t)\)</span> and <span class="math notranslate nohighlight">\(z_2(t) = d\alpha(t)/dt\)</span>.
- Thus</p>
- <div class="amsmath math notranslate nohighlight" id="equation-27ccd595-dcb3-4feb-8f28-e9470220fd48">
- <span class="eqno">()<a class="headerlink" href="#equation-27ccd595-dcb3-4feb-8f28-e9470220fd48" title="Permalink to this equation">¶</a></span>\[\begin{align}
- \frac{d \vz}{dt}
- = \begin{pmatrix} z_2 \\ -g \sin(z_1) \end{pmatrix}
- + \begin{pmatrix} 0 \\ 1 \end{pmatrix} w(t)
- \end{align}\]</div>
- <p>If we discretize this step size <span class="math notranslate nohighlight">\(\Delta\)</span>,
- we get the following
- formulation <span id="id5">[<a class="reference internal" href="../bib.html#id18" title="Simo Sarkka. Bayesian Filtering and Smoothing. Cambridge University Press, 2013. URL: https://users.aalto.fi/~ssarkka/pub/cup_book_online_20131111.pdf.">Sar13</a>]</span> p74:</p>
- <div class="amsmath math notranslate nohighlight" id="equation-37e2b498-70d7-45cf-a98f-e957e1dd0963">
- <span class="eqno">()<a class="headerlink" href="#equation-37e2b498-70d7-45cf-a98f-e957e1dd0963" title="Permalink to this equation">¶</a></span>\[\begin{align}
- \underbrace{
- \begin{pmatrix} z_{1,t} \\ z_{2,t} \end{pmatrix}
- }_{\hmmhid_t}
- =
- \underbrace{
- \begin{pmatrix} z_{1,t-1} + z_{2,t-1} \Delta \\
- z_{2,t-1} -g \sin(z_{1,t-1}) \Delta \end{pmatrix}
- }_{\vf(\hmmhid_{t-1})}
- +\vq_{t-1}
- \end{align}\]</div>
- <p>where <span class="math notranslate nohighlight">\(\vq_{t-1} \sim \gauss(\vzero,\vQ)\)</span> with</p>
- <div class="amsmath math notranslate nohighlight" id="equation-65c3fab7-9891-4368-865b-935fc7a297fd">
- <span class="eqno">()<a class="headerlink" href="#equation-65c3fab7-9891-4368-865b-935fc7a297fd" title="Permalink to this equation">¶</a></span>\[\begin{align}
- \vQ = q^c \begin{pmatrix}
- \frac{\Delta^3}{3} & \frac{\Delta^2}{2} \\
- \frac{\Delta^2}{2} & \Delta
- \end{pmatrix}
- \end{align}\]</div>
- <p>where <span class="math notranslate nohighlight">\(q^c\)</span> is the spectral density (continuous time variance)
- of the continuous-time noise process.</p>
- <p>If we observe the angular position, we
- get the linear observation model</p>
- <div class="amsmath math notranslate nohighlight" id="equation-714de64b-7aa6-4c3f-a17c-0bf21bbc0ec5">
- <span class="eqno">()<a class="headerlink" href="#equation-714de64b-7aa6-4c3f-a17c-0bf21bbc0ec5" title="Permalink to this equation">¶</a></span>\[\begin{align}
- y_t = \alpha_t + r_t = h(\hmmhid_t) + r_t
- \end{align}\]</div>
- <p>where <span class="math notranslate nohighlight">\(h(\hmmhid_t) = z_{1,t}\)</span>
- and <span class="math notranslate nohighlight">\(r_t\)</span> is the observation noise.
- If we only observe the horizontal position,
- we get the nonlinear observation model</p>
- <div class="amsmath math notranslate nohighlight" id="equation-ab4fa705-661d-4ef6-8f1d-a394e212e500">
- <span class="eqno">()<a class="headerlink" href="#equation-ab4fa705-661d-4ef6-8f1d-a394e212e500" title="Permalink to this equation">¶</a></span>\[\begin{align}
- y_t = \sin(\alpha_t) + r_t = h(\hmmhid_t) + r_t
- \end{align}\]</div>
- <p>where <span class="math notranslate nohighlight">\(h(\hmmhid_t) = \sin(z_{1,t})\)</span>.</p>
- </div>
- </div>
- <div class="tex2jax_ignore mathjax_ignore section" id="inferential-goals">
- <span id="sec-inference"></span><h1>Inferential goals<a class="headerlink" href="#inferential-goals" title="Permalink to this headline">¶</a></h1>
- <div class="figure align-default" id="fig-dbn-inference">
- <a class="reference internal image-reference" href="../_images/inference-problems-tikz.png"><img alt="../_images/inference-problems-tikz.png" src="../_images/inference-problems-tikz.png" style="width: 1174.0px; height: 898.0px;" /></a>
- <p class="caption"><span class="caption-text">Illustration of the different kinds of inference in an SSM.
- The main kinds of inference for state-space models.
- The shaded region is the interval for which we have data.
- The arrow represents the time step at which we want to perform inference.
- <span class="math notranslate nohighlight">\(t\)</span> is the current time, <span class="math notranslate nohighlight">\(T\)</span> is the sequence length,
- <span class="math notranslate nohighlight">\(\ell\)</span> is the lag and <span class="math notranslate nohighlight">\(h\)</span> is the prediction horizon.</span><a class="headerlink" href="#fig-dbn-inference" title="Permalink to this image">¶</a></p>
- </div>
- <p>Given the sequence of observations, and a known model,
- one of the main tasks with SSMs
- to perform posterior inference,
- about the hidden states; this is also called
- state estimation.
- At each time step <span class="math notranslate nohighlight">\(t\)</span>,
- there are multiple forms of posterior we may be interested in computing,
- including the following:</p>
- <ul class="simple">
- <li><p>the filtering distribution
- <span class="math notranslate nohighlight">\(p(\hmmhid_t|\hmmobs_{1:t})\)</span></p></li>
- <li><p>the smoothing distribution
- <span class="math notranslate nohighlight">\(p(\hmmhid_t|\hmmobs_{1:T})\)</span> (note that this conditions on future data <span class="math notranslate nohighlight">\(T>t\)</span>)</p></li>
- <li><p>the fixed-lag smoothing distribution
- <span class="math notranslate nohighlight">\(p(\hmmhid_{t-\ell}|\hmmobs_{1:t})\)</span> (note that this
- infers <span class="math notranslate nohighlight">\(\ell\)</span> steps in the past given data up to the present).</p></li>
- </ul>
- <p>We may also want to compute the
- predictive distribution <span class="math notranslate nohighlight">\(h\)</span> steps into the future:</p>
- <div class="amsmath math notranslate nohighlight" id="equation-47f5e405-e42b-41c8-9bea-867b94bfd570">
- <span class="eqno">()<a class="headerlink" href="#equation-47f5e405-e42b-41c8-9bea-867b94bfd570" title="Permalink to this equation">¶</a></span>\[\begin{align}
- p(\hmmobs_{t+h}|\hmmobs_{1:t})
- &= \sum_{\hmmhid_{t+h}} p(\hmmobs_{t+h}|\hmmhid_{t+h}) p(\hmmhid_{t+h}|\hmmobs_{1:t})
- \end{align}\]</div>
- <p>where the hidden state predictive distribution is</p>
- <div class="amsmath math notranslate nohighlight" id="equation-9fb3fbe3-19d8-4734-96f6-00ee598c2510">
- <span class="eqno">()<a class="headerlink" href="#equation-9fb3fbe3-19d8-4734-96f6-00ee598c2510" title="Permalink to this equation">¶</a></span>\[\begin{align}
- p(\hmmhid_{t+h}|\hmmobs_{1:t})
- &= \sum_{\hmmhid_{t:t+h-1}}
- p(\hmmhid_t|\hmmobs_{1:t})
- p(\hmmhid_{t+1}|\hmmhid_{t})
- p(\hmmhid_{t+2}|\hmmhid_{t+1})
- \cdots
- p(\hmmhid_{t+h}|\hmmhid_{t+h-1})
- \end{align}\]</div>
- <p>See <a class="reference internal" href="#fig-dbn-inference"><span class="std std-ref">Illustration of the different kinds of inference in an SSM.
- The main kinds of inference for state-space models.
- The shaded region is the interval for which we have data.
- The arrow represents the time step at which we want to perform inference.
- t is the current time, T is the sequence length,
- \ell is the lag and h is the prediction horizon.</span></a> for a summary of these distributions.</p>
- <p>In addition to comuting posterior marginals,
- we may want to compute the most probable hidden sequence,
- i.e., the joint MAP estimate</p>
- <div class="math notranslate nohighlight">
- \[\arg \max_{\hmmhid_{1:T}} p(\hmmhid_{1:T}|\hmmobs_{1:T})\]</div>
- <p>or sample sequences from the posterior</p>
- <div class="math notranslate nohighlight">
- \[\hmmhid_{1:T} \sim p(\hmmhid_{1:T}|\hmmobs_{1:T})\]</div>
- <p>Algorithms for all these task are discussed in the following chapters,
- since the details depend on the form of the SSM.</p>
- <div class="section" id="example-inference-in-the-casino-hmm">
- <h2>Example: inference in the casino HMM<a class="headerlink" href="#example-inference-in-the-casino-hmm" title="Permalink to this headline">¶</a></h2>
- <p>We now illustrate filtering, smoothing and MAP decoding applied
- to the casino HMM from <span class="xref std std-ref">sec:casino</span>.</p>
- <div class="cell docutils container">
- <div class="cell_input docutils container">
- <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># Call inference engine</span>
- <span class="n">filtered_dist</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">smoothed_dist</span><span class="p">,</span> <span class="n">loglik</span> <span class="o">=</span> <span class="n">hmm</span><span class="o">.</span><span class="n">forward_backward</span><span class="p">(</span><span class="n">x_hist</span><span class="p">)</span>
- <span class="n">map_path</span> <span class="o">=</span> <span class="n">hmm</span><span class="o">.</span><span class="n">viterbi</span><span class="p">(</span><span class="n">x_hist</span><span class="p">)</span>
- </pre></div>
- </div>
- </div>
- <div class="cell_output docutils container">
- <div class="output stderr highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>/opt/anaconda3/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:5154: UserWarning: Explicitly requested dtype <class 'jax._src.numpy.lax_numpy.int64'> requested in astype is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
- lax_internal._check_user_dtype_supported(dtype, "astype")
- </pre></div>
- </div>
- </div>
- </div>
- <div class="cell docutils container">
- <div class="cell_input docutils container">
- <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># Find the span of timesteps that the simulated systems turns to be in state 1</span>
- <span class="k">def</span> <span class="nf">find_dishonest_intervals</span><span class="p">(</span><span class="n">z_hist</span><span class="p">):</span>
- <span class="n">spans</span> <span class="o">=</span> <span class="p">[]</span>
- <span class="n">x_init</span> <span class="o">=</span> <span class="mi">0</span>
- <span class="k">for</span> <span class="n">t</span><span class="p">,</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">z_hist</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]):</span>
- <span class="k">if</span> <span class="n">z_hist</span><span class="p">[</span><span class="n">t</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">z_hist</span><span class="p">[</span><span class="n">t</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
- <span class="n">x_end</span> <span class="o">=</span> <span class="n">t</span>
- <span class="n">spans</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">x_init</span><span class="p">,</span> <span class="n">x_end</span><span class="p">))</span>
- <span class="k">elif</span> <span class="n">z_hist</span><span class="p">[</span><span class="n">t</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">and</span> <span class="n">z_hist</span><span class="p">[</span><span class="n">t</span><span class="p">]</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
- <span class="n">x_init</span> <span class="o">=</span> <span class="n">t</span> <span class="o">+</span> <span class="mi">1</span>
- <span class="k">return</span> <span class="n">spans</span>
- </pre></div>
- </div>
- </div>
- </div>
- <div class="cell docutils container">
- <div class="cell_input docutils container">
- <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># Plot posterior</span>
- <span class="k">def</span> <span class="nf">plot_inference</span><span class="p">(</span><span class="n">inference_values</span><span class="p">,</span> <span class="n">z_hist</span><span class="p">,</span> <span class="n">ax</span><span class="p">,</span> <span class="n">state</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">map_estimate</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
- <span class="n">n_samples</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">inference_values</span><span class="p">)</span>
- <span class="n">xspan</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">n_samples</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
- <span class="n">spans</span> <span class="o">=</span> <span class="n">find_dishonest_intervals</span><span class="p">(</span><span class="n">z_hist</span><span class="p">)</span>
- <span class="k">if</span> <span class="n">map_estimate</span><span class="p">:</span>
- <span class="n">ax</span><span class="o">.</span><span class="n">step</span><span class="p">(</span><span class="n">xspan</span><span class="p">,</span> <span class="n">inference_values</span><span class="p">,</span> <span class="n">where</span><span class="o">=</span><span class="s2">"post"</span><span class="p">)</span>
- <span class="k">else</span><span class="p">:</span>
- <span class="n">ax</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">xspan</span><span class="p">,</span> <span class="n">inference_values</span><span class="p">[:,</span> <span class="n">state</span><span class="p">])</span>
- <span class="k">for</span> <span class="n">span</span> <span class="ow">in</span> <span class="n">spans</span><span class="p">:</span>
- <span class="n">ax</span><span class="o">.</span><span class="n">axvspan</span><span class="p">(</span><span class="o">*</span><span class="n">span</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span> <span class="n">facecolor</span><span class="o">=</span><span class="s2">"tab:gray"</span><span class="p">,</span> <span class="n">edgecolor</span><span class="o">=</span><span class="s2">"none"</span><span class="p">)</span>
- <span class="n">ax</span><span class="o">.</span><span class="n">set_xlim</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">n_samples</span><span class="p">)</span>
- <span class="c1"># ax.set_ylim(0, 1)</span>
- <span class="n">ax</span><span class="o">.</span><span class="n">set_ylim</span><span class="p">(</span><span class="o">-</span><span class="mf">0.1</span><span class="p">,</span> <span class="mf">1.1</span><span class="p">)</span>
- <span class="n">ax</span><span class="o">.</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s2">"Observation number"</span><span class="p">)</span>
- </pre></div>
- </div>
- </div>
- </div>
- <div class="cell docutils container">
- <div class="cell_input docutils container">
- <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span> <span class="c1"># Filtering</span>
- <span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">()</span>
- <span class="n">plot_inference</span><span class="p">(</span><span class="n">filtered_dist</span><span class="p">,</span> <span class="n">z_hist</span><span class="p">,</span> <span class="n">ax</span><span class="p">)</span>
- <span class="n">ax</span><span class="o">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s2">"p(loaded)"</span><span class="p">)</span>
- <span class="n">ax</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">"Filtered"</span><span class="p">)</span>
-
-
- </pre></div>
- </div>
- </div>
- <div class="cell_output docutils container">
- <div class="output traceback highlight-ipythontb notranslate"><div class="highlight"><pre><span></span><span class="gt">---------------------------------------------------------------------------</span>
- <span class="ne">ValueError</span><span class="g g-Whitespace"> </span>Traceback (most recent call last)
- <span class="o"><</span><span class="n">ipython</span><span class="o">-</span><span class="nb">input</span><span class="o">-</span><span class="mi">17</span><span class="o">-</span><span class="n">a3a9c11b46bd</span><span class="o">></span> <span class="ow">in</span> <span class="o"><</span><span class="n">module</span><span class="o">></span>
- <span class="g g-Whitespace"> </span><span class="mi">1</span> <span class="c1"># Filtering</span>
- <span class="g g-Whitespace"> </span><span class="mi">2</span> <span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">()</span>
- <span class="ne">----> </span><span class="mi">3</span> <span class="n">plot_inference</span><span class="p">(</span><span class="n">filtered_dist</span><span class="p">,</span> <span class="n">z_hist</span><span class="p">,</span> <span class="n">ax</span><span class="p">)</span>
- <span class="g g-Whitespace"> </span><span class="mi">4</span> <span class="n">ax</span><span class="o">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s2">"p(loaded)"</span><span class="p">)</span>
- <span class="g g-Whitespace"> </span><span class="mi">5</span> <span class="n">ax</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">"Filtered"</span><span class="p">)</span>
- <span class="nn"><ipython-input-16-9c8ddad3f57c></span> in <span class="ni">plot_inference</span><span class="nt">(inference_values, z_hist, ax, state, map_estimate)</span>
- <span class="g g-Whitespace"> </span><span class="mi">3</span> <span class="n">n_samples</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">inference_values</span><span class="p">)</span>
- <span class="g g-Whitespace"> </span><span class="mi">4</span> <span class="n">xspan</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">n_samples</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
- <span class="ne">----> </span><span class="mi">5</span> <span class="n">spans</span> <span class="o">=</span> <span class="n">find_dishonest_intervals</span><span class="p">(</span><span class="n">z_hist</span><span class="p">)</span>
- <span class="g g-Whitespace"> </span><span class="mi">6</span> <span class="k">if</span> <span class="n">map_estimate</span><span class="p">:</span>
- <span class="g g-Whitespace"> </span><span class="mi">7</span> <span class="n">ax</span><span class="o">.</span><span class="n">step</span><span class="p">(</span><span class="n">xspan</span><span class="p">,</span> <span class="n">inference_values</span><span class="p">,</span> <span class="n">where</span><span class="o">=</span><span class="s2">"post"</span><span class="p">)</span>
- <span class="nn"><ipython-input-15-4606c615e17a></span> in <span class="ni">find_dishonest_intervals</span><span class="nt">(z_hist)</span>
- <span class="g g-Whitespace"> </span><span class="mi">4</span> <span class="n">x_init</span> <span class="o">=</span> <span class="mi">0</span>
- <span class="g g-Whitespace"> </span><span class="mi">5</span> <span class="k">for</span> <span class="n">t</span><span class="p">,</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">z_hist</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]):</span>
- <span class="ne">----> </span><span class="mi">6</span> <span class="k">if</span> <span class="n">z_hist</span><span class="p">[</span><span class="n">t</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">z_hist</span><span class="p">[</span><span class="n">t</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
- <span class="g g-Whitespace"> </span><span class="mi">7</span> <span class="n">x_end</span> <span class="o">=</span> <span class="n">t</span>
- <span class="g g-Whitespace"> </span><span class="mi">8</span> <span class="n">spans</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">x_init</span><span class="p">,</span> <span class="n">x_end</span><span class="p">))</span>
- <span class="nn">/opt/anaconda3/lib/python3.8/functools.py</span> in <span class="ni">_method</span><span class="nt">(cls_or_self, *args, **keywords)</span>
- <span class="g g-Whitespace"> </span><span class="mi">397</span> <span class="k">def</span> <span class="nf">_method</span><span class="p">(</span><span class="n">cls_or_self</span><span class="p">,</span> <span class="o">/</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">keywords</span><span class="p">):</span>
- <span class="g g-Whitespace"> </span><span class="mi">398</span> <span class="n">keywords</span> <span class="o">=</span> <span class="p">{</span><span class="o">**</span><span class="bp">self</span><span class="o">.</span><span class="n">keywords</span><span class="p">,</span> <span class="o">**</span><span class="n">keywords</span><span class="p">}</span>
- <span class="ne">--> </span><span class="mi">399</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">func</span><span class="p">(</span><span class="n">cls_or_self</span><span class="p">,</span> <span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">keywords</span><span class="p">)</span>
- <span class="g g-Whitespace"> </span><span class="mi">400</span> <span class="n">_method</span><span class="o">.</span><span class="n">__isabstractmethod__</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">__isabstractmethod__</span>
- <span class="g g-Whitespace"> </span><span class="mi">401</span> <span class="n">_method</span><span class="o">.</span><span class="n">_partialmethod</span> <span class="o">=</span> <span class="bp">self</span>
- <span class="nn">/opt/anaconda3/lib/python3.8/site-packages/jax/_src/device_array.py</span> in <span class="ni">_forward_method</span><span class="nt">(attrname, self, fun, *args)</span>
- <span class="g g-Whitespace"> </span><span class="mi">39</span>
- <span class="g g-Whitespace"> </span><span class="mi">40</span> <span class="k">def</span> <span class="nf">_forward_method</span><span class="p">(</span><span class="n">attrname</span><span class="p">,</span> <span class="bp">self</span><span class="p">,</span> <span class="n">fun</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">):</span>
- <span class="ne">---> </span><span class="mi">41</span> <span class="k">return</span> <span class="n">fun</span><span class="p">(</span><span class="nb">getattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">attrname</span><span class="p">),</span> <span class="o">*</span><span class="n">args</span><span class="p">)</span>
- <span class="g g-Whitespace"> </span><span class="mi">42</span> <span class="n">_forward_to_value</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">_forward_method</span><span class="p">,</span> <span class="s2">"_value"</span><span class="p">)</span>
- <span class="g g-Whitespace"> </span><span class="mi">43</span>
- <span class="ne">ValueError</span>: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
- </pre></div>
- </div>
- <img alt="../_images/ssm_old_30_1.png" src="../_images/ssm_old_30_1.png" />
- </div>
- </div>
- <div class="cell docutils container">
- <div class="cell_input docutils container">
- <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># Smoothing</span>
- <span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">()</span>
- <span class="n">plot_inference</span><span class="p">(</span><span class="n">smoothed_dist</span><span class="p">,</span> <span class="n">z_hist</span><span class="p">,</span> <span class="n">ax</span><span class="p">)</span>
- <span class="n">ax</span><span class="o">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s2">"p(loaded)"</span><span class="p">)</span>
- <span class="n">ax</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">"Smoothed"</span><span class="p">)</span>
-
- </pre></div>
- </div>
- </div>
- <div class="cell_output docutils container">
- <div class="output text_plain highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>Text(0.5, 1.0, 'Smoothed')
- </pre></div>
- </div>
- <img alt="../_images/ssm_old_31_1.png" src="../_images/ssm_old_31_1.png" />
- </div>
- </div>
- <div class="cell docutils container">
- <div class="cell_input docutils container">
- <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># MAP estimation</span>
- <span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">()</span>
- <span class="n">plot_inference</span><span class="p">(</span><span class="n">map_path</span><span class="p">,</span> <span class="n">z_hist</span><span class="p">,</span> <span class="n">ax</span><span class="p">,</span> <span class="n">map_estimate</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
- <span class="n">ax</span><span class="o">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s2">"MAP state"</span><span class="p">)</span>
- <span class="n">ax</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">"Viterbi"</span><span class="p">)</span>
- </pre></div>
- </div>
- </div>
- </div>
- <div class="cell docutils container">
- <div class="cell_input docutils container">
- <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># TODO: posterior samples</span>
- </pre></div>
- </div>
- </div>
- </div>
- </div>
- <div class="section" id="example-inference-in-the-tracking-ssm">
- <h2>Example: inference in the tracking SSM<a class="headerlink" href="#example-inference-in-the-tracking-ssm" title="Permalink to this headline">¶</a></h2>
- <p>We now illustrate filtering, smoothing and MAP decoding applied
- to the 2d tracking HMM from <a class="reference internal" href="#sec-tracking-lds"><span class="std std-ref">Example: tracking a 2d point</span></a>.</p>
- </div>
- </div>
- <script type="text/x-thebe-config">
- {
- requestKernel: true,
- binderOptions: {
- repo: "binder-examples/jupyter-stacks-datascience",
- ref: "master",
- },
- codeMirrorConfig: {
- theme: "abcdef",
- mode: "python"
- },
- kernelOptions: {
- kernelName: "python3",
- path: "./old"
- },
- predefinedOutput: true
- }
- </script>
- <script>kernelName = 'python3'</script>
- </div>
-
-
- <!-- Previous / next buttons -->
- <div class='prev-next-area'>
- </div>
-
- </div>
- </div>
- <footer class="footer">
- <p>
-
- By Kevin Murphy, Scott Linderman, et al.<br/>
-
- © Copyright 2021.<br/>
- </p>
- </footer>
- </main>
- </div>
- </div>
-
- <script src="../_static/js/index.be7d3bbb2ef33a8344ce.js"></script>
- </body>
- </html>
|