test_stan_functions.R 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. library(prophet)
  2. context("Prophet stan model tests")
  3. fn <- tryCatch({
  4. rstan::expose_stan_functions(
  5. rstan::stanc(file="../../inst/stan/prophet.stan")
  6. )
  7. }, error = function(e) {
  8. rstan::expose_stan_functions(
  9. rstan::stanc(file=system.file("stan/prophet.stan", package="prophet"))
  10. )
  11. })
  12. DATA <- read.csv('data.csv')
  13. N <- nrow(DATA)
  14. train <- DATA[1:floor(N / 2), ]
  15. future <- DATA[(ceiling(N/2) + 1):N, ]
  16. DATA2 <- read.csv('data2.csv')
  17. DATA$ds <- prophet:::set_date(DATA$ds)
  18. DATA2$ds <- prophet:::set_date(DATA2$ds)
  19. test_that("get_changepoint_matrix", {
  20. history <- train
  21. m <- prophet(history, fit = FALSE)
  22. out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE)
  23. history <- out$df
  24. m <- out$m
  25. m$history <- history
  26. m <- prophet:::set_changepoints(m)
  27. cp <- m$changepoints.t
  28. mat <- get_changepoint_matrix(history$t, cp, nrow(history), length(cp))
  29. expect_equal(nrow(mat), floor(N / 2))
  30. expect_equal(ncol(mat), m$n.changepoints)
  31. # Compare to the R implementation
  32. A <- matrix(0, nrow(history), length(cp))
  33. for (i in 1:length(cp)) {
  34. A[history$t >= cp[i], i] <- 1
  35. }
  36. expect_true(all(A == mat))
  37. })
  38. test_that("get_zero_changepoints", {
  39. history <- train
  40. m <- prophet(history, n.changepoints = 0, fit = FALSE)
  41. out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE)
  42. m <- out$m
  43. history <- out$df
  44. m$history <- history
  45. m <- prophet:::set_changepoints(m)
  46. cp <- m$changepoints.t
  47. mat <- get_changepoint_matrix(history$t, cp, nrow(history), length(cp))
  48. expect_equal(nrow(mat), floor(N / 2))
  49. expect_equal(ncol(mat), 1)
  50. expect_true(all(mat == 1))
  51. })
  52. test_that("linear_trend", {
  53. t <- seq(0, 10)
  54. m <- 0
  55. k <- 1.0
  56. deltas <- c(0.5)
  57. changepoint.ts <- c(5)
  58. A <- get_changepoint_matrix(t, changepoint.ts, length(t), 1)
  59. y <- linear_trend(k, m, deltas, t, A, changepoint.ts)
  60. y.true <- c(0, 1, 2, 3, 4, 5, 6.5, 8, 9.5, 11, 12.5)
  61. expect_equal(y, y.true)
  62. t <- t[8:length(t)]
  63. A <- get_changepoint_matrix(t, changepoint.ts, length(t), 1)
  64. y.true <- y.true[8:length(y.true)]
  65. y <- linear_trend(k, m, deltas, t, A, changepoint.ts)
  66. expect_equal(y, y.true)
  67. })
  68. test_that("piecewise_logistic", {
  69. t <- seq(0, 10)
  70. cap <- rep(10, 11)
  71. m <- 0
  72. k <- 1.0
  73. deltas <- c(0.5)
  74. changepoint.ts <- c(5)
  75. A <- get_changepoint_matrix(t, changepoint.ts, length(t), 1)
  76. y <- logistic_trend(k, m, deltas, t, cap, A, changepoint.ts, 1)
  77. y.true <- c(5.000000, 7.310586, 8.807971, 9.525741, 9.820138, 9.933071,
  78. 9.984988, 9.996646, 9.999252, 9.999833, 9.999963)
  79. expect_equal(y, y.true, tolerance = 1e-6)
  80. t <- t[8:length(t)]
  81. A <- get_changepoint_matrix(t, changepoint.ts, length(t), 1)
  82. y.true <- y.true[8:length(y.true)]
  83. cap <- cap[8:length(cap)]
  84. y <- logistic_trend(k, m, deltas, t, cap, A, changepoint.ts, 1)
  85. expect_equal(y, y.true, tolerance = 1e-6)
  86. })