test_prophet.R 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685
  1. library(prophet)
  2. context("Prophet tests")
  3. DATA <- read.csv('data.csv')
  4. N <- nrow(DATA)
  5. train <- DATA[1:floor(N / 2), ]
  6. future <- DATA[(ceiling(N/2) + 1):N, ]
  7. DATA2 <- read.csv('data2.csv')
  8. DATA$ds <- prophet:::set_date(DATA$ds)
  9. DATA2$ds <- prophet:::set_date(DATA2$ds)
  10. test_that("fit_predict", {
  11. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  12. m <- prophet(train)
  13. expect_error(predict(m, future), NA)
  14. })
  15. test_that("fit_predict_no_seasons", {
  16. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  17. m <- prophet(train, weekly.seasonality = FALSE, yearly.seasonality = FALSE)
  18. expect_error(predict(m, future), NA)
  19. })
  20. test_that("fit_predict_no_changepoints", {
  21. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  22. expect_warning({
  23. # warning from prophet(), error from predict()
  24. m <- prophet(train, n.changepoints = 0)
  25. expect_error(predict(m, future), NA)
  26. })
  27. })
  28. test_that("fit_predict_changepoint_not_in_history", {
  29. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  30. train_t <- dplyr::mutate(DATA, ds=prophet:::set_date(ds))
  31. train_t <- dplyr::filter(train_t,
  32. (ds < prophet:::set_date('2013-01-01')) |
  33. (ds > prophet:::set_date('2014-01-01')))
  34. future <- data.frame(ds=DATA$ds)
  35. expect_warning({
  36. # warning from prophet(), error from predict()
  37. m <- prophet(train_t, changepoints=c('2013-06-06'))
  38. expect_error(predict(m, future), NA)
  39. })
  40. })
  41. test_that("fit_predict_duplicates", {
  42. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  43. train2 <- train
  44. train2$y <- train2$y + 10
  45. train_t <- rbind(train, train2)
  46. m <- prophet(train_t)
  47. expect_error(predict(m, future), NA)
  48. })
  49. test_that("fit_predict_constant_history", {
  50. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  51. train2 <- train
  52. train2$y <- 20
  53. m <- prophet(train2)
  54. fcst <- predict(m, future)
  55. expect_equal(tail(fcst$yhat, 1), 20)
  56. train2$y <- 0
  57. m <- prophet(train2)
  58. fcst <- predict(m, future)
  59. expect_equal(tail(fcst$yhat, 1), 0)
  60. })
  61. test_that("setup_dataframe", {
  62. history <- train
  63. m <- prophet(history, fit = FALSE)
  64. out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE)
  65. history <- out$df
  66. expect_true('t' %in% colnames(history))
  67. expect_equal(min(history$t), 0)
  68. expect_equal(max(history$t), 1)
  69. expect_true('y_scaled' %in% colnames(history))
  70. expect_equal(max(history$y_scaled), 1)
  71. })
  72. test_that("logistic_floor", {
  73. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  74. skip_on_os('mac') # Resolves mysterious CRAN build issue
  75. m <- prophet(growth = 'logistic')
  76. history <- train
  77. history$floor <- 10.
  78. history$cap <- 80.
  79. future1 <- future
  80. future1$cap <- 80.
  81. future1$floor <- 10.
  82. m <- fit.prophet(m, history, algorithm = 'Newton')
  83. expect_true(m$logistic.floor)
  84. expect_true('floor' %in% colnames(m$history))
  85. expect_equal(m$history$y_scaled[1], 1., tolerance = 1e-6)
  86. fcst1 <- predict(m, future1)
  87. m2 <- prophet(growth = 'logistic')
  88. history2 <- history
  89. history2$y <- history2$y + 10.
  90. history2$floor <- history2$floor + 10.
  91. history2$cap <- history2$cap + 10.
  92. future1$cap <- future1$cap + 10.
  93. future1$floor <- future1$floor + 10.
  94. m2 <- fit.prophet(m2, history2, algorithm = 'Newton')
  95. expect_equal(m2$history$y_scaled[1], 1., tolerance = 1e-6)
  96. fcst2 <- predict(m, future1)
  97. fcst2$yhat <- fcst2$yhat - 10.
  98. # Check for approximate shift invariance
  99. expect_true(all(abs(fcst1$yhat - fcst2$yhat) < 1))
  100. })
  101. test_that("get_changepoints", {
  102. history <- train
  103. m <- prophet(history, fit = FALSE)
  104. out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE)
  105. history <- out$df
  106. m <- out$m
  107. m$history <- history
  108. m <- prophet:::set_changepoints(m)
  109. cp <- m$changepoints.t
  110. expect_equal(length(cp), m$n.changepoints)
  111. expect_true(min(cp) > 0)
  112. expect_true(max(cp) <= history$t[ceiling(0.8 * length(history$t))])
  113. })
  114. test_that("set_changepoint_range", {
  115. history <- train
  116. m <- prophet(history, fit = FALSE, changepoint.range = 0.4)
  117. out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE)
  118. history <- out$df
  119. m <- out$m
  120. m$history <- history
  121. m <- prophet:::set_changepoints(m)
  122. cp <- m$changepoints.t
  123. expect_equal(length(cp), m$n.changepoints)
  124. expect_true(min(cp) > 0)
  125. expect_true(max(cp) <= history$t[ceiling(0.4 * length(history$t))])
  126. expect_error(prophet(history, changepoint.range = -0.1))
  127. expect_error(prophet(history, changepoint.range = 2))
  128. })
  129. test_that("get_zero_changepoints", {
  130. history <- train
  131. m <- prophet(history, n.changepoints = 0, fit = FALSE)
  132. out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE)
  133. m <- out$m
  134. history <- out$df
  135. m$history <- history
  136. m <- prophet:::set_changepoints(m)
  137. cp <- m$changepoints.t
  138. expect_equal(length(cp), 1)
  139. expect_equal(cp[1], 0)
  140. })
  141. test_that("override_n_changepoints", {
  142. history <- train[1:20,]
  143. m <- prophet(history, fit = FALSE)
  144. out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE)
  145. m <- out$m
  146. history <- out$df
  147. m$history <- history
  148. m <- prophet:::set_changepoints(m)
  149. expect_equal(m$n.changepoints, 15)
  150. cp <- m$changepoints.t
  151. expect_equal(length(cp), 15)
  152. })
  153. test_that("fourier_series_weekly", {
  154. true.values <- c(0.7818315, 0.6234898, 0.9749279, -0.2225209, 0.4338837,
  155. -0.9009689)
  156. mat <- prophet:::fourier_series(DATA$ds, 7, 3)
  157. expect_equal(true.values, mat[1, ], tolerance = 1e-6)
  158. })
  159. test_that("fourier_series_yearly", {
  160. true.values <- c(0.7006152, -0.7135393, -0.9998330, 0.01827656, 0.7262249,
  161. 0.6874572)
  162. mat <- prophet:::fourier_series(DATA$ds, 365.25, 3)
  163. expect_equal(true.values, mat[1, ], tolerance = 1e-6)
  164. })
  165. test_that("growth_init", {
  166. history <- DATA[1:468, ]
  167. history$cap <- max(history$y)
  168. m <- prophet(history, growth = 'logistic', fit = FALSE)
  169. out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE)
  170. m <- out$m
  171. history <- out$df
  172. params <- prophet:::linear_growth_init(history)
  173. expect_equal(params[1], 0.3055671, tolerance = 1e-6)
  174. expect_equal(params[2], 0.5307511, tolerance = 1e-6)
  175. params <- prophet:::logistic_growth_init(history)
  176. expect_equal(params[1], 1.507925, tolerance = 1e-6)
  177. expect_equal(params[2], -0.08167497, tolerance = 1e-6)
  178. })
  179. test_that("piecewise_linear", {
  180. t <- seq(0, 10)
  181. m <- 0
  182. k <- 1.0
  183. deltas <- c(0.5)
  184. changepoint.ts <- c(5)
  185. y <- prophet:::piecewise_linear(t, deltas, k, m, changepoint.ts)
  186. y.true <- c(0, 1, 2, 3, 4, 5, 6.5, 8, 9.5, 11, 12.5)
  187. expect_equal(y, y.true)
  188. t <- t[8:length(t)]
  189. y.true <- y.true[8:length(y.true)]
  190. y <- prophet:::piecewise_linear(t, deltas, k, m, changepoint.ts)
  191. expect_equal(y, y.true)
  192. })
  193. test_that("piecewise_logistic", {
  194. t <- seq(0, 10)
  195. cap <- rep(10, 11)
  196. m <- 0
  197. k <- 1.0
  198. deltas <- c(0.5)
  199. changepoint.ts <- c(5)
  200. y <- prophet:::piecewise_logistic(t, cap, deltas, k, m, changepoint.ts)
  201. y.true <- c(5.000000, 7.310586, 8.807971, 9.525741, 9.820138, 9.933071,
  202. 9.984988, 9.996646, 9.999252, 9.999833, 9.999963)
  203. expect_equal(y, y.true, tolerance = 1e-6)
  204. t <- t[8:length(t)]
  205. y.true <- y.true[8:length(y.true)]
  206. cap <- cap[8:length(cap)]
  207. y <- prophet:::piecewise_logistic(t, cap, deltas, k, m, changepoint.ts)
  208. expect_equal(y, y.true, tolerance = 1e-6)
  209. })
  210. test_that("holidays", {
  211. holidays <- data.frame(ds = c('2016-12-25'),
  212. holiday = c('xmas'),
  213. lower_window = c(-1),
  214. upper_window = c(0))
  215. df <- data.frame(
  216. ds = seq(prophet:::set_date('2016-12-20'),
  217. prophet:::set_date('2016-12-31'), by='d'))
  218. m <- prophet(train, holidays = holidays, fit = FALSE)
  219. out <- prophet:::make_holiday_features(m, df$ds, m$holidays)
  220. feats <- out$holiday.features
  221. priors <- out$prior.scales
  222. names <- out$holiday.names
  223. expect_equal(nrow(feats), nrow(df))
  224. expect_equal(ncol(feats), 2)
  225. expect_equal(sum(colSums(feats) - c(1, 1)), 0)
  226. expect_true(all(priors == c(10., 10.)))
  227. expect_equal(names, c('xmas'))
  228. holidays <- data.frame(ds = c('2016-12-25'),
  229. holiday = c('xmas'),
  230. lower_window = c(-1),
  231. upper_window = c(10))
  232. m <- prophet(train, holidays = holidays, fit = FALSE)
  233. out <- prophet:::make_holiday_features(m, df$ds, m$holidays)
  234. feats <- out$holiday.features
  235. priors <- out$prior.scales
  236. names <- out$holiday.names
  237. expect_equal(nrow(feats), nrow(df))
  238. expect_equal(ncol(feats), 12)
  239. expect_true(all(priors == rep(10, 12)))
  240. expect_equal(names, c('xmas'))
  241. # Check prior specifications
  242. holidays <- data.frame(
  243. ds = prophet:::set_date(c('2016-12-25', '2017-12-25')),
  244. holiday = c('xmas', 'xmas'),
  245. lower_window = c(-1, -1),
  246. upper_window = c(0, 0),
  247. prior_scale = c(5., 5.)
  248. )
  249. m <- prophet(holidays = holidays, fit = FALSE)
  250. out <- prophet:::make_holiday_features(m, df$ds, m$holidays)
  251. priors <- out$prior.scales
  252. names <- out$holiday.names
  253. expect_true(all(priors == c(5., 5.)))
  254. expect_equal(names, c('xmas'))
  255. # 2 different priors
  256. holidays2 <- data.frame(
  257. ds = prophet:::set_date(c('2012-06-06', '2013-06-06')),
  258. holiday = c('seans-bday', 'seans-bday'),
  259. lower_window = c(0, 0),
  260. upper_window = c(1, 1),
  261. prior_scale = c(8, 8)
  262. )
  263. holidays2 <- rbind(holidays, holidays2)
  264. m <- prophet(holidays = holidays2, fit = FALSE)
  265. out <- prophet:::make_holiday_features(m, df$ds, m$holidays)
  266. priors <- out$prior.scales
  267. names <- out$holiday.names
  268. expect_true(all(priors == c(8, 8, 5, 5)))
  269. expect_true(all(sort(names) == c('seans-bday', 'xmas')))
  270. holidays2 <- data.frame(
  271. ds = prophet:::set_date(c('2012-06-06', '2013-06-06')),
  272. holiday = c('seans-bday', 'seans-bday'),
  273. lower_window = c(0, 0),
  274. upper_window = c(1, 1)
  275. )
  276. # manual coercions to avoid below bind_rows() warning
  277. holidays$holiday <- as.character(holidays$holiday)
  278. holidays2$holiday <- as.character(holidays2$holiday)
  279. holidays2 <- dplyr::bind_rows(holidays, holidays2)
  280. # manual factorizing to avoid above bind_rows() warning
  281. holidays2$holiday <- factor(holidays2$holiday)
  282. m <- prophet(holidays = holidays2, fit = FALSE, holidays.prior.scale = 4)
  283. out <- prophet:::make_holiday_features(m, df$ds, m$holidays)
  284. priors <- out$prior.scales
  285. expect_true(all(priors == c(4, 4, 5, 5)))
  286. # Check incompatible priors
  287. holidays <- data.frame(
  288. ds = prophet:::set_date(c('2016-12-25', '2016-12-27')),
  289. holiday = c('xmasish', 'xmasish'),
  290. lower_window = c(-1, -1),
  291. upper_window = c(0, 0),
  292. prior_scale = c(5., 6.)
  293. )
  294. m <- prophet(holidays = holidays, fit = FALSE)
  295. expect_error(prophet:::make_holiday_features(m, df$ds, m$holidays))
  296. })
  297. test_that("fit_with_holidays", {
  298. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  299. holidays <- data.frame(ds = c('2012-06-06', '2013-06-06'),
  300. holiday = c('seans-bday', 'seans-bday'),
  301. lower_window = c(0, 0),
  302. upper_window = c(1, 1))
  303. m <- prophet(DATA, holidays = holidays, uncertainty.samples = 0)
  304. expect_error(predict(m), NA)
  305. })
  306. test_that("fit_with_country_holidays", {
  307. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  308. holidays <- data.frame(ds = c('2012-06-06', '2013-06-06'),
  309. holiday = c('seans-bday', 'seans-bday'),
  310. lower_window = c(0, 0),
  311. upper_window = c(1, 1))
  312. # Test with holidays and append_holidays
  313. m <- prophet(holidays = holidays, uncertainty.samples = 0)
  314. m <- add_country_holidays(m, 'US')
  315. m <- fit.prophet(m, DATA)
  316. expect_error(predict(m), NA)
  317. # There are training holidays missing in the test set
  318. train2 <- DATA %>% head(155)
  319. future2 <- DATA %>% tail(355)
  320. m <- prophet(uncertainty.samples = 0)
  321. m <- add_country_holidays(m, 'US')
  322. m <- fit.prophet(m, train2)
  323. expect_error(predict(m, future2), NA)
  324. # There are test holidays missing in the training set
  325. train2 <- DATA %>% tail(355)
  326. future2 <- DATA2
  327. m <- prophet(uncertainty.samples = 0)
  328. m <- add_country_holidays(m, 'US')
  329. m <- fit.prophet(m, train2)
  330. expect_error(predict(m, future2), NA)
  331. # Append_holidays with non-existing year
  332. max.year <- generated_holidays %>%
  333. dplyr::filter(country=='US') %>%
  334. dplyr::select(year) %>%
  335. max()
  336. train2 <- data.frame('ds'=c(paste(max.year+1, "-01-01", sep=''),
  337. paste(max.year+1, "-01-02", sep='')),
  338. 'y'=1)
  339. m <- prophet()
  340. m <- add_country_holidays(m, 'US')
  341. expect_warning(m <- fit.prophet(m, train2))
  342. # Append_holidays with non-existing country
  343. m <- prophet()
  344. expect_error(add_country_holidays(m, 'Utopia'))
  345. })
  346. test_that("make_future_dataframe", {
  347. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  348. train.t <- DATA[1:234, ]
  349. m <- prophet(train.t)
  350. future <- make_future_dataframe(m, periods = 3, freq = 'day',
  351. include_history = FALSE)
  352. correct <- prophet:::set_date(c('2013-04-26', '2013-04-27', '2013-04-28'))
  353. expect_equal(future$ds, correct)
  354. future <- make_future_dataframe(m, periods = 3, freq = 'month',
  355. include_history = FALSE)
  356. correct <- prophet:::set_date(c('2013-05-25', '2013-06-25', '2013-07-25'))
  357. expect_equal(future$ds, correct)
  358. })
  359. test_that("auto_weekly_seasonality", {
  360. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  361. # Should be enabled
  362. N.w <- 15
  363. train.w <- DATA[1:N.w, ]
  364. m <- prophet(train.w, fit = FALSE)
  365. expect_equal(m$weekly.seasonality, 'auto')
  366. m <- fit.prophet(m, train.w)
  367. expect_true('weekly' %in% names(m$seasonalities))
  368. true <- list(
  369. period = 7, fourier.order = 3, prior.scale = 10, mode = 'additive')
  370. for (name in names(true)) {
  371. expect_equal(m$seasonalities$weekly[[name]], true[[name]])
  372. }
  373. # Should be disabled due to too short history
  374. N.w <- 9
  375. train.w <- DATA[1:N.w, ]
  376. m <- prophet(train.w)
  377. expect_false('weekly' %in% names(m$seasonalities))
  378. expect_warning({
  379. # prophet warning: non-zero return code in optimizing
  380. m <- prophet(train.w, weekly.seasonality = TRUE)
  381. expect_true('weekly' %in% names(m$seasonalities))
  382. })
  383. # Should be False due to weekly spacing
  384. train.w <- DATA[seq(1, nrow(DATA), 7), ]
  385. m <- prophet(train.w)
  386. expect_false('weekly' %in% names(m$seasonalities))
  387. m <- prophet(DATA, weekly.seasonality = 2, seasonality.prior.scale = 3)
  388. true <- list(
  389. period = 7, fourier.order = 2, prior.scale = 3, mode = 'additive')
  390. for (name in names(true)) {
  391. expect_equal(m$seasonalities$weekly[[name]], true[[name]])
  392. }
  393. })
  394. test_that("auto_yearly_seasonality", {
  395. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  396. # Should be enabled
  397. m <- prophet(DATA, fit = FALSE)
  398. expect_equal(m$yearly.seasonality, 'auto')
  399. m <- fit.prophet(m, DATA)
  400. expect_true('yearly' %in% names(m$seasonalities))
  401. true <- list(
  402. period = 365.25, fourier.order = 10, prior.scale = 10, mode = 'additive')
  403. for (name in names(true)) {
  404. expect_equal(m$seasonalities$yearly[[name]], true[[name]])
  405. }
  406. # Should be disabled due to too short history
  407. N.w <- 240
  408. train.y <- DATA[1:N.w, ]
  409. m <- prophet(train.y)
  410. expect_false('yearly' %in% names(m$seasonalities))
  411. m <- prophet(train.y, yearly.seasonality = TRUE)
  412. expect_true('yearly' %in% names(m$seasonalities))
  413. m <- prophet(DATA, yearly.seasonality = 7, seasonality.prior.scale = 3)
  414. true <- list(
  415. period = 365.25, fourier.order = 7, prior.scale = 3, mode = 'additive')
  416. for (name in names(true)) {
  417. expect_equal(m$seasonalities$yearly[[name]], true[[name]])
  418. }
  419. })
  420. test_that("auto_daily_seasonality", {
  421. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  422. # Should be enabled
  423. m <- prophet(DATA2, fit = FALSE)
  424. expect_equal(m$daily.seasonality, 'auto')
  425. m <- fit.prophet(m, DATA2)
  426. expect_true('daily' %in% names(m$seasonalities))
  427. true <- list(
  428. period = 1, fourier.order = 4, prior.scale = 10, mode = 'additive')
  429. for (name in names(true)) {
  430. expect_equal(m$seasonalities$daily[[name]], true[[name]])
  431. }
  432. # Should be disabled due to too short history
  433. N.d <- 430
  434. train.y <- DATA2[1:N.d, ]
  435. m <- prophet(train.y)
  436. expect_false('daily' %in% names(m$seasonalities))
  437. m <- prophet(train.y, daily.seasonality = TRUE)
  438. expect_true('daily' %in% names(m$seasonalities))
  439. m <- prophet(DATA2, daily.seasonality = 7, seasonality.prior.scale = 3)
  440. true <- list(
  441. period = 1, fourier.order = 7, prior.scale = 3, mode = 'additive')
  442. for (name in names(true)) {
  443. expect_equal(m$seasonalities$daily[[name]], true[[name]])
  444. }
  445. m <- prophet(DATA)
  446. expect_false('daily' %in% names(m$seasonalities))
  447. })
  448. test_that("test_subdaily_holidays", {
  449. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  450. holidays <- data.frame(ds = c('2017-01-02'),
  451. holiday = c('special_day'))
  452. m <- prophet(DATA2, holidays=holidays)
  453. fcst <- predict(m)
  454. expect_equal(sum(fcst$special_day == 0), 575)
  455. })
  456. test_that("custom_seasonality", {
  457. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  458. holidays <- data.frame(ds = c('2017-01-02'),
  459. holiday = c('special_day'),
  460. prior_scale = c(4))
  461. m <- prophet(holidays=holidays)
  462. m <- add_seasonality(m, name='monthly', period=30, fourier.order=5)
  463. true <- list(
  464. period = 30, fourier.order = 5, prior.scale = 10, mode = 'additive')
  465. for (name in names(true)) {
  466. expect_equal(m$seasonalities$monthly[[name]], true[[name]])
  467. }
  468. expect_error(
  469. add_seasonality(m, name='special_day', period=30, fourier_order=5)
  470. )
  471. expect_error(
  472. add_seasonality(m, name='trend', period=30, fourier_order=5)
  473. )
  474. m <- add_seasonality(m, name='weekly', period=30, fourier.order=5)
  475. # Test priors
  476. m <- prophet(
  477. holidays = holidays, yearly.seasonality = FALSE,
  478. seasonality.mode = 'multiplicative')
  479. m <- add_seasonality(
  480. m, name='monthly', period=30, fourier.order=5, prior.scale = 2,
  481. mode = 'additive')
  482. m <- fit.prophet(m, DATA)
  483. expect_equal(m$seasonalities$monthly$mode, 'additive')
  484. expect_equal(m$seasonalities$weekly$mode, 'multiplicative')
  485. out <- prophet:::make_all_seasonality_features(m, m$history)
  486. prior.scales <- out$prior.scales
  487. component.cols <- out$component.cols
  488. expect_equal(sum(component.cols$monthly), 10)
  489. expect_equal(sum(component.cols$special_day), 1)
  490. expect_equal(sum(component.cols$weekly), 6)
  491. expect_equal(sum(component.cols$additive_terms), 10)
  492. expect_equal(sum(component.cols$multiplicative_terms), 7)
  493. expect_equal(sum(component.cols$monthly[1:11]), 10)
  494. expect_equal(sum(component.cols$weekly[11:17]), 6)
  495. expect_true(all(prior.scales == c(rep(2, 10), rep(10, 6), 4)))
  496. })
  497. test_that("added_regressors", {
  498. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  499. m <- prophet()
  500. m <- add_regressor(m, 'binary_feature', prior.scale=0.2)
  501. m <- add_regressor(m, 'numeric_feature', prior.scale=0.5)
  502. m <- add_regressor(
  503. m, 'numeric_feature2', prior.scale=0.5, mode = 'multiplicative')
  504. m <- add_regressor(m, 'binary_feature2', standardize=TRUE)
  505. df <- DATA
  506. df$binary_feature <- c(rep(0, 255), rep(1, 255))
  507. df$numeric_feature <- 0:509
  508. df$numeric_feature2 <- 0:509
  509. # Require all regressors in df
  510. expect_error(
  511. fit.prophet(m, df)
  512. )
  513. df$binary_feature2 <- c(rep(1, 100), rep(0, 410))
  514. m <- fit.prophet(m, df)
  515. # Check that standardizations are correctly set
  516. true <- list(
  517. prior.scale = 0.2, mu = 0, std = 1, standardize = 'auto', mode = 'additive'
  518. )
  519. for (name in names(true)) {
  520. expect_equal(true[[name]], m$extra_regressors$binary_feature[[name]])
  521. }
  522. true <- list(prior.scale = 0.5, mu = 254.5, std = 147.368585)
  523. for (name in names(true)) {
  524. expect_equal(true[[name]], m$extra_regressors$numeric_feature[[name]],
  525. tolerance = 1e-5)
  526. }
  527. expect_equal(m$extra_regressors$numeric_feature2$mode, 'multiplicative')
  528. true <- list(prior.scale = 10., mu = 0.1960784, std = 0.3974183)
  529. for (name in names(true)) {
  530. expect_equal(true[[name]], m$extra_regressors$binary_feature2[[name]],
  531. tolerance = 1e-5)
  532. }
  533. # Check that standardization is done correctly
  534. df2 <- prophet:::setup_dataframe(m, df)$df
  535. expect_equal(df2$binary_feature[1], 0)
  536. expect_equal(df2$numeric_feature[1], -1.726962, tolerance = 1e-4)
  537. expect_equal(df2$binary_feature2[1], 2.022859, tolerance = 1e-4)
  538. # Check that feature matrix and prior scales are correctly constructed
  539. out <- prophet:::make_all_seasonality_features(m, df2)
  540. seasonal.features <- out$seasonal.features
  541. prior.scales <- out$prior.scales
  542. component.cols <- out$component.cols
  543. modes <- out$modes
  544. expect_equal(ncol(seasonal.features), 30)
  545. r_names <- c('binary_feature', 'numeric_feature', 'binary_feature2')
  546. true.priors <- c(0.2, 0.5, 10.)
  547. for (i in seq_along(r_names)) {
  548. name <- r_names[i]
  549. expect_true(name %in% colnames(seasonal.features))
  550. expect_equal(sum(component.cols[[name]]), 1)
  551. expect_equal(sum(prior.scales * component.cols[[name]]), true.priors[i])
  552. }
  553. # Check that forecast components are reasonable
  554. future <- data.frame(
  555. ds = c('2014-06-01'),
  556. binary_feature = c(0),
  557. numeric_feature = c(10),
  558. numeric_feature2 = c(10)
  559. )
  560. expect_error(predict(m, future))
  561. future$binary_feature2 <- 0.
  562. fcst <- predict(m, future)
  563. expect_equal(ncol(fcst), 37)
  564. expect_equal(fcst$binary_feature[1], 0)
  565. expect_equal(fcst$extra_regressors_additive[1],
  566. fcst$numeric_feature[1] + fcst$binary_feature2[1])
  567. expect_equal(fcst$extra_regressors_multiplicative[1],
  568. fcst$numeric_feature2[1])
  569. expect_equal(fcst$additive_terms[1],
  570. fcst$yearly[1] + fcst$weekly[1]
  571. + fcst$extra_regressors_additive[1])
  572. expect_equal(fcst$multiplicative_terms[1],
  573. fcst$extra_regressors_multiplicative[1])
  574. expect_equal(
  575. fcst$yhat[1],
  576. fcst$trend[1] * (1 + fcst$multiplicative_terms[1]) + fcst$additive_terms[1]
  577. )
  578. # Check works with constant extra regressor of 0
  579. df$constant_feature <- 0
  580. m <- prophet()
  581. m <- add_regressor(m, 'constant_feature', standardize = TRUE)
  582. m <- fit.prophet(m, df)
  583. expect_equal(m$extra_regressors$constant_feature$std, 1)
  584. })
  585. test_that("set_seasonality_mode", {
  586. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  587. m <- prophet()
  588. expect_equal(m$seasonality.mode, 'additive')
  589. m <- prophet(seasonality.mode = 'multiplicative')
  590. expect_equal(m$seasonality.mode, 'multiplicative')
  591. expect_error(prophet(seasonality.mode = 'batman'))
  592. })
  593. test_that("seasonality_modes", {
  594. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  595. holidays <- data.frame(ds = c('2016-12-25'),
  596. holiday = c('xmas'),
  597. lower_window = c(-1),
  598. upper_window = c(0))
  599. m <- prophet(seasonality.mode = 'multiplicative', holidays = holidays)
  600. m <- add_seasonality(
  601. m, name = 'monthly', period = 30, fourier.order = 3, mode = 'additive')
  602. m <- add_regressor(m, name = 'binary_feature', mode = 'additive')
  603. m <- add_regressor(m, name = 'numeric_feature')
  604. # Construct seasonal features
  605. df <- DATA
  606. df$binary_feature <- c(rep(0, 255), rep(1, 255))
  607. df$numeric_feature <- 0:509
  608. out <- prophet:::setup_dataframe(m, df, initialize_scales = TRUE)
  609. df <- out$df
  610. m <- out$m
  611. m$history <- df
  612. m <- prophet:::set_auto_seasonalities(m)
  613. out <- prophet:::make_all_seasonality_features(m, df)
  614. component.cols <- out$component.cols
  615. modes <- out$modes
  616. expect_equal(sum(component.cols$additive_terms), 7)
  617. expect_equal(sum(component.cols$multiplicative_terms), 29)
  618. expect_equal(
  619. sort(modes$additive),
  620. c('additive_terms', 'binary_feature', 'extra_regressors_additive',
  621. 'monthly')
  622. )
  623. expect_equal(
  624. sort(modes$multiplicative),
  625. c('extra_regressors_multiplicative', 'holidays', 'multiplicative_terms',
  626. 'numeric_feature', 'weekly', 'xmas', 'yearly')
  627. )
  628. })