prophet.stan 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. functions {
  2. real[ , ] get_changepoint_matrix(real[] t, real[] t_change, int T, int S) {
  3. // Assumes t and t_change are sorted.
  4. real A[T, S];
  5. real a_row[S];
  6. int cp_idx;
  7. // Start with an empty matrix.
  8. A = rep_array(0, T, S);
  9. a_row = rep_array(0, S);
  10. cp_idx = 1;
  11. // Fill in each row of A.
  12. for (i in 1:T) {
  13. while ((cp_idx <= S) && (t[i] >= t_change[cp_idx])) {
  14. a_row[cp_idx] = 1;
  15. cp_idx = cp_idx + 1;
  16. }
  17. A[i] = a_row;
  18. }
  19. return A;
  20. }
  21. // Logistic trend functions
  22. real[] logistic_gamma(real k, real m, real[] delta, real[] t_change, int S) {
  23. real gamma[S]; // adjusted offsets, for piecewise continuity
  24. real k_s[S + 1]; // actual rate in each segment
  25. real m_pr;
  26. // Compute the rate in each segment
  27. k_s[1] = k;
  28. for (i in 1:S) {
  29. k_s[i + 1] = k_s[i] + delta[i];
  30. }
  31. // Piecewise offsets
  32. m_pr = m; // The offset in the previous segment
  33. for (i in 1:S) {
  34. gamma[i] = (t_change[i] - m_pr) * (1 - k_s[i] / k_s[i + 1]);
  35. m_pr = m_pr + gamma[i]; // update for the next segment
  36. }
  37. return gamma;
  38. }
  39. real[] logistic_trend(
  40. real k,
  41. real m,
  42. real[] delta,
  43. real[] t,
  44. real[] cap,
  45. real[ , ] A,
  46. real[] t_change,
  47. int S,
  48. int T
  49. ) {
  50. real gamma[S];
  51. real Y[T];
  52. gamma = logistic_gamma(k, m, delta, t_change, S);
  53. for (i in 1:T) {
  54. Y[i] = cap[i] / (1 + exp(-(k + dot_product(A[i], delta))
  55. * (t[i] - (m + dot_product(A[i], gamma)))));
  56. }
  57. return Y;
  58. }
  59. // Linear trend function
  60. real[] linear_trend(
  61. real k,
  62. real m,
  63. real[] delta,
  64. real[] t,
  65. real[ , ] A,
  66. real[] t_change,
  67. int S,
  68. int T
  69. ) {
  70. real gamma[S];
  71. real Y[T];
  72. for (i in 1:S) {
  73. gamma[i] = -t_change[i] * delta[i];
  74. }
  75. for (i in 1:T) {
  76. Y[i] = (k + dot_product(A[i], delta)) * t[i] + (
  77. m + dot_product(A[i], gamma));
  78. }
  79. return Y;
  80. }
  81. }
  82. data {
  83. int T; // Number of time periods
  84. int<lower=1> K; // Number of regressors
  85. real t[T]; // Time
  86. real cap[T]; // Capacities for logistic trend
  87. real y[T]; // Time series
  88. int S; // Number of changepoints
  89. real t_change[S]; // Times of trend changepoints
  90. real X[T,K]; // Regressors
  91. vector[K] sigmas; // Scale on seasonality prior
  92. real<lower=0> tau; // Scale on changepoints prior
  93. int trend_indicator; // 0 for linear, 1 for logistic
  94. real s_a[K]; // Indicator of additive features
  95. real s_m[K]; // Indicator of multiplicative features
  96. }
  97. transformed data {
  98. real A[T, S];
  99. A = get_changepoint_matrix(t, t_change, T, S);
  100. }
  101. parameters {
  102. real k; // Base trend growth rate
  103. real m; // Trend offset
  104. real delta[S]; // Trend rate adjustments
  105. real<lower=0> sigma_obs; // Observation noise
  106. real beta[K]; // Regressor coefficients
  107. }
  108. transformed parameters {
  109. real trend[T];
  110. real Y[T];
  111. real beta_m[K];
  112. real beta_a[K];
  113. if (trend_indicator == 0) {
  114. trend = linear_trend(k, m, delta, t, A, t_change, S, T);
  115. } else if (trend_indicator == 1) {
  116. trend = logistic_trend(k, m, delta, t, cap, A, t_change, S, T);
  117. }
  118. for (i in 1:K) {
  119. beta_m[i] = beta[i] * s_m[i];
  120. beta_a[i] = beta[i] * s_a[i];
  121. }
  122. for (i in 1:T) {
  123. Y[i] = (
  124. trend[i] * (1 + dot_product(X[i], beta_m)) + dot_product(X[i], beta_a)
  125. );
  126. }
  127. }
  128. model {
  129. //priors
  130. k ~ normal(0, 5);
  131. m ~ normal(0, 5);
  132. delta ~ double_exponential(0, tau);
  133. sigma_obs ~ normal(0, 0.5);
  134. beta ~ normal(0, sigmas);
  135. // Likelihood
  136. y ~ normal(Y, sigma_obs);
  137. }