test_stan_functions.R 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. library(prophet)
  2. context("Prophet stan model tests")
  3. fn <- tryCatch({
  4. rstan::expose_stan_functions(
  5. rstan::stanc(file="../../inst/stan/prophet.stan")
  6. )
  7. }, error = function(e) {
  8. rstan::expose_stan_functions(
  9. rstan::stanc(file=system.file("stan/prophet.stan", package="prophet"))
  10. )
  11. })
  12. DATA <- read.csv('data.csv')
  13. N <- nrow(DATA)
  14. train <- DATA[1:floor(N / 2), ]
  15. future <- DATA[(ceiling(N/2) + 1):N, ]
  16. DATA2 <- read.csv('data2.csv')
  17. DATA$ds <- prophet:::set_date(DATA$ds)
  18. DATA2$ds <- prophet:::set_date(DATA2$ds)
  19. test_that("get_changepoint_matrix", {
  20. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  21. history <- train
  22. m <- prophet(history, fit = FALSE)
  23. out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE)
  24. history <- out$df
  25. m <- out$m
  26. m$history <- history
  27. m <- prophet:::set_changepoints(m)
  28. cp <- m$changepoints.t
  29. mat <- get_changepoint_matrix(history$t, cp, nrow(history), length(cp))
  30. expect_equal(nrow(mat), floor(N / 2))
  31. expect_equal(ncol(mat), m$n.changepoints)
  32. # Compare to the R implementation
  33. A <- matrix(0, nrow(history), length(cp))
  34. for (i in 1:length(cp)) {
  35. A[history$t >= cp[i], i] <- 1
  36. }
  37. expect_true(all(A == mat))
  38. })
  39. test_that("get_zero_changepoints", {
  40. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  41. history <- train
  42. m <- prophet(history, n.changepoints = 0, fit = FALSE)
  43. out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE)
  44. m <- out$m
  45. history <- out$df
  46. m$history <- history
  47. m <- prophet:::set_changepoints(m)
  48. cp <- m$changepoints.t
  49. mat <- get_changepoint_matrix(history$t, cp, nrow(history), length(cp))
  50. expect_equal(nrow(mat), floor(N / 2))
  51. expect_equal(ncol(mat), 1)
  52. expect_true(all(mat == 1))
  53. })
  54. test_that("linear_trend", {
  55. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  56. t <- seq(0, 10)
  57. m <- 0
  58. k <- 1.0
  59. deltas <- c(0.5)
  60. changepoint.ts <- c(5)
  61. A <- get_changepoint_matrix(t, changepoint.ts, length(t), 1)
  62. y <- linear_trend(k, m, deltas, t, A, changepoint.ts)
  63. y.true <- c(0, 1, 2, 3, 4, 5, 6.5, 8, 9.5, 11, 12.5)
  64. expect_equal(y, y.true)
  65. t <- t[8:length(t)]
  66. A <- get_changepoint_matrix(t, changepoint.ts, length(t), 1)
  67. y.true <- y.true[8:length(y.true)]
  68. y <- linear_trend(k, m, deltas, t, A, changepoint.ts)
  69. expect_equal(y, y.true)
  70. })
  71. test_that("piecewise_logistic", {
  72. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  73. t <- seq(0, 10)
  74. cap <- rep(10, 11)
  75. m <- 0
  76. k <- 1.0
  77. deltas <- c(0.5)
  78. changepoint.ts <- c(5)
  79. A <- get_changepoint_matrix(t, changepoint.ts, length(t), 1)
  80. y <- logistic_trend(k, m, deltas, t, cap, A, changepoint.ts, 1)
  81. y.true <- c(5.000000, 7.310586, 8.807971, 9.525741, 9.820138, 9.933071,
  82. 9.984988, 9.996646, 9.999252, 9.999833, 9.999963)
  83. expect_equal(y, y.true, tolerance = 1e-6)
  84. t <- t[8:length(t)]
  85. A <- get_changepoint_matrix(t, changepoint.ts, length(t), 1)
  86. y.true <- y.true[8:length(y.true)]
  87. cap <- cap[8:length(cap)]
  88. y <- logistic_trend(k, m, deltas, t, cap, A, changepoint.ts, 1)
  89. expect_equal(y, y.true, tolerance = 1e-6)
  90. })