| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985 |
- <!DOCTYPE html>
- <html>
- <head>
- <meta charset="utf-8" />
- <meta name="viewport" content="width=device-width, initial-scale=1.0" />
- <title>Hidden Markov Models — State Space Models: A Modern Approach</title>
-
- <link href="../../_static/css/theme.css" rel="stylesheet">
- <link href="../../_static/css/index.ff1ffe594081f20da1ef19478df9384b.css" rel="stylesheet">
-
- <link rel="stylesheet"
- href="../../_static/vendor/fontawesome/5.13.0/css/all.min.css">
- <link rel="preload" as="font" type="font/woff2" crossorigin
- href="../../_static/vendor/fontawesome/5.13.0/webfonts/fa-solid-900.woff2">
- <link rel="preload" as="font" type="font/woff2" crossorigin
- href="../../_static/vendor/fontawesome/5.13.0/webfonts/fa-brands-400.woff2">
-
-
-
- <link rel="stylesheet" type="text/css" href="../../_static/pygments.css" />
- <link rel="stylesheet" type="text/css" href="../../_static/sphinx-book-theme.css?digest=c3fdc42140077d1ad13ad2f1588a4309" />
- <link rel="stylesheet" type="text/css" href="../../_static/togglebutton.css" />
- <link rel="stylesheet" type="text/css" href="../../_static/copybutton.css" />
- <link rel="stylesheet" type="text/css" href="../../_static/mystnb.css" />
- <link rel="stylesheet" type="text/css" href="../../_static/sphinx-thebe.css" />
- <link rel="stylesheet" type="text/css" href="../../_static/panels-main.c949a650a448cc0ae9fd3441c0e17fb0.css" />
- <link rel="stylesheet" type="text/css" href="../../_static/panels-variables.06eb56fa6e07937060861dad626602ad.css" />
-
- <link rel="preload" as="script" href="../../_static/js/index.be7d3bbb2ef33a8344ce.js">
- <script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
- <script src="../../_static/jquery.js"></script>
- <script src="../../_static/underscore.js"></script>
- <script src="../../_static/doctools.js"></script>
- <script src="../../_static/clipboard.min.js"></script>
- <script src="../../_static/copybutton.js"></script>
- <script>let toggleHintShow = 'Click to show';</script>
- <script>let toggleHintHide = 'Click to hide';</script>
- <script>let toggleOpenOnPrint = 'true';</script>
- <script src="../../_static/togglebutton.js"></script>
- <script>var togglebuttonSelector = '.toggle, .admonition.dropdown, .tag_hide_input div.cell_input, .tag_hide-input div.cell_input, .tag_hide_output div.cell_output, .tag_hide-output div.cell_output, .tag_hide_cell.cell, .tag_hide-cell.cell';</script>
- <script src="../../_static/sphinx-book-theme.d59cb220de22ca1c485ebbdc042f0030.js"></script>
- <script>const THEBE_JS_URL = "https://unpkg.com/thebe@0.8.2/lib/index.js"
- const thebe_selector = ".thebe,.cell"
- const thebe_selector_input = "pre"
- const thebe_selector_output = ".output, .cell_output"
- </script>
- <script async="async" src="../../_static/sphinx-thebe.js"></script>
- <script>window.MathJax = {"tex": {"macros": {"covMat": "\\boldsymbol{\\Sigma}", "data": "\\mathcal{D}", "defeq": "\\triangleq", "diag": "\\mathrm{diag}", "discreteState": "s", "dotstar": "\\odot", "dynamicsFn": "\\mathbf{f}", "floor": ["\\lfloor#1\\rfloor", 1], "gainMatrix": "\\mathbf{K}", "gainMatrixReverse": "\\mathbf{G}", "gauss": "\\mathcal{N}", "gaussInfo": "\\mathcal{N}_{\\text{info}}", "hidden": "\\mathbf{x}", "hiddenScalar": "x", "hmmInit": "\\boldsymbol{\\pi}", "hmmInitScalar": "\\pi", "hmmObs": "\\mathbf{B}", "hmmObsScalar": "B", "hmmTrans": "\\mathbf{A}", "hmmTransScalar": "A", "infoMat": "\\precMat", "input": "\\mathbf{u}", "inputs": "\\input", "inv": ["{#1}^{-1}", 1], "keyword": ["\\textbf{#1}", 1], "ldsDyn": "\\mathbf{F}", "ldsDynIn": "\\mathbf{B}", "initMean": "\\boldsymbol{\\mean}_0", "initCov": "\\boldsymbol{\\covMat}_0", "ldsObs": "\\mathbf{H}", "ldsObsIn": "\\mathbf{D}", "ldsTrans": "\\ldsDyn", "ldsTransIn": "\\ldsDynIn", "obsCov": "\\mathbf{R}", "obsNoise": "\\boldsymbol{r}", "map": "\\mathrm{map}", "measurementFn": "\\mathbf{h}", "mean": "\\boldsymbol{\\mu}", "mle": "\\mathrm{mle}", "nlatents": "n_x", "nhidden": "\\nlatents", "ninputs": "n_u", "nobs": "n_y", "nsymbols": "n_y", "nstates": "n_s", "obs": "\\mathbf{y}", "obsScalar": "y", "observed": "\\obs", "obsFn": "\\measurementFn", "params": "\\boldsymbol{\\theta}", "precMean": "\\boldsymbol{\\eta}", "precMat": "\\boldsymbol{\\Lambda}", "real": "\\mathbb{R}", "sigmoid": "\\sigma", "softmax": "\\boldsymbol{\\sigma}", "trans": "\\mathsf{T}", "transpose": ["{#1}^{\\trans}", 1], "transCov": "\\mathbf{Q}", "transNoise": "\\mathbf{q}", "valpha": "\\boldsymbol{\\alpha}", "vbeta": "\\boldsymbol{\\beta}", "vdelta": "\\boldsymbol{\\delta}", "vepsilon": "\\boldsymbol{\\epsilon}", "vlambda": "\\boldsymbol{\\lambda}", "vLambda": "\\boldsymbol{\\Lambda}", "vmu": "\\boldsymbol{\\mu}", "vpi": "\\boldsymbol{\\pi}", "vsigma": "\\boldsymbol{\\sigma}", "vSigma": "\\boldsymbol{\\Sigma}", "vone": "\\boldsymbol{1}", "vzero": "\\boldsymbol{0}"}}, "options": {"processHtmlClass": "tex2jax_process|mathjax_process|math|output_area"}}</script>
- <script defer="defer" src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
- <link rel="index" title="Index" href="../../genindex.html" />
- <link rel="search" title="Search" href="../../search.html" />
- <link rel="next" title="Linear Gaussian SSMs" href="lds.html" />
- <link rel="prev" title="What are State Space Models?" href="ssm_intro.html" />
- <meta name="viewport" content="width=device-width, initial-scale=1" />
- <meta name="docsearch:language" content="None">
-
- <!-- Google Analytics -->
-
- </head>
- <body data-spy="scroll" data-target="#bd-toc-nav" data-offset="80">
-
- <div class="container-fluid" id="banner"></div>
-
- <div class="container-xl">
- <div class="row">
-
- <div class="col-12 col-md-3 bd-sidebar site-navigation show" id="site-navigation">
-
- <div class="navbar-brand-box">
- <a class="navbar-brand text-wrap" href="../../index.html">
-
-
-
- <h1 class="site-logo" id="site-title">State Space Models: A Modern Approach</h1>
-
- </a>
- </div><form class="bd-search d-flex align-items-center" action="../../search.html" method="get">
- <i class="icon fas fa-search"></i>
- <input type="search" class="form-control" name="q" id="search-input" placeholder="Search this book..." aria-label="Search this book..." autocomplete="off" >
- </form><nav class="bd-links" id="bd-docs-nav" aria-label="Main">
- <div class="bd-toc-item active">
- <ul class="nav bd-sidenav">
- <li class="toctree-l1">
- <a class="reference internal" href="../../root.html">
- State Space Models: A Modern Approach
- </a>
- </li>
- </ul>
- <ul class="current nav bd-sidenav">
- <li class="toctree-l1 current active has-children">
- <a class="reference internal" href="ssm_index.html">
- State Space Models
- </a>
- <input checked="" class="toctree-checkbox" id="toctree-checkbox-1" name="toctree-checkbox-1" type="checkbox"/>
- <label for="toctree-checkbox-1">
- <i class="fas fa-chevron-down">
- </i>
- </label>
- <ul class="current">
- <li class="toctree-l2">
- <a class="reference internal" href="ssm_intro.html">
- What are State Space Models?
- </a>
- </li>
- <li class="toctree-l2 current active">
- <a class="current reference internal" href="#">
- Hidden Markov Models
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="lds.html">
- Linear Gaussian SSMs
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="nlds.html">
- Nonlinear Gaussian SSMs
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="inference.html">
- States estimation (inference)
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="learning.html">
- Parameter estimation (learning)
- </a>
- </li>
- </ul>
- </li>
- <li class="toctree-l1 has-children">
- <a class="reference internal" href="../hmm/hmm_index.html">
- Hidden Markov Models
- </a>
- <input class="toctree-checkbox" id="toctree-checkbox-2" name="toctree-checkbox-2" type="checkbox"/>
- <label for="toctree-checkbox-2">
- <i class="fas fa-chevron-down">
- </i>
- </label>
- <ul>
- <li class="toctree-l2">
- <a class="reference internal" href="../hmm/hmm_filter.html">
- HMM filtering (forwards algorithm)
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../hmm/hmm_smoother.html">
- HMM smoothing (forwards-backwards algorithm)
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../hmm/hmm_viterbi.html">
- Viterbi algorithm
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../hmm/hmm_parallel.html">
- Parallel HMM smoothing
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../hmm/hmm_sampling.html">
- Forwards-filtering backwards-sampling algorithm
- </a>
- </li>
- </ul>
- </li>
- <li class="toctree-l1 has-children">
- <a class="reference internal" href="../lgssm/lgssm_index.html">
- Linear-Gaussian SSMs
- </a>
- <input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
- <label for="toctree-checkbox-3">
- <i class="fas fa-chevron-down">
- </i>
- </label>
- <ul>
- <li class="toctree-l2">
- <a class="reference internal" href="../lgssm/kalman_filter.html">
- Kalman filtering
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../lgssm/kalman_smoother.html">
- Kalman (RTS) smoother
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../lgssm/kalman_parallel.html">
- Parallel Kalman Smoother
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../lgssm/kalman_sampling.html">
- Forwards-filtering backwards sampling
- </a>
- </li>
- </ul>
- </li>
- <li class="toctree-l1 has-children">
- <a class="reference internal" href="../extended/extended_index.html">
- Extended (linearized) methods
- </a>
- <input class="toctree-checkbox" id="toctree-checkbox-4" name="toctree-checkbox-4" type="checkbox"/>
- <label for="toctree-checkbox-4">
- <i class="fas fa-chevron-down">
- </i>
- </label>
- <ul>
- <li class="toctree-l2">
- <a class="reference internal" href="../extended/extended_filter.html">
- Extended Kalman filtering
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../extended/extended_smoother.html">
- Extended Kalman smoother
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../extended/extended_parallel.html">
- Parallel extended Kalman smoothing
- </a>
- </li>
- </ul>
- </li>
- <li class="toctree-l1 has-children">
- <a class="reference internal" href="../unscented/unscented_index.html">
- Unscented methods
- </a>
- <input class="toctree-checkbox" id="toctree-checkbox-5" name="toctree-checkbox-5" type="checkbox"/>
- <label for="toctree-checkbox-5">
- <i class="fas fa-chevron-down">
- </i>
- </label>
- <ul>
- <li class="toctree-l2">
- <a class="reference internal" href="../unscented/unscented_filter.html">
- Unscented filtering
- </a>
- </li>
- <li class="toctree-l2">
- <a class="reference internal" href="../unscented/unscented_smoother.html">
- Unscented smoothing
- </a>
- </li>
- </ul>
- </li>
- <li class="toctree-l1">
- <a class="reference internal" href="../quadrature/quadrature_index.html">
- Quadrature and cubature methods
- </a>
- </li>
- <li class="toctree-l1">
- <a class="reference internal" href="../postlin/postlin_index.html">
- Posterior linearization
- </a>
- </li>
- <li class="toctree-l1">
- <a class="reference internal" href="../adf/adf_index.html">
- Assumed Density Filtering
- </a>
- </li>
- <li class="toctree-l1">
- <a class="reference internal" href="../vi/vi_index.html">
- Variational inference
- </a>
- </li>
- <li class="toctree-l1">
- <a class="reference internal" href="../pf/pf_index.html">
- Particle filtering
- </a>
- </li>
- <li class="toctree-l1">
- <a class="reference internal" href="../smc/smc_index.html">
- Sequential Monte Carlo
- </a>
- </li>
- <li class="toctree-l1">
- <a class="reference internal" href="../learning/learning_index.html">
- Offline parameter estimation (learning)
- </a>
- </li>
- <li class="toctree-l1">
- <a class="reference internal" href="../tracking/tracking_index.html">
- Multi-target tracking
- </a>
- </li>
- <li class="toctree-l1">
- <a class="reference internal" href="../ensemble/ensemble_index.html">
- Data assimilation using Ensemble Kalman filter
- </a>
- </li>
- <li class="toctree-l1">
- <a class="reference internal" href="../bnp/bnp_index.html">
- Bayesian non-parametric SSMs
- </a>
- </li>
- <li class="toctree-l1">
- <a class="reference internal" href="../changepoint/changepoint_index.html">
- Changepoint detection
- </a>
- </li>
- <li class="toctree-l1">
- <a class="reference internal" href="../timeseries/timeseries_index.html">
- Timeseries forecasting
- </a>
- </li>
- <li class="toctree-l1">
- <a class="reference internal" href="../gp/gp_index.html">
- Markovian Gaussian processes
- </a>
- </li>
- <li class="toctree-l1">
- <a class="reference internal" href="../ode/ode_index.html">
- Differential equations and SSMs
- </a>
- </li>
- <li class="toctree-l1">
- <a class="reference internal" href="../control/control_index.html">
- Optimal control
- </a>
- </li>
- <li class="toctree-l1">
- <a class="reference internal" href="../../bib.html">
- Bibliography
- </a>
- </li>
- </ul>
- </div>
- </nav> <!-- To handle the deprecated key -->
- <div class="navbar_extra_footer">
- Powered by <a href="https://jupyterbook.org">Jupyter Book</a>
- </div>
- </div>
-
-
- <main class="col py-md-3 pl-md-4 bd-content overflow-auto" role="main">
-
- <div class="topbar container-xl fixed-top">
- <div class="topbar-contents row">
- <div class="col-12 col-md-3 bd-topbar-whitespace site-navigation show"></div>
- <div class="col pl-md-4 topbar-main">
-
- <button id="navbar-toggler" class="navbar-toggler ml-0" type="button" data-toggle="collapse"
- data-toggle="tooltip" data-placement="bottom" data-target=".site-navigation" aria-controls="navbar-menu"
- aria-expanded="true" aria-label="Toggle navigation" aria-controls="site-navigation"
- title="Toggle navigation" data-toggle="tooltip" data-placement="left">
- <i class="fas fa-bars"></i>
- <i class="fas fa-arrow-left"></i>
- <i class="fas fa-arrow-up"></i>
- </button>
-
-
- <div class="dropdown-buttons-trigger">
- <button id="dropdown-buttons-trigger" class="btn btn-secondary topbarbtn" aria-label="Download this page"><i
- class="fas fa-download"></i></button>
- <div class="dropdown-buttons">
- <!-- ipynb file if we had a myst markdown file -->
-
- <!-- Download raw file -->
- <a class="dropdown-buttons" href="../../_sources/chapters/ssm/hmm.ipynb"><button type="button"
- class="btn btn-secondary topbarbtn" title="Download source file" data-toggle="tooltip"
- data-placement="left">.ipynb</button></a>
- <!-- Download PDF via print -->
- <button type="button" id="download-print" class="btn btn-secondary topbarbtn" title="Print to PDF"
- onclick="printPdf(this)" data-toggle="tooltip" data-placement="left">.pdf</button>
- </div>
- </div>
- <!-- Source interaction buttons -->
- <div class="dropdown-buttons-trigger">
- <button id="dropdown-buttons-trigger" class="btn btn-secondary topbarbtn"
- aria-label="Connect with source repository"><i class="fab fa-github"></i></button>
- <div class="dropdown-buttons sourcebuttons">
- <a class="repository-button"
- href="https://github.com/probml/ssm-book"><button type="button" class="btn btn-secondary topbarbtn"
- data-toggle="tooltip" data-placement="left" title="Source repository"><i
- class="fab fa-github"></i>repository</button></a>
- <a class="issues-button"
- href="https://github.com/probml/ssm-book/issues/new?title=Issue%20on%20page%20%2Fchapters/ssm/hmm.html&body=Your%20issue%20content%20here."><button
- type="button" class="btn btn-secondary topbarbtn" data-toggle="tooltip" data-placement="left"
- title="Open an issue"><i class="fas fa-lightbulb"></i>open issue</button></a>
-
- </div>
- </div>
- <!-- Full screen (wrap in <a> to have style consistency -->
- <a class="full-screen-button"><button type="button" class="btn btn-secondary topbarbtn" data-toggle="tooltip"
- data-placement="bottom" onclick="toggleFullScreen()" aria-label="Fullscreen mode"
- title="Fullscreen mode"><i
- class="fas fa-expand"></i></button></a>
- <!-- Launch buttons -->
- <div class="dropdown-buttons-trigger">
- <button id="dropdown-buttons-trigger" class="btn btn-secondary topbarbtn"
- aria-label="Launch interactive content"><i class="fas fa-rocket"></i></button>
- <div class="dropdown-buttons">
-
- <a class="binder-button" href="https://mybinder.org/v2/gh/probml/ssm-book/main?urlpath=tree/chapters/ssm/hmm.ipynb"><button type="button"
- class="btn btn-secondary topbarbtn" title="Launch Binder" data-toggle="tooltip"
- data-placement="left"><img class="binder-button-logo"
- src="../../_static/images/logo_binder.svg"
- alt="Interact on binder">Binder</button></a>
-
-
-
- <a class="colab-button" href="https://colab.research.google.com/github/probml/ssm-book/blob/main/chapters/ssm/hmm.ipynb"><button type="button" class="btn btn-secondary topbarbtn"
- title="Launch Colab" data-toggle="tooltip" data-placement="left"><img class="colab-button-logo"
- src="../../_static/images/logo_colab.png"
- alt="Interact on Colab">Colab</button></a>
-
-
- </div>
- </div>
- </div>
- <!-- Table of contents -->
- <div class="d-none d-md-block col-md-2 bd-toc show noprint">
-
- <div class="tocsection onthispage pt-5 pb-3">
- <i class="fas fa-list"></i> Contents
- </div>
- <nav id="bd-toc-nav" aria-label="Page">
- <ul class="visible nav section-nav flex-column">
- <li class="toc-h2 nav-item toc-entry">
- <a class="reference internal nav-link" href="#example-casino-hmm">
- Example: Casino HMM
- </a>
- </li>
- <li class="toc-h2 nav-item toc-entry">
- <a class="reference internal nav-link" href="#example-lillypad-hmm">
- Example: Lillypad HMM
- </a>
- </li>
- </ul>
- </nav>
- </div>
- </div>
- </div>
- <div id="main-content" class="row">
- <div class="col-12 col-md-9 pl-md-3 pr-md-0">
- <!-- Table of contents that is only displayed when printing the page -->
- <div id="jb-print-docs-body" class="onlyprint">
- <h1>Hidden Markov Models</h1>
- <!-- Table of contents -->
- <div id="print-main-content">
- <div id="jb-print-toc">
-
- <div>
- <h2> Contents </h2>
- </div>
- <nav aria-label="Page">
- <ul class="visible nav section-nav flex-column">
- <li class="toc-h2 nav-item toc-entry">
- <a class="reference internal nav-link" href="#example-casino-hmm">
- Example: Casino HMM
- </a>
- </li>
- <li class="toc-h2 nav-item toc-entry">
- <a class="reference internal nav-link" href="#example-lillypad-hmm">
- Example: Lillypad HMM
- </a>
- </li>
- </ul>
- </nav>
- </div>
- </div>
- </div>
-
- <div>
-
- <div class="tex2jax_ignore mathjax_ignore section" id="hidden-markov-models">
- <span id="sec-hmm-intro"></span><h1>Hidden Markov Models<a class="headerlink" href="#hidden-markov-models" title="Permalink to this headline">¶</a></h1>
- <p>In this section, we discuss the
- hidden Markov model or HMM,
- which is a state space model in which the hidden states
- are discrete, so <span class="math notranslate nohighlight">\(\hidden_t \in \{1,\ldots, \nstates\}\)</span>.
- The observations may be discrete,
- <span class="math notranslate nohighlight">\(\obs_t \in \{1,\ldots, \nsymbols\}\)</span>,
- or continuous,
- <span class="math notranslate nohighlight">\(\obs_t \in \real^\nstates\)</span>,
- or some combination,
- as we illustrate below.
- More details can be found in e.g.,
- <span id="id1">[<a class="reference internal" href="../../bib.html#id34" title="O. Cappe, E. Moulines, and T. Ryden. Inference in Hidden Markov Models. Springer, 2005.">CMR05</a>, <a class="reference internal" href="../../bib.html#id33" title="A. Fraser. Hidden Markov Models and Dynamical Systems. SIAM Press, 2008.">Fra08</a>, <a class="reference internal" href="../../bib.html#id32" title="L. R. Rabiner. A tutorial on Hidden Markov Models and selected applications in speech recognition. Proc. of the IEEE, 77(2):257–286, 1989.">Rab89</a>]</span>.
- For an interactive introduction,
- see <a class="reference external" href="https://nipunbatra.github.io/hmm/">https://nipunbatra.github.io/hmm/</a>.</p>
- <div class="cell docutils container">
- <div class="cell_input docutils container">
- <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="p">{</span>
- <span class="s2">"tags"</span><span class="p">:</span> <span class="p">[</span>
- <span class="s2">"hide-cell"</span>
- <span class="p">]</span>
- <span class="p">}</span>
- <span class="c1">### Import standard libraries</span>
- <span class="kn">import</span> <span class="nn">abc</span>
- <span class="kn">from</span> <span class="nn">dataclasses</span> <span class="kn">import</span> <span class="n">dataclass</span>
- <span class="kn">import</span> <span class="nn">functools</span>
- <span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">partial</span>
- <span class="kn">import</span> <span class="nn">itertools</span>
- <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span>
- <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
- <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span><span class="p">,</span> <span class="n">Callable</span><span class="p">,</span> <span class="n">NamedTuple</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Union</span><span class="p">,</span> <span class="n">Tuple</span>
- <span class="kn">import</span> <span class="nn">jax</span>
- <span class="kn">import</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span>
- <span class="kn">from</span> <span class="nn">jax</span> <span class="kn">import</span> <span class="n">lax</span><span class="p">,</span> <span class="n">vmap</span><span class="p">,</span> <span class="n">jit</span><span class="p">,</span> <span class="n">grad</span>
- <span class="c1">#from jax.scipy.special import logit</span>
- <span class="c1">#from jax.nn import softmax</span>
- <span class="kn">import</span> <span class="nn">jax.random</span> <span class="k">as</span> <span class="nn">jr</span>
- <span class="kn">import</span> <span class="nn">distrax</span>
- <span class="kn">import</span> <span class="nn">optax</span>
- <span class="kn">import</span> <span class="nn">jsl</span>
- <span class="kn">import</span> <span class="nn">ssm_jax</span>
- <span class="kn">import</span> <span class="nn">inspect</span>
- <span class="kn">import</span> <span class="nn">inspect</span> <span class="k">as</span> <span class="nn">py_inspect</span>
- <span class="kn">import</span> <span class="nn">rich</span>
- <span class="kn">from</span> <span class="nn">rich</span> <span class="kn">import</span> <span class="n">inspect</span> <span class="k">as</span> <span class="n">r_inspect</span>
- <span class="kn">from</span> <span class="nn">rich</span> <span class="kn">import</span> <span class="nb">print</span> <span class="k">as</span> <span class="n">r_print</span>
- <span class="k">def</span> <span class="nf">print_source</span><span class="p">(</span><span class="n">fname</span><span class="p">):</span>
- <span class="n">r_print</span><span class="p">(</span><span class="n">py_inspect</span><span class="o">.</span><span class="n">getsource</span><span class="p">(</span><span class="n">fname</span><span class="p">))</span>
- </pre></div>
- </div>
- </div>
- </div>
- <div class="section" id="example-casino-hmm">
- <h2>Example: Casino HMM<a class="headerlink" href="#example-casino-hmm" title="Permalink to this headline">¶</a></h2>
- <p>To illustrate HMMs with categorical observation model,
- we consider the “Ocassionally dishonest casino” model from <span id="id2">[<a class="reference internal" href="../../bib.html#id3" title="R. Durbin, S. Eddy, A. Krogh, and G. Mitchison. Biological Sequence Analysis: Probabilistic Models of Proteins and Nucleic Acids. Cambridge University Press, 1998.">DEKM98</a>]</span>.
- There are 2 hidden states, representing whether the dice being used in the casino is fair or loaded.
- Each state defines a distribution over the 6 possible observations.</p>
- <p>The transition model is denoted by</p>
- <div class="math notranslate nohighlight">
- \[p(\hidden_t=j|\hidden_{t-1}=i) = \hmmTransScalar_{ij}\]</div>
- <p>Here the <span class="math notranslate nohighlight">\(i\)</span>’th row of <span class="math notranslate nohighlight">\(\hmmTrans\)</span> corresponds to the outgoing distribution from state <span class="math notranslate nohighlight">\(i\)</span>.
- This is a row stochastic matrix,
- meaning each row sums to one.
- We can visualize
- the non-zero entries in the transition matrix by creating a state transition diagram,
- as shown in
- <a class="reference internal" href="#fig-casino"><span class="std std-numref">Fig. 3</span></a>.</p>
- <div class="figure align-default" id="fig-casino">
- <a class="reference internal image-reference" href="../../_images/casino.png"><img alt="../../_images/casino.png" src="../../_images/casino.png" style="width: 208.5px; height: 142.5px;" /></a>
- <p class="caption"><span class="caption-number">Fig. 3 </span><span class="caption-text">Illustration of the casino HMM.</span><a class="headerlink" href="#fig-casino" title="Permalink to this image">¶</a></p>
- </div>
- <p>The observation model
- <span class="math notranslate nohighlight">\(p(\obs_t|\hidden_t=j)\)</span> has the form</p>
- <div class="math notranslate nohighlight">
- \[p(\obs_t=k|\hidden_t=j) = \hmmObsScalar_{jk} \]</div>
- <p>This is represented by the histograms associated with each
- state in <a class="reference internal" href="#fig-casino"><span class="std std-numref">Fig. 3</span></a>.</p>
- <p>Finally,
- the initial state distribution is denoted by</p>
- <div class="math notranslate nohighlight">
- \[p(\hidden_1=j) = \hmmInitScalar_j\]</div>
- <p>Collectively we denote all the parameters by <span class="math notranslate nohighlight">\(\params=(\hmmTrans, \hmmObs, \hmmInit)\)</span>.</p>
- <p>Now let us implement this model in code.</p>
- <div class="cell docutils container">
- <div class="cell_input docutils container">
- <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># state transition matrix</span>
- <span class="n">A</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span>
- <span class="p">[</span><span class="mf">0.95</span><span class="p">,</span> <span class="mf">0.05</span><span class="p">],</span>
- <span class="p">[</span><span class="mf">0.10</span><span class="p">,</span> <span class="mf">0.90</span><span class="p">]</span>
- <span class="p">])</span>
- <span class="c1"># observation matrix</span>
- <span class="n">B</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span>
- <span class="p">[</span><span class="mi">1</span><span class="o">/</span><span class="mi">6</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">6</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">6</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">6</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">6</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">6</span><span class="p">],</span> <span class="c1"># fair die</span>
- <span class="p">[</span><span class="mi">1</span><span class="o">/</span><span class="mi">10</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">10</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">10</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">10</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">10</span><span class="p">,</span> <span class="mi">5</span><span class="o">/</span><span class="mi">10</span><span class="p">]</span> <span class="c1"># loaded die</span>
- <span class="p">])</span>
- <span class="n">pi</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">])</span>
- <span class="p">(</span><span class="n">nstates</span><span class="p">,</span> <span class="n">nobs</span><span class="p">)</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">B</span><span class="p">)</span>
- </pre></div>
- </div>
- </div>
- </div>
- <div class="cell docutils container">
- <div class="cell_input docutils container">
- <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">distrax</span>
- <span class="kn">from</span> <span class="nn">distrax</span> <span class="kn">import</span> <span class="n">HMM</span>
- <span class="n">hmm</span> <span class="o">=</span> <span class="n">HMM</span><span class="p">(</span><span class="n">trans_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">Categorical</span><span class="p">(</span><span class="n">probs</span><span class="o">=</span><span class="n">A</span><span class="p">),</span>
- <span class="n">init_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">Categorical</span><span class="p">(</span><span class="n">probs</span><span class="o">=</span><span class="n">pi</span><span class="p">),</span>
- <span class="n">obs_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">Categorical</span><span class="p">(</span><span class="n">probs</span><span class="o">=</span><span class="n">B</span><span class="p">))</span>
- <span class="nb">print</span><span class="p">(</span><span class="n">hmm</span><span class="p">)</span>
- </pre></div>
- </div>
- </div>
- <div class="cell_output docutils container">
- <div class="output stderr highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
- </pre></div>
- </div>
- <div class="output stream highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span><distrax._src.utils.hmm.HMM object at 0x7fbca906ad60>
- </pre></div>
- </div>
- </div>
- </div>
- <p>Let’s sample from the model. We will generate a sequence of latent states, <span class="math notranslate nohighlight">\(\hidden_{1:T}\)</span>,
- which we then convert to a sequence of observations, <span class="math notranslate nohighlight">\(\obs_{1:T}\)</span>.</p>
- <div class="cell docutils container">
- <div class="cell_input docutils container">
- <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">seed</span> <span class="o">=</span> <span class="mi">314</span>
- <span class="n">n_samples</span> <span class="o">=</span> <span class="mi">300</span>
- <span class="n">z_hist</span><span class="p">,</span> <span class="n">x_hist</span> <span class="o">=</span> <span class="n">hmm</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="n">seed</span><span class="o">=</span><span class="n">jr</span><span class="o">.</span><span class="n">PRNGKey</span><span class="p">(</span><span class="n">seed</span><span class="p">),</span> <span class="n">seq_len</span><span class="o">=</span><span class="n">n_samples</span><span class="p">)</span>
- <span class="n">z_hist_str</span> <span class="o">=</span> <span class="s2">""</span><span class="o">.</span><span class="n">join</span><span class="p">((</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">z_hist</span><span class="p">)</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="nb">str</span><span class="p">))[:</span><span class="mi">60</span><span class="p">]</span>
- <span class="n">x_hist_str</span> <span class="o">=</span> <span class="s2">""</span><span class="o">.</span><span class="n">join</span><span class="p">((</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">x_hist</span><span class="p">)</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="nb">str</span><span class="p">))[:</span><span class="mi">60</span><span class="p">]</span>
- <span class="nb">print</span><span class="p">(</span><span class="s2">"Printing sample observed/latent..."</span><span class="p">)</span>
- <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"x: </span><span class="si">{</span><span class="n">x_hist_str</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
- <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"z: </span><span class="si">{</span><span class="n">z_hist_str</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
- </pre></div>
- </div>
- </div>
- <div class="cell_output docutils container">
- <div class="output stream highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>Printing sample observed/latent...
- x: 633665342652353616444236412331351246651613325161656366246242
- z: 222222211111111111111111111111111111111222111111112222211111
- </pre></div>
- </div>
- </div>
- </div>
- <p>Below is the source code for the sampling algorithm.</p>
- <div class="cell docutils container">
- <div class="cell_input docutils container">
- <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">print_source</span><span class="p">(</span><span class="n">hmm</span><span class="o">.</span><span class="n">sample</span><span class="p">)</span>
- </pre></div>
- </div>
- </div>
- <div class="cell_output docutils container">
- <div class="output text_html"><pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"> def sample<span style="font-weight: bold">(</span>self,
- *,
- seed: chex.PRNGKey,
- seq_len: chex.Array<span style="font-weight: bold">)</span> -> Tuple:
- <span style="color: #008000; text-decoration-color: #008000">""</span>"Sample from this HMM.
- Samples an observation of given length according to this
- Hidden Markov Model and gives the sequence of the hidden states
- as well as the observation.
- Args:
- seed: Random key of shape <span style="font-weight: bold">(</span><span style="color: #000080; text-decoration-color: #000080; font-weight: bold">2</span>,<span style="font-weight: bold">)</span> and dtype uint32.
- seq_len: The length of the observation sequence.
- Returns:
- Tuple of hidden state sequence, and observation sequence.
- <span style="color: #008000; text-decoration-color: #008000">""</span>"
- rng_key, rng_init = jax.random.split<span style="font-weight: bold">(</span>seed<span style="font-weight: bold">)</span>
- initial_state = self._init_dist.sample<span style="font-weight: bold">(</span><span style="color: #808000; text-decoration-color: #808000">seed</span>=<span style="color: #800080; text-decoration-color: #800080">rng_init</span><span style="font-weight: bold">)</span>
- def draw_state<span style="font-weight: bold">(</span>prev_state, key<span style="font-weight: bold">)</span>:
- state = self._trans_dist.sample<span style="font-weight: bold">(</span><span style="color: #808000; text-decoration-color: #808000">seed</span>=<span style="color: #800080; text-decoration-color: #800080">key</span><span style="font-weight: bold">)</span>
- return state, state
- rng_state, rng_obs = jax.random.split<span style="font-weight: bold">(</span>rng_key<span style="font-weight: bold">)</span>
- keys = jax.random.split<span style="font-weight: bold">(</span>rng_state, seq_len - <span style="color: #000080; text-decoration-color: #000080; font-weight: bold">1</span><span style="font-weight: bold">)</span>
- _, states = jax.lax.scan<span style="font-weight: bold">(</span>draw_state, initial_state, keys<span style="font-weight: bold">)</span>
- states = jnp.append<span style="font-weight: bold">(</span>initial_state, states<span style="font-weight: bold">)</span>
- def draw_obs<span style="font-weight: bold">(</span>state, key<span style="font-weight: bold">)</span>:
- return self._obs_dist.sample<span style="font-weight: bold">(</span><span style="color: #808000; text-decoration-color: #808000">seed</span>=<span style="color: #800080; text-decoration-color: #800080">key</span><span style="font-weight: bold">)</span>
- keys = jax.random.split<span style="font-weight: bold">(</span>rng_obs, seq_len<span style="font-weight: bold">)</span>
- obs_seq = jax.vmap<span style="font-weight: bold">(</span>draw_obs, <span style="color: #808000; text-decoration-color: #808000">in_axes</span>=<span style="font-weight: bold">(</span><span style="color: #000080; text-decoration-color: #000080; font-weight: bold">0</span>, <span style="color: #000080; text-decoration-color: #000080; font-weight: bold">0</span><span style="font-weight: bold">))(</span>states, keys<span style="font-weight: bold">)</span>
- return states, obs_seq
- </pre>
- </div></div>
- </div>
- <p>Let us check correctness by computing empirical pairwise statistics</p>
- <p>We will compute the number of i->j latent state transitions, and check that it is close to the true
- A[i,j] transition probabilites.</p>
- <div class="cell docutils container">
- <div class="cell_input docutils container">
- <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">collections</span>
- <span class="k">def</span> <span class="nf">compute_counts</span><span class="p">(</span><span class="n">state_seq</span><span class="p">,</span> <span class="n">nstates</span><span class="p">):</span>
- <span class="n">wseq</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">state_seq</span><span class="p">)</span>
- <span class="n">word_pairs</span> <span class="o">=</span> <span class="p">[</span><span class="n">pair</span> <span class="k">for</span> <span class="n">pair</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">wseq</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">wseq</span><span class="p">[</span><span class="mi">1</span><span class="p">:])]</span>
- <span class="n">counter_pairs</span> <span class="o">=</span> <span class="n">collections</span><span class="o">.</span><span class="n">Counter</span><span class="p">(</span><span class="n">word_pairs</span><span class="p">)</span>
- <span class="n">counts</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">nstates</span><span class="p">,</span> <span class="n">nstates</span><span class="p">))</span>
- <span class="k">for</span> <span class="p">(</span><span class="n">k</span><span class="p">,</span><span class="n">v</span><span class="p">)</span> <span class="ow">in</span> <span class="n">counter_pairs</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
- <span class="n">counts</span><span class="p">[</span><span class="n">k</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">k</span><span class="p">[</span><span class="mi">1</span><span class="p">]]</span> <span class="o">=</span> <span class="n">v</span>
- <span class="k">return</span> <span class="n">counts</span>
- <span class="k">def</span> <span class="nf">normalize</span><span class="p">(</span><span class="n">u</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-15</span><span class="p">):</span>
- <span class="n">u</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">u</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">jnp</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">u</span> <span class="o"><</span> <span class="n">eps</span><span class="p">,</span> <span class="n">eps</span><span class="p">,</span> <span class="n">u</span><span class="p">))</span>
- <span class="n">c</span> <span class="o">=</span> <span class="n">u</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="n">axis</span><span class="p">)</span>
- <span class="n">c</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">c</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">c</span><span class="p">)</span>
- <span class="k">return</span> <span class="n">u</span> <span class="o">/</span> <span class="n">c</span><span class="p">,</span> <span class="n">c</span>
- <span class="k">def</span> <span class="nf">normalize_counts</span><span class="p">(</span><span class="n">counts</span><span class="p">):</span>
- <span class="n">ncounts</span> <span class="o">=</span> <span class="n">vmap</span><span class="p">(</span><span class="k">lambda</span> <span class="n">v</span><span class="p">:</span> <span class="n">normalize</span><span class="p">(</span><span class="n">v</span><span class="p">)[</span><span class="mi">0</span><span class="p">],</span> <span class="n">in_axes</span><span class="o">=</span><span class="mi">0</span><span class="p">)(</span><span class="n">counts</span><span class="p">)</span>
- <span class="k">return</span> <span class="n">ncounts</span>
- <span class="n">init_dist</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mf">1.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">])</span>
- <span class="n">trans_mat</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="mf">0.7</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">]])</span>
- <span class="n">obs_mat</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span>
- <span class="n">hmm</span> <span class="o">=</span> <span class="n">HMM</span><span class="p">(</span><span class="n">trans_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">Categorical</span><span class="p">(</span><span class="n">probs</span><span class="o">=</span><span class="n">trans_mat</span><span class="p">),</span>
- <span class="n">init_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">Categorical</span><span class="p">(</span><span class="n">probs</span><span class="o">=</span><span class="n">init_dist</span><span class="p">),</span>
- <span class="n">obs_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">Categorical</span><span class="p">(</span><span class="n">probs</span><span class="o">=</span><span class="n">obs_mat</span><span class="p">))</span>
- <span class="n">rng_key</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">PRNGKey</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
- <span class="n">seq_len</span> <span class="o">=</span> <span class="mi">500</span>
- <span class="n">state_seq</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">hmm</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="n">seed</span><span class="o">=</span><span class="n">PRNGKey</span><span class="p">(</span><span class="n">seed</span><span class="p">),</span> <span class="n">seq_len</span><span class="o">=</span><span class="n">seq_len</span><span class="p">)</span>
- <span class="n">counts</span> <span class="o">=</span> <span class="n">compute_counts</span><span class="p">(</span><span class="n">state_seq</span><span class="p">,</span> <span class="n">nstates</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
- <span class="nb">print</span><span class="p">(</span><span class="n">counts</span><span class="p">)</span>
- <span class="n">trans_mat_empirical</span> <span class="o">=</span> <span class="n">normalize_counts</span><span class="p">(</span><span class="n">counts</span><span class="p">)</span>
- <span class="nb">print</span><span class="p">(</span><span class="n">trans_mat_empirical</span><span class="p">)</span>
- <span class="k">assert</span> <span class="n">jnp</span><span class="o">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">trans_mat</span><span class="p">,</span> <span class="n">trans_mat_empirical</span><span class="p">,</span> <span class="n">atol</span><span class="o">=</span><span class="mf">1e-1</span><span class="p">)</span>
- </pre></div>
- </div>
- </div>
- <div class="cell_output docutils container">
- <div class="output traceback highlight-ipythontb notranslate"><div class="highlight"><pre><span></span><span class="gt">---------------------------------------------------------------------------</span>
- <span class="ne">NameError</span><span class="g g-Whitespace"> </span>Traceback (most recent call last)
- <span class="o"><</span><span class="n">ipython</span><span class="o">-</span><span class="nb">input</span><span class="o">-</span><span class="mi">6</span><span class="o">-</span><span class="n">f054683fcd82</span><span class="o">></span> <span class="ow">in</span> <span class="o"><</span><span class="n">module</span><span class="o">></span>
- <span class="g g-Whitespace"> </span><span class="mi">30</span> <span class="n">rng_key</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">PRNGKey</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
- <span class="g g-Whitespace"> </span><span class="mi">31</span> <span class="n">seq_len</span> <span class="o">=</span> <span class="mi">500</span>
- <span class="ne">---> </span><span class="mi">32</span> <span class="n">state_seq</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">hmm</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="n">seed</span><span class="o">=</span><span class="n">PRNGKey</span><span class="p">(</span><span class="n">seed</span><span class="p">),</span> <span class="n">seq_len</span><span class="o">=</span><span class="n">seq_len</span><span class="p">)</span>
- <span class="g g-Whitespace"> </span><span class="mi">33</span>
- <span class="g g-Whitespace"> </span><span class="mi">34</span> <span class="n">counts</span> <span class="o">=</span> <span class="n">compute_counts</span><span class="p">(</span><span class="n">state_seq</span><span class="p">,</span> <span class="n">nstates</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
- <span class="ne">NameError</span>: name 'PRNGKey' is not defined
- </pre></div>
- </div>
- </div>
- </div>
- <p>Our primary goal will be to infer the latent state from the observations,
- so we can detect if the casino is being dishonest or not. This will
- affect how we choose to gamble our money.
- We discuss various ways to perform this inference below.</p>
- </div>
- <div class="section" id="example-lillypad-hmm">
- <span id="sec-lillypad"></span><h2>Example: Lillypad HMM<a class="headerlink" href="#example-lillypad-hmm" title="Permalink to this headline">¶</a></h2>
- <p>If <span class="math notranslate nohighlight">\(\obs_t\)</span> is continuous, it is common to use a Gaussian
- observation model:</p>
- <div class="math notranslate nohighlight">
- \[p(\obs_t|\hidden_t=j) = \gauss(\obs_t|\vmu_j,\vSigma_j)\]</div>
- <p>This is sometimes called a Gaussian HMM.</p>
- <p>As a simple example, suppose we have an HMM with 3 hidden states,
- each of which generates a 2d Gaussian.
- We can represent these Gaussian distributions are 2d ellipses,
- as we show below.
- We call these ``lilly pads’’, because of their shape.
- We can imagine a frog hopping from one lilly pad to another.
- (This analogy is due to the late Sam Roweis.)
- The frog will stay on a pad for a while (corresponding to remaining in the same
- discrete state <span class="math notranslate nohighlight">\(\hidden_t\)</span>), and then jump to a new pad
- (corresponding to a transition to a new state).
- The data we see are just the 2d points (e.g., water droplets)
- coming from near the pad that the frog is currently on.
- Thus this model is like a Gaussian mixture model,
- in that it generates clusters of observations,
- except now there is temporal correlation between the data points.</p>
- <p>Let us now illustrate this model in code.</p>
- <div class="cell docutils container">
- <div class="cell_input docutils container">
- <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># Let us create the model</span>
- <span class="n">initial_probs</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.2</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">])</span>
- <span class="c1"># transition matrix</span>
- <span class="n">A</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span>
- <span class="p">[</span><span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.4</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">],</span>
- <span class="p">[</span><span class="mf">0.1</span><span class="p">,</span> <span class="mf">0.6</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">],</span>
- <span class="p">[</span><span class="mf">0.2</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">]</span>
- <span class="p">])</span>
- <span class="c1"># Observation model</span>
- <span class="n">mu_collection</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span>
- <span class="p">[</span><span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">],</span>
- <span class="p">[</span><span class="mf">0.8</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">],</span>
- <span class="p">[</span><span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.8</span><span class="p">]</span>
- <span class="p">])</span>
- <span class="n">S1</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="mf">1.1</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">]])</span>
- <span class="n">S2</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="mf">0.3</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.5</span><span class="p">],</span> <span class="p">[</span><span class="o">-</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">1.3</span><span class="p">]])</span>
- <span class="n">S3</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="mf">0.8</span><span class="p">,</span> <span class="mf">0.4</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.4</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">]])</span>
- <span class="n">cov_collection</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">S1</span><span class="p">,</span> <span class="n">S2</span><span class="p">,</span> <span class="n">S3</span><span class="p">])</span> <span class="o">/</span> <span class="mi">60</span>
- <span class="kn">import</span> <span class="nn">tensorflow_probability</span> <span class="k">as</span> <span class="nn">tfp</span>
- <span class="k">if</span> <span class="kc">False</span><span class="p">:</span>
- <span class="n">hmm</span> <span class="o">=</span> <span class="n">HMM</span><span class="p">(</span><span class="n">trans_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">Categorical</span><span class="p">(</span><span class="n">probs</span><span class="o">=</span><span class="n">A</span><span class="p">),</span>
- <span class="n">init_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">Categorical</span><span class="p">(</span><span class="n">probs</span><span class="o">=</span><span class="n">initial_probs</span><span class="p">),</span>
- <span class="n">obs_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">MultivariateNormalFullCovariance</span><span class="p">(</span>
- <span class="n">loc</span><span class="o">=</span><span class="n">mu_collection</span><span class="p">,</span> <span class="n">covariance_matrix</span><span class="o">=</span><span class="n">cov_collection</span><span class="p">))</span>
- <span class="k">else</span><span class="p">:</span>
- <span class="n">hmm</span> <span class="o">=</span> <span class="n">HMM</span><span class="p">(</span><span class="n">trans_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">Categorical</span><span class="p">(</span><span class="n">probs</span><span class="o">=</span><span class="n">A</span><span class="p">),</span>
- <span class="n">init_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">Categorical</span><span class="p">(</span><span class="n">probs</span><span class="o">=</span><span class="n">initial_probs</span><span class="p">),</span>
- <span class="n">obs_dist</span><span class="o">=</span><span class="n">distrax</span><span class="o">.</span><span class="n">as_distribution</span><span class="p">(</span>
- <span class="n">tfp</span><span class="o">.</span><span class="n">substrates</span><span class="o">.</span><span class="n">jax</span><span class="o">.</span><span class="n">distributions</span><span class="o">.</span><span class="n">MultivariateNormalFullCovariance</span><span class="p">(</span><span class="n">loc</span><span class="o">=</span><span class="n">mu_collection</span><span class="p">,</span>
- <span class="n">covariance_matrix</span><span class="o">=</span><span class="n">cov_collection</span><span class="p">)))</span>
- <span class="nb">print</span><span class="p">(</span><span class="n">hmm</span><span class="p">)</span>
- </pre></div>
- </div>
- </div>
- <div class="cell_output docutils container">
- <div class="output stream highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span><distrax._src.utils.hmm.HMM object at 0x7fd658f7b7f0>
- </pre></div>
- </div>
- </div>
- </div>
- <div class="cell docutils container">
- <div class="cell_input docutils container">
- <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">n_samples</span><span class="p">,</span> <span class="n">seed</span> <span class="o">=</span> <span class="mi">50</span><span class="p">,</span> <span class="mi">10</span>
- <span class="n">samples_state</span><span class="p">,</span> <span class="n">samples_obs</span> <span class="o">=</span> <span class="n">hmm</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="n">seed</span><span class="o">=</span><span class="n">PRNGKey</span><span class="p">(</span><span class="n">seed</span><span class="p">),</span> <span class="n">seq_len</span><span class="o">=</span><span class="n">n_samples</span><span class="p">)</span>
- <span class="nb">print</span><span class="p">(</span><span class="n">samples_state</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
- <span class="nb">print</span><span class="p">(</span><span class="n">samples_obs</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
- </pre></div>
- </div>
- </div>
- <div class="cell_output docutils container">
- <div class="output stream highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>(50,)
- (50, 2)
- </pre></div>
- </div>
- </div>
- </div>
- <div class="cell docutils container">
- <div class="cell_input docutils container">
- <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># Let's plot the observed data in 2d</span>
- <span class="n">xmin</span><span class="p">,</span> <span class="n">xmax</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span>
- <span class="n">ymin</span><span class="p">,</span> <span class="n">ymax</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="mf">1.2</span>
- <span class="n">colors</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"tab:green"</span><span class="p">,</span> <span class="s2">"tab:blue"</span><span class="p">,</span> <span class="s2">"tab:red"</span><span class="p">]</span>
- <span class="k">def</span> <span class="nf">plot_2dhmm</span><span class="p">(</span><span class="n">hmm</span><span class="p">,</span> <span class="n">samples_obs</span><span class="p">,</span> <span class="n">samples_state</span><span class="p">,</span> <span class="n">colors</span><span class="p">,</span> <span class="n">ax</span><span class="p">,</span> <span class="n">xmin</span><span class="p">,</span> <span class="n">xmax</span><span class="p">,</span> <span class="n">ymin</span><span class="p">,</span> <span class="n">ymax</span><span class="p">,</span> <span class="n">step</span><span class="o">=</span><span class="mf">1e-2</span><span class="p">):</span>
- <span class="n">obs_dist</span> <span class="o">=</span> <span class="n">hmm</span><span class="o">.</span><span class="n">obs_dist</span>
- <span class="n">color_sample</span> <span class="o">=</span> <span class="p">[</span><span class="n">colors</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">samples_state</span><span class="p">]</span>
- <span class="n">xs</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">xmin</span><span class="p">,</span> <span class="n">xmax</span><span class="p">,</span> <span class="n">step</span><span class="p">)</span>
- <span class="n">ys</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">ymin</span><span class="p">,</span> <span class="n">ymax</span><span class="p">,</span> <span class="n">step</span><span class="p">)</span>
- <span class="n">v_prob</span> <span class="o">=</span> <span class="n">vmap</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">obs_dist</span><span class="o">.</span><span class="n">prob</span><span class="p">(</span><span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">])),</span> <span class="n">in_axes</span><span class="o">=</span><span class="p">(</span><span class="kc">None</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span>
- <span class="n">z</span> <span class="o">=</span> <span class="n">vmap</span><span class="p">(</span><span class="n">v_prob</span><span class="p">,</span> <span class="n">in_axes</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="kc">None</span><span class="p">))(</span><span class="n">xs</span><span class="p">,</span> <span class="n">ys</span><span class="p">)</span>
- <span class="n">grid</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">mgrid</span><span class="p">[</span><span class="n">xmin</span><span class="p">:</span><span class="n">xmax</span><span class="p">:</span><span class="n">step</span><span class="p">,</span> <span class="n">ymin</span><span class="p">:</span><span class="n">ymax</span><span class="p">:</span><span class="n">step</span><span class="p">]</span>
- <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">color</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">colors</span><span class="p">):</span>
- <span class="n">ax</span><span class="o">.</span><span class="n">contour</span><span class="p">(</span><span class="o">*</span><span class="n">grid</span><span class="p">,</span> <span class="n">z</span><span class="p">[:,</span> <span class="p">:,</span> <span class="n">k</span><span class="p">],</span> <span class="n">levels</span><span class="o">=</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">colors</span><span class="o">=</span><span class="n">color</span><span class="p">,</span> <span class="n">linewidths</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span>
- <span class="n">ax</span><span class="o">.</span><span class="n">text</span><span class="p">(</span><span class="o">*</span><span class="p">(</span><span class="n">obs_dist</span><span class="o">.</span><span class="n">mean</span><span class="p">()[</span><span class="n">k</span><span class="p">]</span> <span class="o">+</span> <span class="mf">0.13</span><span class="p">),</span> <span class="sa">f</span><span class="s2">"$k$=</span><span class="si">{</span><span class="n">k</span> <span class="o">+</span> <span class="mi">1</span><span class="si">}</span><span class="s2">"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">13</span><span class="p">,</span> <span class="n">horizontalalignment</span><span class="o">=</span><span class="s2">"right"</span><span class="p">)</span>
- <span class="n">ax</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="o">*</span><span class="n">samples_obs</span><span class="o">.</span><span class="n">T</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="s2">"black"</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.3</span><span class="p">,</span> <span class="n">zorder</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
- <span class="n">ax</span><span class="o">.</span><span class="n">scatter</span><span class="p">(</span><span class="o">*</span><span class="n">samples_obs</span><span class="o">.</span><span class="n">T</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="n">color_sample</span><span class="p">,</span> <span class="n">s</span><span class="o">=</span><span class="mi">30</span><span class="p">,</span> <span class="n">zorder</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.8</span><span class="p">)</span>
- <span class="k">return</span> <span class="n">ax</span><span class="p">,</span> <span class="n">color_sample</span>
- <span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">()</span>
- <span class="n">_</span><span class="p">,</span> <span class="n">color_sample</span> <span class="o">=</span> <span class="n">plot_2dhmm</span><span class="p">(</span><span class="n">hmm</span><span class="p">,</span> <span class="n">samples_obs</span><span class="p">,</span> <span class="n">samples_state</span><span class="p">,</span> <span class="n">colors</span><span class="p">,</span> <span class="n">ax</span><span class="p">,</span> <span class="n">xmin</span><span class="p">,</span> <span class="n">xmax</span><span class="p">,</span> <span class="n">ymin</span><span class="p">,</span> <span class="n">ymax</span><span class="p">)</span>
- </pre></div>
- </div>
- </div>
- <div class="cell_output docutils container">
- <img alt="../../_images/hmm_15_0.png" src="../../_images/hmm_15_0.png" />
- </div>
- </div>
- <div class="cell docutils container">
- <div class="cell_input docutils container">
- <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># Let's plot the hidden state sequence</span>
- <span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">()</span>
- <span class="n">ax</span><span class="o">.</span><span class="n">step</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">n_samples</span><span class="p">),</span> <span class="n">samples_state</span><span class="p">,</span> <span class="n">where</span><span class="o">=</span><span class="s2">"post"</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="s2">"black"</span><span class="p">,</span> <span class="n">linewidth</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.3</span><span class="p">)</span>
- <span class="n">ax</span><span class="o">.</span><span class="n">scatter</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">n_samples</span><span class="p">),</span> <span class="n">samples_state</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="n">color_sample</span><span class="p">,</span> <span class="n">zorder</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span>
- </pre></div>
- </div>
- </div>
- <div class="cell_output docutils container">
- <div class="output text_plain highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span><matplotlib.collections.PathCollection at 0x7fd5e174b460>
- </pre></div>
- </div>
- <img alt="../../_images/hmm_16_1.png" src="../../_images/hmm_16_1.png" />
- </div>
- </div>
- </div>
- </div>
- <script type="text/x-thebe-config">
- {
- requestKernel: true,
- binderOptions: {
- repo: "binder-examples/jupyter-stacks-datascience",
- ref: "master",
- },
- codeMirrorConfig: {
- theme: "abcdef",
- mode: "python"
- },
- kernelOptions: {
- kernelName: "python3",
- path: "./chapters/ssm"
- },
- predefinedOutput: true
- }
- </script>
- <script>kernelName = 'python3'</script>
- </div>
-
-
- <!-- Previous / next buttons -->
- <div class='prev-next-area'>
- <a class='left-prev' id="prev-link" href="ssm_intro.html" title="previous page">
- <i class="fas fa-angle-left"></i>
- <div class="prev-next-info">
- <p class="prev-next-subtitle">previous</p>
- <p class="prev-next-title">What are State Space Models?</p>
- </div>
- </a>
- <a class='right-next' id="next-link" href="lds.html" title="next page">
- <div class="prev-next-info">
- <p class="prev-next-subtitle">next</p>
- <p class="prev-next-title">Linear Gaussian SSMs</p>
- </div>
- <i class="fas fa-angle-right"></i>
- </a>
- </div>
-
- </div>
- </div>
- <footer class="footer">
- <p>
-
- By Kevin Murphy, Scott Linderman, et al.<br/>
-
- © Copyright 2021.<br/>
- </p>
- </footer>
- </main>
- </div>
- </div>
-
- <script src="../../_static/js/index.be7d3bbb2ef33a8344ce.js"></script>
- </body>
- </html>
|