{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: distrax in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (0.0.1)\n", "Collecting distrax\n", " Downloading distrax-0.1.2-py3-none-any.whl (272 kB)\n", "\u001b[K |████████████████████████████████| 272 kB 6.9 MB/s eta 0:00:01\n", "\u001b[?25hRequirement already satisfied: jax>=0.1.55 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from distrax) (0.2.11)\n", "Requirement already satisfied: absl-py>=0.9.0 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from distrax) (0.12.0)\n", "Requirement already satisfied: chex>=0.0.7 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from distrax) (0.0.8)\n", "Requirement already satisfied: jaxlib>=0.1.67 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from distrax) (0.1.70)\n", "Requirement already satisfied: numpy>=1.18.0 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from distrax) (1.19.5)\n", "Collecting tensorflow-probability>=0.15.0\n", " Using cached tensorflow_probability-0.16.0-py2.py3-none-any.whl (6.3 MB)\n", "Requirement already satisfied: six in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from absl-py>=0.9.0->distrax) (1.15.0)\n", "Requirement already satisfied: dm-tree>=0.1.5 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from chex>=0.0.7->distrax) (0.1.6)\n", "Requirement already satisfied: toolz>=0.9.0 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from chex>=0.0.7->distrax) (0.11.1)\n", "Requirement already satisfied: opt-einsum in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from jax>=0.1.55->distrax) (3.3.0)\n", "Requirement already satisfied: flatbuffers<3.0,>=1.12 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from jaxlib>=0.1.67->distrax) (1.12)\n", "Requirement already satisfied: scipy in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from jaxlib>=0.1.67->distrax) (1.6.3)\n", "Requirement already satisfied: cloudpickle>=1.3 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from tensorflow-probability>=0.15.0->distrax) (1.6.0)\n", "Requirement already satisfied: decorator in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from tensorflow-probability>=0.15.0->distrax) (4.4.2)\n", "Requirement already satisfied: gast>=0.3.2 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from tensorflow-probability>=0.15.0->distrax) (0.4.0)\n", "Installing collected packages: tensorflow-probability, distrax\n", " Attempting uninstall: tensorflow-probability\n", " Found existing installation: tensorflow-probability 0.13.0\n", " Uninstalling tensorflow-probability-0.13.0:\n", " Successfully uninstalled tensorflow-probability-0.13.0\n", " Attempting uninstall: distrax\n", " Found existing installation: distrax 0.0.1\n", " Uninstalling distrax-0.0.1:\n", " Successfully uninstalled distrax-0.0.1\n", "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", "jsl 0.0.0 requires dataclasses, which is not installed.\u001b[0m\n", "Successfully installed distrax-0.1.2 tensorflow-probability-0.16.0\n", "\u001b[33mWARNING: You are using pip version 21.2.4; however, version 22.0.4 is available.\n", "You should consider upgrading via the '/opt/anaconda3/envs/spyder-dev/bin/python -m pip install --upgrade pip' command.\u001b[0m\n", "Note: you may need to restart the kernel to use updated packages.\n" ] } ], "source": [ "# meta-data does not work yet in VScode\n", "# https://github.com/microsoft/vscode-jupyter/issues/1121\n", "\n", "{\n", " \"tags\": [\n", " \"hide-cell\"\n", " ]\n", "}\n", "\n", "\n", "### Install necessary libraries\n", "\n", "try:\n", " import jax\n", "except:\n", " # For cuda version, see https://github.com/google/jax#installation\n", " %pip install --upgrade \"jax[cpu]\" \n", " import jax\n", "\n", "try:\n", " import distrax\n", "except:\n", " %pip install --upgrade distrax\n", " import distrax\n", "\n", "try:\n", " import jsl\n", "except:\n", " %pip install git+https://github.com/probml/jsl\n", " import jsl\n", "\n", "try:\n", " import rich\n", "except:\n", " %pip install rich\n", " import rich\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "{\n", " \"tags\": [\n", " \"hide-cell\"\n", " ]\n", "}\n", "\n", "\n", "### Import standard libraries\n", "\n", "import abc\n", "from dataclasses import dataclass\n", "import functools\n", "import itertools\n", "\n", "from typing import Any, Callable, NamedTuple, Optional, Union, Tuple\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "\n", "import jax\n", "import jax.numpy as jnp\n", "from jax import lax, vmap, jit, grad\n", "from jax.scipy.special import logit\n", "from jax.nn import softmax\n", "from functools import partial\n", "from jax.random import PRNGKey, split\n", "\n", "import inspect\n", "import inspect as py_inspect\n", "import rich\n", "from rich import inspect as r_inspect\n", "from rich import print as r_print\n", "\n", "def print_source(fname):\n", " r_print(py_inspect.getsource(fname))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```{math}\n", "\n", "\\newcommand\\floor[1]{\\lfloor#1\\rfloor}\n", "\n", "\\newcommand{\\real}{\\mathbb{R}}\n", "\n", "% Numbers\n", "\\newcommand{\\vzero}{\\boldsymbol{0}}\n", "\\newcommand{\\vone}{\\boldsymbol{1}}\n", "\n", "% Greek https://www.latex-tutorial.com/symbols/greek-alphabet/\n", "\\newcommand{\\valpha}{\\boldsymbol{\\alpha}}\n", "\\newcommand{\\vbeta}{\\boldsymbol{\\beta}}\n", "\\newcommand{\\vchi}{\\boldsymbol{\\chi}}\n", "\\newcommand{\\vdelta}{\\boldsymbol{\\delta}}\n", "\\newcommand{\\vDelta}{\\boldsymbol{\\Delta}}\n", "\\newcommand{\\vepsilon}{\\boldsymbol{\\epsilon}}\n", "\\newcommand{\\vzeta}{\\boldsymbol{\\zeta}}\n", "\\newcommand{\\vXi}{\\boldsymbol{\\Xi}}\n", "\\newcommand{\\vell}{\\boldsymbol{\\ell}}\n", "\\newcommand{\\veta}{\\boldsymbol{\\eta}}\n", "%\\newcommand{\\vEta}{\\boldsymbol{\\Eta}}\n", "\\newcommand{\\vgamma}{\\boldsymbol{\\gamma}}\n", "\\newcommand{\\vGamma}{\\boldsymbol{\\Gamma}}\n", "\\newcommand{\\vmu}{\\boldsymbol{\\mu}}\n", "\\newcommand{\\vmut}{\\boldsymbol{\\tilde{\\mu}}}\n", "\\newcommand{\\vnu}{\\boldsymbol{\\nu}}\n", "\\newcommand{\\vkappa}{\\boldsymbol{\\kappa}}\n", "\\newcommand{\\vlambda}{\\boldsymbol{\\lambda}}\n", "\\newcommand{\\vLambda}{\\boldsymbol{\\Lambda}}\n", "\\newcommand{\\vLambdaBar}{\\overline{\\vLambda}}\n", "%\\newcommand{\\vnu}{\\boldsymbol{\\nu}}\n", "\\newcommand{\\vomega}{\\boldsymbol{\\omega}}\n", "\\newcommand{\\vOmega}{\\boldsymbol{\\Omega}}\n", "\\newcommand{\\vphi}{\\boldsymbol{\\phi}}\n", "\\newcommand{\\vvarphi}{\\boldsymbol{\\varphi}}\n", "\\newcommand{\\vPhi}{\\boldsymbol{\\Phi}}\n", "\\newcommand{\\vpi}{\\boldsymbol{\\pi}}\n", "\\newcommand{\\vPi}{\\boldsymbol{\\Pi}}\n", "\\newcommand{\\vpsi}{\\boldsymbol{\\psi}}\n", "\\newcommand{\\vPsi}{\\boldsymbol{\\Psi}}\n", "\\newcommand{\\vrho}{\\boldsymbol{\\rho}}\n", "\\newcommand{\\vtheta}{\\boldsymbol{\\theta}}\n", "\\newcommand{\\vthetat}{\\boldsymbol{\\tilde{\\theta}}}\n", "\\newcommand{\\vTheta}{\\boldsymbol{\\Theta}}\n", "\\newcommand{\\vsigma}{\\boldsymbol{\\sigma}}\n", "\\newcommand{\\vSigma}{\\boldsymbol{\\Sigma}}\n", "\\newcommand{\\vSigmat}{\\boldsymbol{\\tilde{\\Sigma}}}\n", "\\newcommand{\\vsigmoid}{\\vsigma}\n", "\\newcommand{\\vtau}{\\boldsymbol{\\tau}}\n", "\\newcommand{\\vxi}{\\boldsymbol{\\xi}}\n", "\n", "\n", "% Lower Roman (Vectors)\n", "\\newcommand{\\va}{\\mathbf{a}}\n", "\\newcommand{\\vb}{\\mathbf{b}}\n", "\\newcommand{\\vBt}{\\mathbf{\\tilde{B}}}\n", "\\newcommand{\\vc}{\\mathbf{c}}\n", "\\newcommand{\\vct}{\\mathbf{\\tilde{c}}}\n", "\\newcommand{\\vd}{\\mathbf{d}}\n", "\\newcommand{\\ve}{\\mathbf{e}}\n", "\\newcommand{\\vf}{\\mathbf{f}}\n", "\\newcommand{\\vg}{\\mathbf{g}}\n", "\\newcommand{\\vh}{\\mathbf{h}}\n", "%\\newcommand{\\myvh}{\\mathbf{h}}\n", "\\newcommand{\\vi}{\\mathbf{i}}\n", "\\newcommand{\\vj}{\\mathbf{j}}\n", "\\newcommand{\\vk}{\\mathbf{k}}\n", "\\newcommand{\\vl}{\\mathbf{l}}\n", "\\newcommand{\\vm}{\\mathbf{m}}\n", "\\newcommand{\\vn}{\\mathbf{n}}\n", "\\newcommand{\\vo}{\\mathbf{o}}\n", "\\newcommand{\\vp}{\\mathbf{p}}\n", "\\newcommand{\\vq}{\\mathbf{q}}\n", "\\newcommand{\\vr}{\\mathbf{r}}\n", "\\newcommand{\\vs}{\\mathbf{s}}\n", "\\newcommand{\\vt}{\\mathbf{t}}\n", "\\newcommand{\\vu}{\\mathbf{u}}\n", "\\newcommand{\\vv}{\\mathbf{v}}\n", "\\newcommand{\\vw}{\\mathbf{w}}\n", "\\newcommand{\\vws}{\\vw_s}\n", "\\newcommand{\\vwt}{\\mathbf{\\tilde{w}}}\n", "\\newcommand{\\vWt}{\\mathbf{\\tilde{W}}}\n", "\\newcommand{\\vwh}{\\hat{\\vw}}\n", "\\newcommand{\\vx}{\\mathbf{x}}\n", "%\\newcommand{\\vx}{\\mathbf{x}}\n", "\\newcommand{\\vxt}{\\mathbf{\\tilde{x}}}\n", "\\newcommand{\\vy}{\\mathbf{y}}\n", "\\newcommand{\\vyt}{\\mathbf{\\tilde{y}}}\n", "\\newcommand{\\vz}{\\mathbf{z}}\n", "%\\newcommand{\\vzt}{\\mathbf{\\tilde{z}}}\n", "\n", "\n", "% Upper Roman (Matrices)\n", "\\newcommand{\\vA}{\\mathbf{A}}\n", "\\newcommand{\\vB}{\\mathbf{B}}\n", "\\newcommand{\\vC}{\\mathbf{C}}\n", "\\newcommand{\\vD}{\\mathbf{D}}\n", "\\newcommand{\\vE}{\\mathbf{E}}\n", "\\newcommand{\\vF}{\\mathbf{F}}\n", "\\newcommand{\\vG}{\\mathbf{G}}\n", "\\newcommand{\\vH}{\\mathbf{H}}\n", "\\newcommand{\\vI}{\\mathbf{I}}\n", "\\newcommand{\\vJ}{\\mathbf{J}}\n", "\\newcommand{\\vK}{\\mathbf{K}}\n", "\\newcommand{\\vL}{\\mathbf{L}}\n", "\\newcommand{\\vM}{\\mathbf{M}}\n", "\\newcommand{\\vMt}{\\mathbf{\\tilde{M}}}\n", "\\newcommand{\\vN}{\\mathbf{N}}\n", "\\newcommand{\\vO}{\\mathbf{O}}\n", "\\newcommand{\\vP}{\\mathbf{P}}\n", "\\newcommand{\\vQ}{\\mathbf{Q}}\n", "\\newcommand{\\vR}{\\mathbf{R}}\n", "\\newcommand{\\vS}{\\mathbf{S}}\n", "\\newcommand{\\vT}{\\mathbf{T}}\n", "\\newcommand{\\vU}{\\mathbf{U}}\n", "\\newcommand{\\vV}{\\mathbf{V}}\n", "\\newcommand{\\vW}{\\mathbf{W}}\n", "\\newcommand{\\vX}{\\mathbf{X}}\n", "%\\newcommand{\\vXs}{\\vX_{\\vs}}\n", "\\newcommand{\\vXs}{\\vX_{s}}\n", "\\newcommand{\\vXt}{\\mathbf{\\tilde{X}}}\n", "\\newcommand{\\vY}{\\mathbf{Y}}\n", "\\newcommand{\\vZ}{\\mathbf{Z}}\n", "\\newcommand{\\vZt}{\\mathbf{\\tilde{Z}}}\n", "\\newcommand{\\vzt}{\\mathbf{\\tilde{z}}}\n", "\n", "\n", "%%%%\n", "\\newcommand{\\hidden}{\\vz}\n", "\\newcommand{\\hid}{\\hidden}\n", "\\newcommand{\\observed}{\\vy}\n", "\\newcommand{\\obs}{\\observed}\n", "\\newcommand{\\inputs}{\\vu}\n", "\\newcommand{\\input}{\\inputs}\n", "\n", "\\newcommand{\\hmmTrans}{\\vA}\n", "\\newcommand{\\hmmObs}{\\vB}\n", "\\newcommand{\\hmmInit}{\\vpi}\n", "\\newcommand{\\hmmhid}{\\hidden}\n", "\\newcommand{\\hmmobs}{\\obs}\n", "\n", "\\newcommand{\\ldsDyn}{\\vA}\n", "\\newcommand{\\ldsObs}{\\vC}\n", "\\newcommand{\\ldsDynIn}{\\vB}\n", "\\newcommand{\\ldsObsIn}{\\vD}\n", "\\newcommand{\\ldsDynNoise}{\\vQ}\n", "\\newcommand{\\ldsObsNoise}{\\vR}\n", "\n", "\\newcommand{\\ssmDynFn}{f}\n", "\\newcommand{\\ssmObsFn}{h}\n", "\n", "\n", "%%%\n", "\\newcommand{\\gauss}{\\mathcal{N}}\n", "\n", "\\newcommand{\\diag}{\\mathrm{diag}}\n", "```\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "(sec:lds-intro)=\n", "# Linear Gaussian SSMs\n", "\n", "\n", "Consider the state space model in \n", "{eq}`eq:SSM-ar`\n", "where we assume the observations are conditionally iid given the\n", "hidden states and inputs (i.e. there are no auto-regressive dependencies\n", "between the observables).\n", "We can rewrite this model as \n", "a stochastic nonlinear dynamical system (NLDS)\n", "by defining the distribution of the next hidden state \n", "as a deterministic function of the past state\n", "plus random process noise $\\vepsilon_t$ \n", "\\begin{align}\n", "\\hmmhid_t &= \\ssmDynFn(\\hmmhid_{t-1}, \\inputs_t, \\vepsilon_t) \n", "\\end{align}\n", "where $\\vepsilon_t$ is drawn from the distribution such\n", "that the induced distribution\n", "on $\\hmmhid_t$ matches $p(\\hmmhid_t|\\hmmhid_{t-1}, \\inputs_t)$.\n", "Similarly we can rewrite the observation distributions\n", "as a deterministic function of the hidden state\n", "plus observation noise $\\veta_t$:\n", "\\begin{align}\n", "\\hmmobs_t &= \\ssmObsFn(\\hmmhid_{t}, \\inputs_t, \\veta_t)\n", "\\end{align}\n", "\n", "\n", "If we assume additive Gaussian noise,\n", "the model becomes\n", "\\begin{align}\n", "\\hmmhid_t &= \\ssmDynFn(\\hmmhid_{t-1}, \\inputs_t) + \\vepsilon_t \\\\\n", "\\hmmobs_t &= \\ssmObsFn(\\hmmhid_{t}, \\inputs_t) + \\veta_t\n", "\\end{align}\n", "where $\\vepsilon_t \\sim \\gauss(\\vzero,\\vQ_t)$\n", "and $\\veta_t \\sim \\gauss(\\vzero,\\vR_t)$.\n", "We will call these Gaussian SSMs.\n", "\n", "If we additionally assume\n", "the transition function $\\ssmDynFn$\n", "and the observation function $\\ssmObsFn$ are both linear,\n", "then we can rewrite the model as follows:\n", "\\begin{align}\n", "p(\\hmmhid_t|\\hmmhid_{t-1},\\inputs_t) &= \\gauss(\\hmmhid_t|\\ldsDyn_t \\hmmhid_{t-1}\n", "+ \\ldsDynIn_t \\inputs_t, \\vQ_t)\n", "\\\\\n", "p(\\hmmobs_t|\\hmmhid_t,\\inputs_t) &= \\gauss(\\hmmobs_t|\\ldsObs_t \\hmmhid_{t}\n", "+ \\ldsObsIn_t \\inputs_t, \\vR_t)\n", "\\end{align}\n", "This is called a \n", "linear-Gaussian state space model\n", "(LG-SSM),\n", "or a\n", "linear dynamical system (LDS).\n", "We usually assume the parameters are independent of time, in which case\n", "the model is said to be time-invariant or homogeneous.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "(sec:tracking-lds)=\n", "(sec:kalman-tracking)=\n", "## Example: tracking a 2d point\n", "\n", "\n", "\n", "% Sarkkar p43\n", "Consider an object moving in $\\real^2$.\n", "Let the state be\n", "the position and velocity of the object,\n", "$\\vz_t =\\begin{pmatrix} u_t & \\dot{u}_t & v_t & \\dot{v}_t \\end{pmatrix}$.\n", "(We use $u$ and $v$ for the two coordinates,\n", "to avoid confusion with the state and observation variables.)\n", "If we use Euler discretization,\n", "the dynamics become\n", "\\begin{align}\n", "\\underbrace{\\begin{pmatrix} u_t\\\\ \\dot{u}_t \\\\ v_t \\\\ \\dot{v}_t \\end{pmatrix}}_{\\vz_t}\n", " = \n", "\\underbrace{\n", "\\begin{pmatrix}\n", "1 & 0 & \\Delta & 0 \\\\\n", "0 & 1 & 0 & \\Delta\\\\\n", "0 & 0 & 1 & 0 \\\\\n", "0 & 0 & 0 & 1\n", "\\end{pmatrix}\n", "}_{\\ldsDyn}\n", "\\\n", "\\underbrace{\\begin{pmatrix} u_{t-1} \\\\ \\dot{u}_{t-1} \\\\ v_{t-1} \\\\ \\dot{v}_{t-1} \\end{pmatrix}}_{\\vz_{t-1}}\n", "+ \\vepsilon_t\n", "\\end{align}\n", "where $\\vepsilon_t \\sim \\gauss(\\vzero,\\vQ)$ is\n", "the process noise.\n", "\n", "Let us assume\n", "that the process noise is \n", "a white noise process added to the velocity components\n", "of the state, but not to the location.\n", "(This is known as a random accelerations model.)\n", "We can approximate the resulting process in discrete time by assuming\n", "$\\vQ = \\diag(0, q, 0, q)$.\n", "(See {cite}`Sarkka13` p60 for a more accurate way\n", "to convert the continuous time process to discrete time.)\n", "\n", "\n", "Now suppose that at each discrete time point we\n", "observe the location,\n", "corrupted by Gaussian noise.\n", "Thus the observation model becomes\n", "\\begin{align}\n", "\\underbrace{\\begin{pmatrix} y_{1,t} \\\\ y_{2,t} \\end{pmatrix}}_{\\vy_t}\n", " &=\n", " \\underbrace{\n", " \\begin{pmatrix}\n", "1 & 0 & 0 & 0 \\\\\n", "0 & 0 & 1 & 0\n", " \\end{pmatrix}\n", " }_{\\ldsObs}\n", " \\\n", "\\underbrace{\\begin{pmatrix} u_t\\\\ \\dot{u}_t \\\\ v_t \\\\ \\dot{v}_t \\end{pmatrix}}_{\\vz_t} \n", " + \\veta_t\n", "\\end{align}\n", "where $\\veta_t \\sim \\gauss(\\vzero,\\vR)$ is the \\keywordDef{observation noise}.\n", "We see that the observation matrix $\\ldsObs$ simply ``extracts'' the\n", "relevant parts of the state vector.\n", "\n", "Suppose we sample a trajectory and corresponding set\n", "of noisy observations from this model,\n", "$(\\vz_{1:T}, \\vy_{1:T}) \\sim p(\\vz,\\vy|\\vtheta)$.\n", "(We use diagonal observation noise,\n", "$\\vR = \\diag(\\sigma_1^2, \\sigma_2^2)$.)\n", "The results are shown below. \n" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "LDS(A=DeviceArray([[1., 0., 1., 0.],\n", " [0., 1., 0., 1.],\n", " [0., 0., 1., 0.],\n", " [0., 0., 0., 1.]], dtype=float32), C=DeviceArray([[1, 0, 0, 0],\n", " [0, 1, 0, 0]], dtype=int32), Q=DeviceArray([[0.001, 0. , 0. , 0. ],\n", " [0. , 0.001, 0. , 0. ],\n", " [0. , 0. , 0.001, 0. ],\n", " [0. , 0. , 0. , 0.001]], dtype=float32), R=DeviceArray([[1., 0.],\n", " [0., 1.]], dtype=float32), mu=DeviceArray([ 8., 10., 1., 0.], dtype=float32), Sigma=DeviceArray([[1., 0., 0., 0.],\n", " [0., 1., 0., 0.],\n", " [0., 0., 1., 0.],\n", " [0., 0., 0., 1.]], dtype=float32), state_offset=None, obs_offset=None, nstates=4, nobs=2)\n" ] } ], "source": [ "key = jax.random.PRNGKey(314)\n", "timesteps = 15\n", "delta = 1.0\n", "A = jnp.array([\n", " [1, 0, delta, 0],\n", " [0, 1, 0, delta],\n", " [0, 0, 1, 0],\n", " [0, 0, 0, 1]\n", "])\n", "\n", "C = jnp.array([\n", " [1, 0, 0, 0],\n", " [0, 1, 0, 0]\n", "])\n", "\n", "state_size, _ = A.shape\n", "observation_size, _ = C.shape\n", "\n", "Q = jnp.eye(state_size) * 0.001\n", "R = jnp.eye(observation_size) * 1.0\n", "# Prior parameter distribution\n", "mu0 = jnp.array([8, 10, 1, 0]).astype(float)\n", "Sigma0 = jnp.eye(state_size) * 1.0\n", "\n", "from jsl.lds.kalman_filter import LDS, smooth, filter\n", "\n", "lds = LDS(A, C, Q, R, mu0, Sigma0)\n", "print(lds)\n", "\n" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "from jsl.demos.plot_utils import plot_ellipse\n", "\n", "def plot_tracking_values(observed, filtered, cov_hist, signal_label, ax):\n", " timesteps, _ = observed.shape\n", " ax.plot(observed[:, 0], observed[:, 1], marker=\"o\", linewidth=0,\n", " markerfacecolor=\"none\", markeredgewidth=2, markersize=8, label=\"observed\", c=\"tab:green\")\n", " ax.plot(*filtered[:, :2].T, label=signal_label, c=\"tab:red\", marker=\"x\", linewidth=2)\n", " for t in range(0, timesteps, 1):\n", " covn = cov_hist[t][:2, :2]\n", " plot_ellipse(covn, filtered[t, :2], ax, n_std=2.0, plot_center=False)\n", " ax.axis(\"equal\")\n", " ax.legend()" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(7.24486608505249, 23.857812213897706, 8.0420747756958, 11.636079216003418)" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "\n", "z_hist, x_hist = lds.sample(key, timesteps)\n", "\n", "fig_truth, axs = plt.subplots()\n", "axs.plot(x_hist[:, 0], x_hist[:, 1],\n", " marker=\"o\", linewidth=0, markerfacecolor=\"none\",\n", " markeredgewidth=2, markersize=8,\n", " label=\"observed\", c=\"tab:green\")\n", "\n", "axs.plot(z_hist[:, 0], z_hist[:, 1],\n", " linewidth=2, label=\"truth\",\n", " marker=\"s\", markersize=8)\n", "axs.legend()\n", "axs.axis(\"equal\")\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The main task is to infer the hidden states given the noisy\n", "observations, i.e., $p(\\vz|\\vy,\\vtheta)$. We discuss the topic of inference in {ref}`sec:inference`." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.2" } }, "nbformat": 4, "nbformat_minor": 4 }