learning.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. #!/usr/bin/env python
  2. # coding: utf-8
  3. # In[1]:
  4. ### Import standard libraries
  5. import abc
  6. from dataclasses import dataclass
  7. import functools
  8. from functools import partial
  9. import itertools
  10. import matplotlib.pyplot as plt
  11. import numpy as np
  12. from typing import Any, Callable, NamedTuple, Optional, Union, Tuple
  13. import jax
  14. import jax.numpy as jnp
  15. from jax import lax, vmap, jit, grad
  16. #from jax.scipy.special import logit
  17. #from jax.nn import softmax
  18. import jax.random as jr
  19. import distrax
  20. import optax
  21. import jsl
  22. import ssm_jax
  23. # (sec:learning)=
  24. # # Parameter estimation (learning)
  25. #
  26. #
  27. # So far, we have assumed that the parameters $\params$ of the SSM are known.
  28. # For example, in the case of an HMM with categorical observations
  29. # we have $\params = (\hmmInit, \hmmTrans, \hmmObs)$,
  30. # and in the case of an LDS, we have $\params =
  31. # (\ldsTrans, \ldsObs, \ldsTransIn, \ldsObsIn, \transCov, \obsCov, \initMean, \initCov)$.
  32. # If we adopt a Bayesian perspective, we can view these parameters as random variables that are
  33. # shared across all time steps, and across all sequences.
  34. # This is shown in {numref}`fig:hmm-plates`, where we adopt $\keyword{plate notation}$
  35. # to represent repetitive structure.
  36. #
  37. # ```{figure} /figures/hmmDgmPlatesY.png
  38. # :scale: 100%
  39. # :name: fig:hmm-plates
  40. #
  41. # Illustration of an HMM using plate notation, where we show the parameter
  42. # nodes which are shared across all the sequences.
  43. # ```
  44. #
  45. # Suppose we observe $N$ sequences $\data = \{\obs_{n,1:T_n}: n=1:N\}$.
  46. # Then the goal of $\keyword{parameter estimation}$, also called $\keyword{model learning}$
  47. # or $\keyword{model fitting}$, is to approximate the posterior
  48. # \begin{align}
  49. # p(\params|\data) \propto p(\params) \prod_{n=1}^N p(\obs_{n,1:T_n} | \params)
  50. # \end{align}
  51. # where $p(\obs_{n,1:T_n} | \params)$ is the marginal likelihood of sequence $n$:
  52. # \begin{align}
  53. # p(\obs_{1:T} | \params) = \int p(\hidden_{1:T}, \obs_{1:T} | \params) d\hidden_{1:T}
  54. # \end{align}
  55. #
  56. # Since computing the full posterior is computationally difficult, we often settle for computing
  57. # a point estimate such as the MAP (maximum a posterior) estimate
  58. # \begin{align}
  59. # \params_{\map} = \arg \max_{\params} \log p(\params) + \sum_{n=1}^N \log p(\obs_{n,1:T_n} | \params)
  60. # \end{align}
  61. # If we ignore the prior term, we get the maximum likelihood estimate or MLE:
  62. # \begin{align}
  63. # \params_{\mle} = \arg \max_{\params} \sum_{n=1}^N \log p(\obs_{n,1:T_n} | \params)
  64. # \end{align}
  65. # In practice, the MAP estimate often works better than the MLE, since the prior can regularize
  66. # the estimate to ensure the model is numerically stable and does not overfit the training set.
  67. #
  68. # We will discuss a variety of algorithms for parameter estimation in later chapters.
  69. #
  70. #