소스 검색

Move changepoint matrix calculation into stan (R)

Ben Letham 7 년 전
부모
커밋
1f84fa960f
5개의 변경된 파일117개의 추가작업 그리고 30개의 파일을 삭제
  1. 0 17
      R/R/prophet.R
  2. 30 2
      R/inst/stan/prophet_linear_growth.stan
  3. 30 2
      R/inst/stan/prophet_logistic_growth.stan
  4. 1 9
      R/tests/testthat/test_prophet.R
  5. 56 0
      R/tests/testthat/test_stan_functions.R

+ 0 - 17
R/R/prophet.R

@@ -464,21 +464,6 @@ set_changepoints <- function(m) {
   return(m)
   return(m)
 }
 }
 
 
-#' Gets changepoint matrix for history dataframe.
-#'
-#' @param m Prophet object.
-#'
-#' @return array of indexes.
-#'
-#' @keywords internal
-get_changepoint_matrix <- function(m) {
-  A <- matrix(0, nrow(m$history), length(m$changepoints.t))
-  for (i in seq_along(m$changepoints.t)) {
-    A[m$history$t >= m$changepoints.t[i], i] <- 1
-  }
-  return(A)
-}
-
 #' Provides Fourier series components with the specified frequency and order.
 #' Provides Fourier series components with the specified frequency and order.
 #'
 #'
 #' @param dates Vector of dates.
 #' @param dates Vector of dates.
@@ -905,7 +890,6 @@ fit.prophet <- function(m, df, ...) {
   prior.scales <- out2$prior.scales
   prior.scales <- out2$prior.scales
 
 
   m <- set_changepoints(m)
   m <- set_changepoints(m)
-  A <- get_changepoint_matrix(m)
 
 
   # Construct input to stan
   # Construct input to stan
   dat <- list(
   dat <- list(
@@ -914,7 +898,6 @@ fit.prophet <- function(m, df, ...) {
     S = length(m$changepoints.t),
     S = length(m$changepoints.t),
     y = history$y_scaled,
     y = history$y_scaled,
     t = history$t,
     t = history$t,
-    A = A,
     t_change = array(m$changepoints.t),
     t_change = array(m$changepoints.t),
     X = as.matrix(seasonal.features),
     X = as.matrix(seasonal.features),
     sigmas = array(prior.scales),
     sigmas = array(prior.scales),

+ 30 - 2
R/inst/stan/prophet_linear_growth.stan

@@ -1,16 +1,44 @@
+functions {
+  matrix get_changepoint_matrix(vector t, vector t_change, int T, int S) {
+    // Assumes t and t_change are sorted.
+    matrix[T, S] A;
+    row_vector[S] a_row;
+    int cp_idx;
+
+    // Start with an empty matrix.
+    A = rep_matrix(0, T, S);
+    a_row = rep_row_vector(0, S);
+    cp_idx = 1;
+
+    // Fill in each row of A.
+    for (i in 1:T) {
+      while ((cp_idx <= S) && (t[i] >= t_change[cp_idx])) {
+        a_row[cp_idx] = 1;
+        cp_idx += 1;
+      }
+      A[i] = a_row;
+    }
+    return A;
+  }
+}
+
 data {
 data {
   int T;                                // Sample size
   int T;                                // Sample size
   int<lower=1> K;                       // Number of seasonal vectors
   int<lower=1> K;                       // Number of seasonal vectors
   vector[T] t;                            // Day
   vector[T] t;                            // Day
   vector[T] y;                            // Time-series
   vector[T] y;                            // Time-series
   int S;                                // Number of changepoints
   int S;                                // Number of changepoints
-  matrix[T, S] A;                   // Split indicators
-  real t_change[S];                 // Index of changepoints
+  vector[S] t_change;                 // Index of changepoints
   matrix[T,K] X;                // season vectors
   matrix[T,K] X;                // season vectors
   vector[K] sigmas;              // scale on seasonality prior
   vector[K] sigmas;              // scale on seasonality prior
   real<lower=0> tau;                  // scale on changepoints prior
   real<lower=0> tau;                  // scale on changepoints prior
 }
 }
 
 
+transformed data {
+  matrix[T, S] A;
+  A = get_changepoint_matrix(t, t_change, T, S);
+}
+
 parameters {
 parameters {
   real k;                            // Base growth rate
   real k;                            // Base growth rate
   real m;                            // offset
   real m;                            // offset

+ 30 - 2
R/inst/stan/prophet_logistic_growth.stan

@@ -1,3 +1,27 @@
+functions {
+  matrix get_changepoint_matrix(vector t, vector t_change, int T, int S) {
+    // Assumes t and t_change are sorted.
+    matrix[T, S] A;
+    row_vector[S] a_row;
+    int cp_idx;
+
+    // Start with an empty matrix.
+    A = rep_matrix(0, T, S);
+    a_row = rep_row_vector(0, S);
+    cp_idx = 1;
+
+    // Fill in each row of A.
+    for (i in 1:T) {
+      while ((cp_idx <= S) && (t[i] >= t_change[cp_idx])) {
+        a_row[cp_idx] = 1;
+        cp_idx += 1;
+      }
+      A[i] = a_row;
+    }
+    return A;
+  }
+}
+
 data {
 data {
   int T;                                // Sample size
   int T;                                // Sample size
   int<lower=1> K;                       // Number of seasonal vectors
   int<lower=1> K;                       // Number of seasonal vectors
@@ -5,13 +29,17 @@ data {
   vector[T] cap;                          // Capacities
   vector[T] cap;                          // Capacities
   vector[T] y;                            // Time-series
   vector[T] y;                            // Time-series
   int S;                                // Number of changepoints
   int S;                                // Number of changepoints
-  matrix[T, S] A;                   // Split indicators
-  real t_change[S];                 // Index of changepoints
+  vector[S] t_change;                 // Index of changepoints
   matrix[T,K] X;                    // season vectors
   matrix[T,K] X;                    // season vectors
   vector[K] sigmas;               // scale on seasonality prior
   vector[K] sigmas;               // scale on seasonality prior
   real<lower=0> tau;                  // scale on changepoints prior
   real<lower=0> tau;                  // scale on changepoints prior
 }
 }
 
 
+transformed data {
+  matrix[T, S] A;
+  A = get_changepoint_matrix(t, t_change, T, S);
+}
+
 parameters {
 parameters {
   real k;                            // Base growth rate
   real k;                            // Base growth rate
   real m;                            // offset
   real m;                            // offset

+ 1 - 9
R/tests/testthat/test_prophet.R

@@ -121,11 +121,7 @@ test_that("get_changepoints", {
   cp <- m$changepoints.t
   cp <- m$changepoints.t
   expect_equal(length(cp), m$n.changepoints)
   expect_equal(length(cp), m$n.changepoints)
   expect_true(min(cp) > 0)
   expect_true(min(cp) > 0)
-  expect_true(max(cp) < N)
-
-  mat <- prophet:::get_changepoint_matrix(m)
-  expect_equal(nrow(mat), floor(N / 2))
-  expect_equal(ncol(mat), m$n.changepoints)
+  expect_true(max(cp) < 1)
 })
 })
 
 
 test_that("get_zero_changepoints", {
 test_that("get_zero_changepoints", {
@@ -141,10 +137,6 @@ test_that("get_zero_changepoints", {
   cp <- m$changepoints.t
   cp <- m$changepoints.t
   expect_equal(length(cp), 1)
   expect_equal(length(cp), 1)
   expect_equal(cp[1], 0)
   expect_equal(cp[1], 0)
-
-  mat <- prophet:::get_changepoint_matrix(m)
-  expect_equal(nrow(mat), floor(N / 2))
-  expect_equal(ncol(mat), 1)
 })
 })
 
 
 test_that("override_n_changepoints", {
 test_that("override_n_changepoints", {

+ 56 - 0
R/tests/testthat/test_stan_functions.R

@@ -0,0 +1,56 @@
+library(prophet)
+context("Prophet stan model tests")
+
+rstan::expose_stan_functions(rstan::stanc(file="../..//inst/stan/prophet_logistic_growth.stan"))
+
+DATA <- read.csv('data.csv')
+N <- nrow(DATA)
+train <- DATA[1:floor(N / 2), ]
+future <- DATA[(ceiling(N/2) + 1):N, ]
+
+DATA2 <- read.csv('data2.csv')
+
+DATA$ds <- prophet:::set_date(DATA$ds)
+DATA2$ds <- prophet:::set_date(DATA2$ds)
+
+test_that("get_changepoint_matrix", {
+  history <- train
+  m <- prophet(history, fit = FALSE)
+
+  out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE)
+  history <- out$df
+  m <- out$m
+  m$history <- history
+
+  m <- prophet:::set_changepoints(m)
+
+  cp <- m$changepoints.t
+
+  mat <- get_changepoint_matrix(history$t, cp, nrow(history), length(cp))
+  expect_equal(nrow(mat), floor(N / 2))
+  expect_equal(ncol(mat), m$n.changepoints)
+  # Compare to the R implementation
+  A <- matrix(0, nrow(history), length(cp))
+  for (i in 1:length(cp)) {
+    A[history$t >= cp[i], i] <- 1
+  }
+  expect_true(all(A == mat))
+})
+
+test_that("get_zero_changepoints", {
+  history <- train
+  m <- prophet(history, n.changepoints = 0, fit = FALSE)
+  
+  out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE)
+  m <- out$m
+  history <- out$df
+  m$history <- history
+
+  m <- prophet:::set_changepoints(m)
+  cp <- m$changepoints.t
+  
+  mat <- get_changepoint_matrix(history$t, cp, nrow(history), length(cp))
+  expect_equal(nrow(mat), floor(N / 2))
+  expect_equal(ncol(mat), 1)
+  expect_true(all(mat == 1))
+})