test_prophet.R 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. library(prophet)
  2. context("Prophet tests")
  3. DATA <- read.csv('data.csv')
  4. DATA$ds <- as.Date(DATA$ds)
  5. N <- nrow(DATA)
  6. train <- DATA[1:floor(N / 2), ]
  7. future <- DATA[(ceiling(N/2) + 1):N, ]
  8. test_that("load_models", {
  9. expect_error(prophet:::get_prophet_stan_model('linear'), NA)
  10. expect_error(prophet:::get_prophet_stan_model('logistic'), NA)
  11. })
  12. test_that("fit_predict", {
  13. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  14. m <- prophet(train)
  15. expect_error(predict(m, future), NA)
  16. })
  17. test_that("fit_predict_no_seasons", {
  18. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  19. m <- prophet(train, weekly.seasonality = FALSE, yearly.seasonality = FALSE)
  20. expect_error(predict(m, future), NA)
  21. })
  22. test_that("fit_predict_no_changepoints", {
  23. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  24. m <- prophet(train, n.changepoints = 0)
  25. expect_error(predict(m, future), NA)
  26. })
  27. test_that("fit_predict_changepoint_not_in_history", {
  28. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  29. train_t <- dplyr::mutate(DATA, ds=zoo::as.Date(ds))
  30. train_t <- dplyr::filter(train_t, (ds < zoo::as.Date('2013-01-01')) |
  31. (ds > zoo::as.Date('2014-01-01')))
  32. future <- data.frame(ds=DATA$ds)
  33. m <- prophet(train_t, changepoints=c('2013-06-06'))
  34. expect_error(predict(m, future), NA)
  35. })
  36. test_that("fit_predict_duplicates", {
  37. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  38. train2 <- train
  39. train2$y <- train2$y + 10
  40. train_t <- rbind(train, train2)
  41. m <- prophet(train_t)
  42. expect_error(predict(m, future), NA)
  43. })
  44. test_that("setup_dataframe", {
  45. history <- train
  46. m <- prophet(history, fit = FALSE)
  47. out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE)
  48. history <- out$df
  49. expect_true('t' %in% colnames(history))
  50. expect_equal(min(history$t), 0)
  51. expect_equal(max(history$t), 1)
  52. expect_true('y_scaled' %in% colnames(history))
  53. expect_equal(max(history$y_scaled), 1)
  54. })
  55. test_that("get_changepoints", {
  56. history <- train
  57. m <- prophet(history, fit = FALSE)
  58. out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE)
  59. history <- out$df
  60. m <- out$m
  61. m$history <- history
  62. m <- prophet:::set_changepoints(m)
  63. cp <- m$changepoints.t
  64. expect_equal(length(cp), m$n.changepoints)
  65. expect_true(min(cp) > 0)
  66. expect_true(max(cp) < N)
  67. mat <- prophet:::get_changepoint_matrix(m)
  68. expect_equal(nrow(mat), floor(N / 2))
  69. expect_equal(ncol(mat), m$n.changepoints)
  70. })
  71. test_that("get_zero_changepoints", {
  72. history <- train
  73. m <- prophet(history, n.changepoints = 0, fit = FALSE)
  74. out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE)
  75. m <- out$m
  76. history <- out$df
  77. m$history <- history
  78. m <- prophet:::set_changepoints(m)
  79. cp <- m$changepoints.t
  80. expect_equal(length(cp), 1)
  81. expect_equal(cp[1], 0)
  82. mat <- prophet:::get_changepoint_matrix(m)
  83. expect_equal(nrow(mat), floor(N / 2))
  84. expect_equal(ncol(mat), 1)
  85. })
  86. test_that("fourier_series_weekly", {
  87. mat <- prophet:::fourier_series(DATA$ds, 7, 3)
  88. true.values <- c(0.7818315, 0.6234898, 0.9749279, -0.2225209, 0.4338837,
  89. -0.9009689)
  90. expect_equal(true.values, mat[1, ], tolerance = 1e-6)
  91. })
  92. test_that("fourier_series_yearly", {
  93. mat <- prophet:::fourier_series(DATA$ds, 365.25, 3)
  94. true.values <- c(0.7006152, -0.7135393, -0.9998330, 0.01827656, 0.7262249,
  95. 0.6874572)
  96. expect_equal(true.values, mat[1, ], tolerance = 1e-6)
  97. })
  98. test_that("growth_init", {
  99. history <- DATA
  100. history$cap <- max(history$y)
  101. m <- prophet(history, growth = 'logistic', fit = FALSE)
  102. out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE)
  103. m <- out$m
  104. history <- out$df
  105. params <- prophet:::linear_growth_init(history)
  106. expect_equal(params[1], 0.3055671, tolerance = 1e-6)
  107. expect_equal(params[2], 0.5307511, tolerance = 1e-6)
  108. params <- prophet:::logistic_growth_init(history)
  109. expect_equal(params[1], 1.507925, tolerance = 1e-6)
  110. expect_equal(params[2], -0.08167497, tolerance = 1e-6)
  111. })
  112. test_that("piecewise_linear", {
  113. t <- seq(0, 10)
  114. m <- 0
  115. k <- 1.0
  116. deltas <- c(0.5)
  117. changepoint.ts <- c(5)
  118. y <- prophet:::piecewise_linear(t, deltas, k, m, changepoint.ts)
  119. y.true <- c(0, 1, 2, 3, 4, 5, 6.5, 8, 9.5, 11, 12.5)
  120. expect_equal(y, y.true)
  121. t <- t[8:length(t)]
  122. y.true <- y.true[8:length(y.true)]
  123. y <- prophet:::piecewise_linear(t, deltas, k, m, changepoint.ts)
  124. expect_equal(y, y.true)
  125. })
  126. test_that("piecewise_logistic", {
  127. t <- seq(0, 10)
  128. cap <- rep(10, 11)
  129. m <- 0
  130. k <- 1.0
  131. deltas <- c(0.5)
  132. changepoint.ts <- c(5)
  133. y <- prophet:::piecewise_logistic(t, cap, deltas, k, m, changepoint.ts)
  134. y.true <- c(5.000000, 7.310586, 8.807971, 9.525741, 9.820138, 9.933071,
  135. 9.984988, 9.996646, 9.999252, 9.999833, 9.999963)
  136. expect_equal(y, y.true, tolerance = 1e-6)
  137. t <- t[8:length(t)]
  138. y.true <- y.true[8:length(y.true)]
  139. cap <- cap[8:length(cap)]
  140. y <- prophet:::piecewise_logistic(t, cap, deltas, k, m, changepoint.ts)
  141. expect_equal(y, y.true, tolerance = 1e-6)
  142. })
  143. test_that("holidays", {
  144. holidays = data.frame(ds = zoo::as.Date(c('2016-12-25')),
  145. holiday = c('xmas'),
  146. lower_window = c(-1),
  147. upper_window = c(0))
  148. df <- data.frame(
  149. ds = seq(zoo::as.Date('2016-12-20'), zoo::as.Date('2016-12-31'), by='d'))
  150. m <- prophet(train, holidays = holidays, fit = FALSE)
  151. feats <- prophet:::make_holiday_features(m, df$ds)
  152. expect_equal(nrow(feats), nrow(df))
  153. expect_equal(ncol(feats), 2)
  154. expect_equal(sum(colSums(feats) - c(1, 1)), 0)
  155. holidays = data.frame(ds = zoo::as.Date(c('2016-12-25')),
  156. holiday = c('xmas'),
  157. lower_window = c(-1),
  158. upper_window = c(10))
  159. m <- prophet(train, holidays = holidays, fit = FALSE)
  160. feats <- prophet:::make_holiday_features(m, df$ds)
  161. expect_equal(nrow(feats), nrow(df))
  162. expect_equal(ncol(feats), 12)
  163. })
  164. test_that("fit_with_holidays", {
  165. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  166. holidays <- data.frame(ds = zoo::as.Date(c('2012-06-06', '2013-06-06')),
  167. holiday = c('seans-bday', 'seans-bday'),
  168. lower_window = c(0, 0),
  169. upper_window = c(1, 1))
  170. m <- prophet(DATA, holidays = holidays, uncertainty.samples = 0)
  171. expect_error(predict(m), NA)
  172. })
  173. test_that("make_future_dataframe", {
  174. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  175. m <- prophet(train)
  176. future <- make_future_dataframe(m, periods = 3, freq = 'd',
  177. include_history = FALSE)
  178. correct <- as.Date(c('2013-04-26', '2013-04-27', '2013-04-28'))
  179. expect_equal(future$ds, correct)
  180. future <- make_future_dataframe(m, periods = 3, freq = 'm',
  181. include_history = FALSE)
  182. correct <- as.Date(c('2013-05-25', '2013-06-25', '2013-07-25'))
  183. expect_equal(future$ds, correct)
  184. })