test_stan_functions.R 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. library(prophet)
  2. context("Prophet stan model tests")
  3. rstan::expose_stan_functions(rstan::stanc(file="../..//inst/stan/prophet_logistic_growth.stan"))
  4. DATA <- read.csv('data.csv')
  5. N <- nrow(DATA)
  6. train <- DATA[1:floor(N / 2), ]
  7. future <- DATA[(ceiling(N/2) + 1):N, ]
  8. DATA2 <- read.csv('data2.csv')
  9. DATA$ds <- prophet:::set_date(DATA$ds)
  10. DATA2$ds <- prophet:::set_date(DATA2$ds)
  11. test_that("get_changepoint_matrix", {
  12. history <- train
  13. m <- prophet(history, fit = FALSE)
  14. out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE)
  15. history <- out$df
  16. m <- out$m
  17. m$history <- history
  18. m <- prophet:::set_changepoints(m)
  19. cp <- m$changepoints.t
  20. mat <- get_changepoint_matrix(history$t, cp, nrow(history), length(cp))
  21. expect_equal(nrow(mat), floor(N / 2))
  22. expect_equal(ncol(mat), m$n.changepoints)
  23. # Compare to the R implementation
  24. A <- matrix(0, nrow(history), length(cp))
  25. for (i in 1:length(cp)) {
  26. A[history$t >= cp[i], i] <- 1
  27. }
  28. expect_true(all(A == mat))
  29. })
  30. test_that("get_zero_changepoints", {
  31. history <- train
  32. m <- prophet(history, n.changepoints = 0, fit = FALSE)
  33. out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE)
  34. m <- out$m
  35. history <- out$df
  36. m$history <- history
  37. m <- prophet:::set_changepoints(m)
  38. cp <- m$changepoints.t
  39. mat <- get_changepoint_matrix(history$t, cp, nrow(history), length(cp))
  40. expect_equal(nrow(mat), floor(N / 2))
  41. expect_equal(ncol(mat), 1)
  42. expect_true(all(mat == 1))
  43. })