Pārlūkot izejas kodu

Speed up stan fitting by removing unecessary parameter definitions

Ben Letham 7 gadi atpakaļ
vecāks
revīzija
3bd372bc15
2 mainītis faili ar 30 papildinājumiem un 32 dzēšanām
  1. 15 16
      R/inst/stan/prophet.stan
  2. 15 16
      python/stan/unix/prophet.stan

+ 15 - 16
R/inst/stan/prophet.stan

@@ -102,21 +102,6 @@ parameters {
   vector[K] beta;           // Regressor coefficients
 }
 
-transformed parameters {
-  vector[T] trend;
-  vector[T] Xb_a;
-  vector[T] Xb_m;
-
-  if (trend_indicator == 0) {
-    trend = linear_trend(k, m, delta, t, A, t_change);
-  } else if (trend_indicator == 1) {
-    trend = logistic_trend(k, m, delta, t, cap, A, t_change, S);
-  }
-
-  Xb_a = X * (beta .* s_a);
-  Xb_m = X * (beta .* s_m);
-}
-
 model {
   //priors
   k ~ normal(0, 5);
@@ -126,5 +111,19 @@ model {
   beta ~ normal(0, sigmas);
 
   // Likelihood
-  y ~ normal(trend .* (1 + Xb_m) + Xb_a, sigma_obs);
+  if (trend_indicator == 0) {
+    y ~ normal(
+      linear_trend(k, m, delta, t, A, t_change)
+      .* (1 + X * (beta .* s_m))
+      + X * (beta .* s_a),
+      sigma_obs
+    );
+  } else if (trend_indicator == 1) {
+    y ~ normal(
+      logistic_trend(k, m, delta, t, cap, A, t_change, S)
+      .* (1 + X * (beta .* s_m))
+      + X * (beta .* s_a),
+      sigma_obs
+    );
+  }
 }

+ 15 - 16
python/stan/unix/prophet.stan

@@ -102,21 +102,6 @@ parameters {
   vector[K] beta;           // Regressor coefficients
 }
 
-transformed parameters {
-  vector[T] trend;
-  vector[T] Xb_a;
-  vector[T] Xb_m;
-
-  if (trend_indicator == 0) {
-    trend = linear_trend(k, m, delta, t, A, t_change);
-  } else if (trend_indicator == 1) {
-    trend = logistic_trend(k, m, delta, t, cap, A, t_change, S);
-  }
-
-  Xb_a = X * (beta .* s_a);
-  Xb_m = X * (beta .* s_m);
-}
-
 model {
   //priors
   k ~ normal(0, 5);
@@ -126,5 +111,19 @@ model {
   beta ~ normal(0, sigmas);
 
   // Likelihood
-  y ~ normal(trend .* (1 + Xb_m) + Xb_a, sigma_obs);
+  if (trend_indicator == 0) {
+    y ~ normal(
+      linear_trend(k, m, delta, t, A, t_change)
+      .* (1 + X * (beta .* s_m))
+      + X * (beta .* s_a),
+      sigma_obs
+    );
+  } else if (trend_indicator == 1) {
+    y ~ normal(
+      logistic_trend(k, m, delta, t, cap, A, t_change, S)
+      .* (1 + X * (beta .* s_m))
+      + X * (beta .* s_a),
+      sigma_obs
+    );
+  }
 }