test_stan_functions.R 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. library(prophet)
  2. context("Prophet stan model tests")
  3. rstan::expose_stan_functions(
  4. rstan::stanc(file="../..//inst/stan/prophet.stan")
  5. )
  6. DATA <- read.csv('data.csv')
  7. N <- nrow(DATA)
  8. train <- DATA[1:floor(N / 2), ]
  9. future <- DATA[(ceiling(N/2) + 1):N, ]
  10. DATA2 <- read.csv('data2.csv')
  11. DATA$ds <- prophet:::set_date(DATA$ds)
  12. DATA2$ds <- prophet:::set_date(DATA2$ds)
  13. test_that("get_changepoint_matrix", {
  14. history <- train
  15. m <- prophet(history, fit = FALSE)
  16. out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE)
  17. history <- out$df
  18. m <- out$m
  19. m$history <- history
  20. m <- prophet:::set_changepoints(m)
  21. cp <- m$changepoints.t
  22. mat <- get_changepoint_matrix(history$t, cp, nrow(history), length(cp))
  23. expect_equal(nrow(mat), floor(N / 2))
  24. expect_equal(ncol(mat), m$n.changepoints)
  25. # Compare to the R implementation
  26. A <- matrix(0, nrow(history), length(cp))
  27. for (i in 1:length(cp)) {
  28. A[history$t >= cp[i], i] <- 1
  29. }
  30. expect_true(all(A == mat))
  31. })
  32. test_that("get_zero_changepoints", {
  33. history <- train
  34. m <- prophet(history, n.changepoints = 0, fit = FALSE)
  35. out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE)
  36. m <- out$m
  37. history <- out$df
  38. m$history <- history
  39. m <- prophet:::set_changepoints(m)
  40. cp <- m$changepoints.t
  41. mat <- get_changepoint_matrix(history$t, cp, nrow(history), length(cp))
  42. expect_equal(nrow(mat), floor(N / 2))
  43. expect_equal(ncol(mat), 1)
  44. expect_true(all(mat == 1))
  45. })
  46. test_that("linear_trend", {
  47. t <- seq(0, 10)
  48. m <- 0
  49. k <- 1.0
  50. deltas <- c(0.5)
  51. changepoint.ts <- c(5)
  52. A <- get_changepoint_matrix(t, changepoint.ts, length(t), 1)
  53. y <- linear_trend(k, m, deltas, t, A, changepoint.ts)
  54. y.true <- c(0, 1, 2, 3, 4, 5, 6.5, 8, 9.5, 11, 12.5)
  55. expect_equal(y, y.true)
  56. t <- t[8:length(t)]
  57. A <- get_changepoint_matrix(t, changepoint.ts, length(t), 1)
  58. y.true <- y.true[8:length(y.true)]
  59. y <- linear_trend(k, m, deltas, t, A, changepoint.ts)
  60. expect_equal(y, y.true)
  61. })
  62. test_that("piecewise_logistic", {
  63. t <- seq(0, 10)
  64. cap <- rep(10, 11)
  65. m <- 0
  66. k <- 1.0
  67. deltas <- c(0.5)
  68. changepoint.ts <- c(5)
  69. A <- get_changepoint_matrix(t, changepoint.ts, length(t), 1)
  70. y <- logistic_trend(k, m, deltas, t, cap, A, changepoint.ts, 1)
  71. y.true <- c(5.000000, 7.310586, 8.807971, 9.525741, 9.820138, 9.933071,
  72. 9.984988, 9.996646, 9.999252, 9.999833, 9.999963)
  73. expect_equal(y, y.true, tolerance = 1e-6)
  74. t <- t[8:length(t)]
  75. A <- get_changepoint_matrix(t, changepoint.ts, length(t), 1)
  76. y.true <- y.true[8:length(y.true)]
  77. cap <- cap[8:length(cap)]
  78. y <- logistic_trend(k, m, deltas, t, cap, A, changepoint.ts, 1)
  79. expect_equal(y, y.true, tolerance = 1e-6)
  80. })