test_prophet.R 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. library(prophet)
  2. context("Prophet tests")
  3. DATA <- read.csv('data.csv')
  4. N <- nrow(DATA)
  5. train <- DATA[1:floor(N / 2), ]
  6. future <- DATA[(ceiling(N/2) + 1):N, ]
  7. DATA2 <- read.csv('data2.csv')
  8. test_that("fit_predict", {
  9. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  10. m <- prophet(train)
  11. expect_error(predict(m, future), NA)
  12. })
  13. test_that("fit_predict_no_seasons", {
  14. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  15. m <- prophet(train, weekly.seasonality = FALSE, yearly.seasonality = FALSE)
  16. expect_error(predict(m, future), NA)
  17. })
  18. test_that("fit_predict_no_changepoints", {
  19. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  20. m <- prophet(train, n.changepoints = 0)
  21. expect_error(predict(m, future), NA)
  22. })
  23. test_that("fit_predict_changepoint_not_in_history", {
  24. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  25. train_t <- dplyr::mutate(DATA, ds=prophet:::set_date(ds))
  26. train_t <- dplyr::filter(train_t,
  27. (ds < prophet:::set_date('2013-01-01')) |
  28. (ds > prophet:::set_date('2014-01-01')))
  29. future <- data.frame(ds=DATA$ds)
  30. m <- prophet(train_t, changepoints=c('2013-06-06'))
  31. expect_error(predict(m, future), NA)
  32. })
  33. test_that("fit_predict_duplicates", {
  34. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  35. train2 <- train
  36. train2$y <- train2$y + 10
  37. train_t <- rbind(train, train2)
  38. m <- prophet(train_t)
  39. expect_error(predict(m, future), NA)
  40. })
  41. test_that("setup_dataframe", {
  42. history <- train
  43. m <- prophet(history, fit = FALSE)
  44. out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE)
  45. history <- out$df
  46. expect_true('t' %in% colnames(history))
  47. expect_equal(min(history$t), 0)
  48. expect_equal(max(history$t), 1)
  49. expect_true('y_scaled' %in% colnames(history))
  50. expect_equal(max(history$y_scaled), 1)
  51. })
  52. test_that("get_changepoints", {
  53. history <- train
  54. m <- prophet(history, fit = FALSE)
  55. out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE)
  56. history <- out$df
  57. m <- out$m
  58. m$history <- history
  59. m <- prophet:::set_changepoints(m)
  60. cp <- m$changepoints.t
  61. expect_equal(length(cp), m$n.changepoints)
  62. expect_true(min(cp) > 0)
  63. expect_true(max(cp) < N)
  64. mat <- prophet:::get_changepoint_matrix(m)
  65. expect_equal(nrow(mat), floor(N / 2))
  66. expect_equal(ncol(mat), m$n.changepoints)
  67. })
  68. test_that("get_zero_changepoints", {
  69. history <- train
  70. m <- prophet(history, n.changepoints = 0, fit = FALSE)
  71. out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE)
  72. m <- out$m
  73. history <- out$df
  74. m$history <- history
  75. m <- prophet:::set_changepoints(m)
  76. cp <- m$changepoints.t
  77. expect_equal(length(cp), 1)
  78. expect_equal(cp[1], 0)
  79. mat <- prophet:::get_changepoint_matrix(m)
  80. expect_equal(nrow(mat), floor(N / 2))
  81. expect_equal(ncol(mat), 1)
  82. })
  83. test_that("override_n_changepoints", {
  84. history <- train[1:20,]
  85. m <- prophet(history, fit = FALSE)
  86. out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE)
  87. m <- out$m
  88. history <- out$df
  89. m$history <- history
  90. m <- prophet:::set_changepoints(m)
  91. expect_equal(m$n.changepoints, 15)
  92. cp <- m$changepoints.t
  93. expect_equal(length(cp), 15)
  94. })
  95. test_that("fourier_series_weekly", {
  96. mat <- prophet:::fourier_series(DATA$ds, 7, 3)
  97. true.values <- c(0.9165623, 0.3998920, 0.7330519, -0.6801727, -0.3302791,
  98. -0.9438833)
  99. expect_equal(true.values, mat[1, ], tolerance = 1e-6)
  100. })
  101. test_that("fourier_series_yearly", {
  102. mat <- prophet:::fourier_series(DATA$ds, 365.25, 3)
  103. true.values <- c(0.69702635, -0.71704551, -0.99959923, 0.02830854,
  104. 0.73648994, 0.67644849)
  105. expect_equal(true.values, mat[1, ], tolerance = 1e-6)
  106. })
  107. test_that("growth_init", {
  108. history <- DATA[1:468, ]
  109. history$cap <- max(history$y)
  110. m <- prophet(history, growth = 'logistic', fit = FALSE)
  111. out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE)
  112. m <- out$m
  113. history <- out$df
  114. params <- prophet:::linear_growth_init(history)
  115. expect_equal(params[1], 0.3055671, tolerance = 1e-6)
  116. expect_equal(params[2], 0.5307511, tolerance = 1e-6)
  117. params <- prophet:::logistic_growth_init(history)
  118. expect_equal(params[1], 1.507925, tolerance = 1e-6)
  119. expect_equal(params[2], -0.08167497, tolerance = 1e-6)
  120. })
  121. test_that("piecewise_linear", {
  122. t <- seq(0, 10)
  123. m <- 0
  124. k <- 1.0
  125. deltas <- c(0.5)
  126. changepoint.ts <- c(5)
  127. y <- prophet:::piecewise_linear(t, deltas, k, m, changepoint.ts)
  128. y.true <- c(0, 1, 2, 3, 4, 5, 6.5, 8, 9.5, 11, 12.5)
  129. expect_equal(y, y.true)
  130. t <- t[8:length(t)]
  131. y.true <- y.true[8:length(y.true)]
  132. y <- prophet:::piecewise_linear(t, deltas, k, m, changepoint.ts)
  133. expect_equal(y, y.true)
  134. })
  135. test_that("piecewise_logistic", {
  136. t <- seq(0, 10)
  137. cap <- rep(10, 11)
  138. m <- 0
  139. k <- 1.0
  140. deltas <- c(0.5)
  141. changepoint.ts <- c(5)
  142. y <- prophet:::piecewise_logistic(t, cap, deltas, k, m, changepoint.ts)
  143. y.true <- c(5.000000, 7.310586, 8.807971, 9.525741, 9.820138, 9.933071,
  144. 9.984988, 9.996646, 9.999252, 9.999833, 9.999963)
  145. expect_equal(y, y.true, tolerance = 1e-6)
  146. t <- t[8:length(t)]
  147. y.true <- y.true[8:length(y.true)]
  148. cap <- cap[8:length(cap)]
  149. y <- prophet:::piecewise_logistic(t, cap, deltas, k, m, changepoint.ts)
  150. expect_equal(y, y.true, tolerance = 1e-6)
  151. })
  152. test_that("holidays", {
  153. holidays = data.frame(ds = c('2016-12-25'),
  154. holiday = c('xmas'),
  155. lower_window = c(-1),
  156. upper_window = c(0))
  157. df <- data.frame(
  158. ds = seq(prophet:::set_date('2016-12-20'),
  159. prophet:::set_date('2016-12-31'), by='d'))
  160. m <- prophet(train, holidays = holidays, fit = FALSE)
  161. feats <- prophet:::make_holiday_features(m, df$ds)
  162. expect_equal(nrow(feats), nrow(df))
  163. expect_equal(ncol(feats), 2)
  164. expect_equal(sum(colSums(feats) - c(1, 1)), 0)
  165. holidays = data.frame(ds = c('2016-12-25'),
  166. holiday = c('xmas'),
  167. lower_window = c(-1),
  168. upper_window = c(10))
  169. m <- prophet(train, holidays = holidays, fit = FALSE)
  170. feats <- prophet:::make_holiday_features(m, df$ds)
  171. expect_equal(nrow(feats), nrow(df))
  172. expect_equal(ncol(feats), 12)
  173. })
  174. test_that("fit_with_holidays", {
  175. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  176. holidays <- data.frame(ds = c('2012-06-06', '2013-06-06'),
  177. holiday = c('seans-bday', 'seans-bday'),
  178. lower_window = c(0, 0),
  179. upper_window = c(1, 1))
  180. m <- prophet(DATA, holidays = holidays, uncertainty.samples = 0)
  181. expect_error(predict(m), NA)
  182. })
  183. test_that("make_future_dataframe", {
  184. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  185. train.t <- DATA[1:234, ]
  186. m <- prophet(train.t)
  187. future <- make_future_dataframe(m, periods = 3, freq = 'day',
  188. include_history = FALSE)
  189. correct <- prophet:::set_date(c('2013-04-26', '2013-04-27', '2013-04-28'))
  190. expect_equal(future$ds, correct)
  191. future <- make_future_dataframe(m, periods = 3, freq = 'month',
  192. include_history = FALSE)
  193. correct <- prophet:::set_date(c('2013-05-25', '2013-06-25', '2013-07-25'))
  194. expect_equal(future$ds, correct)
  195. })
  196. test_that("auto_weekly_seasonality", {
  197. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  198. # Should be enabled
  199. N.w <- 15
  200. train.w <- DATA[1:N.w, ]
  201. m <- prophet(train.w, fit = FALSE)
  202. expect_equal(m$weekly.seasonality, 'auto')
  203. m <- prophet:::fit.prophet(m, train.w)
  204. expect_true('weekly' %in% names(m$seasonalities))
  205. expect_equal(m$seasonalities[['weekly']], c(7, 3))
  206. # Should be disabled due to too short history
  207. N.w <- 9
  208. train.w <- DATA[1:N.w, ]
  209. m <- prophet(train.w)
  210. expect_false('weekly' %in% names(m$seasonalities))
  211. m <- prophet(train.w, weekly.seasonality = TRUE)
  212. expect_true('weekly' %in% names(m$seasonalities))
  213. # Should be False due to weekly spacing
  214. train.w <- DATA[seq(1, nrow(DATA), 7), ]
  215. m <- prophet(train.w)
  216. expect_false('weekly' %in% names(m$seasonalities))
  217. m <- prophet(DATA, weekly.seasonality=2)
  218. expect_equal(m$seasonalities[['weekly']], c(7, 2))
  219. })
  220. test_that("auto_yearly_seasonality", {
  221. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  222. # Should be enabled
  223. m <- prophet(DATA, fit = FALSE)
  224. expect_equal(m$yearly.seasonality, 'auto')
  225. m <- prophet:::fit.prophet(m, DATA)
  226. expect_true('yearly' %in% names(m$seasonalities))
  227. expect_equal(m$seasonalities[['yearly']], c(365.25, 10))
  228. # Should be disabled due to too short history
  229. N.w <- 240
  230. train.y <- DATA[1:N.w, ]
  231. m <- prophet(train.y)
  232. expect_false('yearly' %in% names(m$seasonalities))
  233. m <- prophet(train.y, yearly.seasonality = TRUE)
  234. expect_true('yearly' %in% names(m$seasonalities))
  235. m <- prophet(DATA, yearly.seasonality=7)
  236. expect_equal(m$seasonalities[['yearly']], c(365.25, 7))
  237. })
  238. test_that("auto_daily_seasonality", {
  239. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  240. # Should be enabled
  241. m <- prophet(DATA2, fit = FALSE)
  242. expect_equal(m$daily.seasonality, 'auto')
  243. m <- prophet:::fit.prophet(m, DATA2)
  244. expect_true('daily' %in% names(m$seasonalities))
  245. expect_equal(m$seasonalities[['daily']], c(1, 4))
  246. # Should be disabled due to too short history
  247. N.d <- 430
  248. train.y <- DATA2[1:N.d, ]
  249. m <- prophet(train.y)
  250. expect_false('daily' %in% names(m$seasonalities))
  251. m <- prophet(train.y, daily.seasonality = TRUE)
  252. expect_true('daily' %in% names(m$seasonalities))
  253. m <- prophet(DATA2, daily.seasonality=7)
  254. expect_equal(m$seasonalities[['daily']], c(1, 7))
  255. m <- prophet(DATA)
  256. expect_false('daily' %in% names(m$seasonalities))
  257. })
  258. test_that("test_subdaily_holidays", {
  259. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  260. holidays <- data.frame(ds = c('2017-01-02'),
  261. holiday = c('special_day'))
  262. m <- prophet(DATA2, holidays=holidays)
  263. fcst <- predict(m)
  264. expect_equal(sum(fcst$special_day == 0), 575)
  265. })
  266. test_that("custom_seasonality", {
  267. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  268. holidays <- data.frame(ds = c('2017-01-02'),
  269. holiday = c('special_day'))
  270. m <- prophet(holidays=holidays)
  271. m <- add_seasonality(m, name='monthly', period=30, fourier.order=5)
  272. expect_equal(m$seasonalities[['monthly']], c(30, 5))
  273. })