{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: distrax in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (0.0.1)\n", "Collecting distrax\n", " Downloading distrax-0.1.2-py3-none-any.whl (272 kB)\n", "\u001b[K |████████████████████████████████| 272 kB 6.9 MB/s eta 0:00:01\n", "\u001b[?25hRequirement already satisfied: jax>=0.1.55 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from distrax) (0.2.11)\n", "Requirement already satisfied: absl-py>=0.9.0 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from distrax) (0.12.0)\n", "Requirement already satisfied: chex>=0.0.7 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from distrax) (0.0.8)\n", "Requirement already satisfied: jaxlib>=0.1.67 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from distrax) (0.1.70)\n", "Requirement already satisfied: numpy>=1.18.0 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from distrax) (1.19.5)\n", "Collecting tensorflow-probability>=0.15.0\n", " Using cached tensorflow_probability-0.16.0-py2.py3-none-any.whl (6.3 MB)\n", "Requirement already satisfied: six in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from absl-py>=0.9.0->distrax) (1.15.0)\n", "Requirement already satisfied: dm-tree>=0.1.5 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from chex>=0.0.7->distrax) (0.1.6)\n", "Requirement already satisfied: toolz>=0.9.0 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from chex>=0.0.7->distrax) (0.11.1)\n", "Requirement already satisfied: opt-einsum in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from jax>=0.1.55->distrax) (3.3.0)\n", "Requirement already satisfied: flatbuffers<3.0,>=1.12 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from jaxlib>=0.1.67->distrax) (1.12)\n", "Requirement already satisfied: scipy in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from jaxlib>=0.1.67->distrax) (1.6.3)\n", "Requirement already satisfied: cloudpickle>=1.3 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from tensorflow-probability>=0.15.0->distrax) (1.6.0)\n", "Requirement already satisfied: decorator in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from tensorflow-probability>=0.15.0->distrax) (4.4.2)\n", "Requirement already satisfied: gast>=0.3.2 in /opt/anaconda3/envs/spyder-dev/lib/python3.9/site-packages (from tensorflow-probability>=0.15.0->distrax) (0.4.0)\n", "Installing collected packages: tensorflow-probability, distrax\n", " Attempting uninstall: tensorflow-probability\n", " Found existing installation: tensorflow-probability 0.13.0\n", " Uninstalling tensorflow-probability-0.13.0:\n", " Successfully uninstalled tensorflow-probability-0.13.0\n", " Attempting uninstall: distrax\n", " Found existing installation: distrax 0.0.1\n", " Uninstalling distrax-0.0.1:\n", " Successfully uninstalled distrax-0.0.1\n", "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", "jsl 0.0.0 requires dataclasses, which is not installed.\u001b[0m\n", "Successfully installed distrax-0.1.2 tensorflow-probability-0.16.0\n", "\u001b[33mWARNING: You are using pip version 21.2.4; however, version 22.0.4 is available.\n", "You should consider upgrading via the '/opt/anaconda3/envs/spyder-dev/bin/python -m pip install --upgrade pip' command.\u001b[0m\n", "Note: you may need to restart the kernel to use updated packages.\n" ] } ], "source": [ "# meta-data does not work yet in VScode\n", "# https://github.com/microsoft/vscode-jupyter/issues/1121\n", "\n", "{\n", " \"tags\": [\n", " \"hide-cell\"\n", " ]\n", "}\n", "\n", "\n", "### Install necessary libraries\n", "\n", "try:\n", " import jax\n", "except:\n", " # For cuda version, see https://github.com/google/jax#installation\n", " %pip install --upgrade \"jax[cpu]\" \n", " import jax\n", "\n", "try:\n", " import distrax\n", "except:\n", " %pip install --upgrade distrax\n", " import distrax\n", "\n", "try:\n", " import jsl\n", "except:\n", " %pip install git+https://github.com/probml/jsl\n", " import jsl\n", "\n", "try:\n", " import rich\n", "except:\n", " %pip install rich\n", " import rich\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "{\n", " \"tags\": [\n", " \"hide-cell\"\n", " ]\n", "}\n", "\n", "\n", "### Import standard libraries\n", "\n", "import abc\n", "from dataclasses import dataclass\n", "import functools\n", "import itertools\n", "\n", "from typing import Any, Callable, NamedTuple, Optional, Union, Tuple\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "\n", "import jax\n", "import jax.numpy as jnp\n", "from jax import lax, vmap, jit, grad\n", "from jax.scipy.special import logit\n", "from jax.nn import softmax\n", "from functools import partial\n", "from jax.random import PRNGKey, split\n", "\n", "import inspect\n", "import inspect as py_inspect\n", "import rich\n", "from rich import inspect as r_inspect\n", "from rich import print as r_print\n", "\n", "def print_source(fname):\n", " r_print(py_inspect.getsource(fname))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "(sec:nlds-intro)=\n", "# Nonlinear Gaussian SSMs\n", "\n", "In this section, we consider SSMs in which the dynamics and/or observation models are nonlinear,\n", "but the process noise and observation noise are Gaussian.\n", "That is, \n", "\\begin{align}\n", "\\hidden_t &= \\dynamicsFn(\\hidden_{t-1}, \\inputs_t) + \\transNoise_t \\\\\n", "\\obs_t &= \\obsFn(\\hidden_{t}, \\inputs_t) + \\obsNoise_t\n", "\\end{align}\n", "where $\\transNoise_t \\sim \\gauss(\\vzero,\\transCov)$\n", "and $\\obsNoise_t \\sim \\gauss(\\vzero,\\obsCov)$.\n", "This is a very widely used model class. We give some examples below." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "(sec:pendulum)=\n", "## Example: tracking a 1d pendulum\n", "\n", "```{figure} /figures/pendulum.png\n", ":scale: 50%\n", ":name: fig:pendulum\n", "\n", "Illustration of a pendulum swinging.\n", "$g$ is the force of gravity,\n", "$w(t)$ is a random external force,\n", "and $\\alpha$ is the angle wrt the vertical.\n", "Based on {cite}`Sarkka13` fig 3.10.\n", "```\n", "\n", "\n", "% Sarka p45, p74\n", "Consider a simple pendulum of unit mass and length swinging from\n", "a fixed attachment, as in\n", "{numref}`fig:pendulum`.\n", "Such an object is in principle entirely deterministic in its behavior.\n", "However, in the real world, there are often unknown forces at work\n", "(e.g., air turbulence, friction).\n", "We will model these by a continuous time random Gaussian noise process $w(t)$.\n", "This gives rise to the following differential equation:\n", "\\begin{align}\n", "\\frac{d^2 \\alpha}{d t^2}\n", "= -g \\sin(\\alpha) + w(t)\n", "\\end{align}\n", "We can write this as a nonlinear SSM by defining the state to be\n", "$\\hidden_1(t) = \\alpha(t)$ and $\\hidden_2(t) = d\\alpha(t)/dt$.\n", "Thus\n", "\\begin{align}\n", "\\frac{d \\hidden}{dt}\n", "= \\begin{pmatrix} \\hiddenScalar_2 \\\\ -g \\sin(\\hiddenScalar_1) \\end{pmatrix}\n", "+ \\begin{pmatrix} 0 \\\\ 1 \\end{pmatrix} w(t)\n", "\\end{align}\n", "If we discretize this step size $\\Delta$,\n", "we get the following\n", "formulation {cite}`Sarkka13` p74:\n", "\\begin{align}\n", "\\underbrace{\n", " \\begin{pmatrix} \\hiddenScalar_{1,t} \\\\ \\hiddenScalar_{2,t} \\end{pmatrix}\n", " }_{\\hidden_t}\n", "=\n", "\\underbrace{\n", " \\begin{pmatrix} \\hiddenScalar_{1,t-1} + \\hiddenScalar_{2,t-1} \\Delta \\\\\n", " \\hiddenScalar_{2,t-1} -g \\sin(\\hiddenScalar_{1,t-1}) \\Delta \\end{pmatrix}\n", " }_{\\dynamicsFn(\\hidden_{t-1})}\n", "+\\transNoise_{t-1}\n", "\\end{align}\n", "where $\\transNoise_{t-1} \\sim \\gauss(\\vzero,\\transCov)$ with\n", "\\begin{align}\n", "\\transCov = q^c \\begin{pmatrix}\n", " \\frac{\\Delta^3}{3} & \\frac{\\Delta^2}{2} \\\\\n", " \\frac{\\Delta^2}{2} & \\Delta\n", " \\end{pmatrix}\n", " \\end{align}\n", "where $q^c$ is the spectral density (continuous time variance)\n", "of the continuous-time noise process.\n", "\n", "\n", "If we observe the angular position, we\n", "get the linear observation model\n", "$\\obsFn(\\hidden_t) = \\alpha_t = \\hiddenScalar_{1,t}$.\n", "If we only observe the horizontal position,\n", "we get the nonlinear observation model\n", "$\\obsFn(\\hidden_t) = \\sin(\\alpha_t) = \\sin(\\hiddenScalar_{1,t})$.\n", "\n", "\n", "\n", "\n", "\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.2" } }, "nbformat": 4, "nbformat_minor": 4 }