prophet.stan 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. functions {
  2. matrix get_changepoint_matrix(vector t, vector t_change, int T, int S) {
  3. // Assumes t and t_change are sorted.
  4. matrix[T, S] A;
  5. row_vector[S] a_row;
  6. int cp_idx;
  7. // Start with an empty matrix.
  8. A = rep_matrix(0, T, S);
  9. a_row = rep_row_vector(0, S);
  10. cp_idx = 1;
  11. // Fill in each row of A.
  12. for (i in 1:T) {
  13. while ((cp_idx <= S) && (t[i] >= t_change[cp_idx])) {
  14. a_row[cp_idx] = 1;
  15. cp_idx = cp_idx + 1;
  16. }
  17. A[i] = a_row;
  18. }
  19. return A;
  20. }
  21. // Logistic trend functions
  22. vector logistic_gamma(real k, real m, vector delta, vector t_change, int S) {
  23. vector[S] gamma; // adjusted offsets, for piecewise continuity
  24. vector[S + 1] k_s; // actual rate in each segment
  25. real m_pr;
  26. // Compute the rate in each segment
  27. k_s = append_row(k, k + cumulative_sum(delta));
  28. // Piecewise offsets
  29. m_pr = m; // The offset in the previous segment
  30. for (i in 1:S) {
  31. gamma[i] = (t_change[i] - m_pr) * (1 - k_s[i] / k_s[i + 1]);
  32. m_pr = m_pr + gamma[i]; // update for the next segment
  33. }
  34. return gamma;
  35. }
  36. vector logistic_trend(
  37. real k,
  38. real m,
  39. vector delta,
  40. vector t,
  41. vector cap,
  42. matrix A,
  43. vector t_change,
  44. int S
  45. ) {
  46. vector[S] gamma;
  47. gamma = logistic_gamma(k, m, delta, t_change, S);
  48. return cap .* inv_logit((k + A * delta) .* (t - (m + A * gamma)));
  49. }
  50. // Linear trend function
  51. vector linear_trend(
  52. real k,
  53. real m,
  54. vector delta,
  55. vector t,
  56. matrix A,
  57. vector t_change
  58. ) {
  59. return (k + A * delta) .* t + (m + A * (-t_change .* delta));
  60. }
  61. }
  62. data {
  63. int T; // Number of time periods
  64. int<lower=1> K; // Number of regressors
  65. vector[T] t; // Time
  66. vector[T] cap; // Capacities for logistic trend
  67. vector[T] y; // Time series
  68. int S; // Number of changepoints
  69. vector[S] t_change; // Times of trend changepoints
  70. matrix[T,K] X; // Regressors
  71. vector[K] sigmas; // Scale on seasonality prior
  72. real<lower=0> tau; // Scale on changepoints prior
  73. int trend_indicator; // 0 for linear, 1 for logistic
  74. vector[K] s_a; // Indicator of additive features
  75. vector[K] s_m; // Indicator of multiplicative features
  76. }
  77. transformed data {
  78. matrix[T, S] A;
  79. A = get_changepoint_matrix(t, t_change, T, S);
  80. }
  81. parameters {
  82. real k; // Base trend growth rate
  83. real m; // Trend offset
  84. vector[S] delta; // Trend rate adjustments
  85. real<lower=0> sigma_obs; // Observation noise
  86. vector[K] beta; // Regressor coefficients
  87. }
  88. model {
  89. //priors
  90. k ~ normal(0, 5);
  91. m ~ normal(0, 5);
  92. delta ~ double_exponential(0, tau);
  93. sigma_obs ~ normal(0, 0.5);
  94. beta ~ normal(0, sigmas);
  95. // Likelihood
  96. if (trend_indicator == 0) {
  97. y ~ normal(
  98. linear_trend(k, m, delta, t, A, t_change)
  99. .* (1 + X * (beta .* s_m))
  100. + X * (beta .* s_a),
  101. sigma_obs
  102. );
  103. } else if (trend_indicator == 1) {
  104. y ~ normal(
  105. logistic_trend(k, m, delta, t, cap, A, t_change, S)
  106. .* (1 + X * (beta .* s_m))
  107. + X * (beta .* s_a),
  108. sigma_obs
  109. );
  110. }
  111. }