{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "### Import standard libraries\n", "\n", "import abc\n", "from dataclasses import dataclass\n", "import functools\n", "from functools import partial\n", "import itertools\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "from typing import Any, Callable, NamedTuple, Optional, Union, Tuple\n", "\n", "import jax\n", "import jax.numpy as jnp\n", "from jax import lax, vmap, jit, grad\n", "#from jax.scipy.special import logit\n", "#from jax.nn import softmax\n", "import jax.random as jr\n", "\n", "\n", "\n", "import distrax\n", "import optax\n", "\n", "import jsl\n", "import ssm_jax\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "(sec:learning)=\n", "# Parameter estimation (learning)\n", "\n", "\n", "So far, we have assumed that the parameters $\\params$ of the SSM are known.\n", "For example, in the case of an HMM with categorical observations\n", "we have $\\params = (\\hmmInit, \\hmmTrans, \\hmmObs)$,\n", "and in the case of an LDS, we have $\\params = \n", "(\\ldsTrans, \\ldsObs, \\ldsTransIn, \\ldsObsIn, \\transCov, \\obsCov, \\initMean, \\initCov)$.\n", "If we adopt a Bayesian perspective, we can view these parameters as random variables that are\n", "shared across all time steps, and across all sequences.\n", "This is shown in {numref}`fig:hmm-plates`, where we adopt $\\keyword{plate notation}$\n", "to represent repetitive structure.\n", "\n", "```{figure} /figures/hmmDgmPlatesY.png\n", ":scale: 100%\n", ":name: fig:hmm-plates\n", "\n", "Illustration of an HMM using plate notation, where we show the parameter\n", "nodes which are shared across all the sequences.\n", "```\n", "\n", "Suppose we observe $N$ sequences $\\data = \\{\\obs_{n,1:T_n}: n=1:N\\}$.\n", "Then the goal of $\\keyword{parameter estimation}$, also called $\\keyword{model learning}$\n", "or $\\keyword{model fitting}$, is to approximate the posterior\n", "\\begin{align}\n", "p(\\params|\\data) \\propto p(\\params) \\prod_{n=1}^N p(\\obs_{n,1:T_n} | \\params)\n", "\\end{align}\n", "where $p(\\obs_{n,1:T_n} | \\params)$ is the marginal likelihood of sequence $n$:\n", "\\begin{align}\n", "p(\\obs_{1:T} | \\params) = \\int p(\\hidden_{1:T}, \\obs_{1:T} | \\params) d\\hidden_{1:T}\n", "\\end{align}\n", "\n", "Since computing the full posterior is computationally difficult, we often settle for computing\n", "a point estimate such as the MAP (maximum a posterior) estimate\n", "\\begin{align}\n", "\\params_{\\map} = \\arg \\max_{\\params} \\log p(\\params) + \\sum_{n=1}^N \\log p(\\obs_{n,1:T_n} | \\params)\n", "\\end{align}\n", "If we ignore the prior term, we get the maximum likelihood estimate or MLE:\n", "\\begin{align}\n", "\\params_{\\mle} = \\arg \\max_{\\params} \\sum_{n=1}^N \\log p(\\obs_{n,1:T_n} | \\params)\n", "\\end{align}\n", "In practice, the MAP estimate often works better than the MLE, since the prior can regularize\n", "the estimate to ensure the model is numerically stable and does not overfit the training set.\n", "\n", "We will discuss a variety of algorithms for parameter estimation in later chapters.\n", "\n" ] } ], "metadata": { "interpreter": { "hash": "6407c60499271029b671b4ff687c4ed4626355c45fd34c44476827f4be42c4d7" }, "kernelspec": { "display_name": "Python 3.9.2 ('spyder-dev')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.2" } }, "nbformat": 4, "nbformat_minor": 4 }