test_prophet.R 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567
  1. library(prophet)
  2. context("Prophet tests")
  3. DATA <- read.csv('data.csv')
  4. N <- nrow(DATA)
  5. train <- DATA[1:floor(N / 2), ]
  6. future <- DATA[(ceiling(N/2) + 1):N, ]
  7. DATA2 <- read.csv('data2.csv')
  8. test_that("fit_predict", {
  9. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  10. m <- prophet(train)
  11. expect_error(predict(m, future), NA)
  12. })
  13. test_that("fit_predict_no_seasons", {
  14. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  15. m <- prophet(train, weekly.seasonality = FALSE, yearly.seasonality = FALSE)
  16. expect_error(predict(m, future), NA)
  17. })
  18. test_that("fit_predict_no_changepoints", {
  19. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  20. m <- prophet(train, n.changepoints = 0)
  21. expect_error(predict(m, future), NA)
  22. })
  23. test_that("fit_predict_changepoint_not_in_history", {
  24. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  25. train_t <- dplyr::mutate(DATA, ds=prophet:::set_date(ds))
  26. train_t <- dplyr::filter(train_t,
  27. (ds < prophet:::set_date('2013-01-01')) |
  28. (ds > prophet:::set_date('2014-01-01')))
  29. future <- data.frame(ds=DATA$ds)
  30. m <- prophet(train_t, changepoints=c('2013-06-06'))
  31. expect_error(predict(m, future), NA)
  32. })
  33. test_that("fit_predict_duplicates", {
  34. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  35. train2 <- train
  36. train2$y <- train2$y + 10
  37. train_t <- rbind(train, train2)
  38. m <- prophet(train_t)
  39. expect_error(predict(m, future), NA)
  40. })
  41. test_that("fit_predict_constant_history", {
  42. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  43. train2 <- train
  44. train2$y <- 20
  45. m <- prophet(train2)
  46. fcst <- predict(m, future)
  47. expect_equal(tail(fcst$yhat, 1), 20)
  48. train2$y <- 0
  49. m <- prophet(train2)
  50. fcst <- predict(m, future)
  51. expect_equal(tail(fcst$yhat, 1), 0)
  52. })
  53. test_that("setup_dataframe", {
  54. history <- train
  55. m <- prophet(history, fit = FALSE)
  56. out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE)
  57. history <- out$df
  58. expect_true('t' %in% colnames(history))
  59. expect_equal(min(history$t), 0)
  60. expect_equal(max(history$t), 1)
  61. expect_true('y_scaled' %in% colnames(history))
  62. expect_equal(max(history$y_scaled), 1)
  63. })
  64. test_that("logistic_floor", {
  65. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  66. m <- prophet(growth = 'logistic')
  67. history <- train
  68. history$floor <- 10.
  69. history$cap <- 80.
  70. future1 <- future
  71. future1$cap <- 80.
  72. future1$floor <- 10.
  73. m <- fit.prophet(m, history)
  74. expect_true(m$logistic.floor)
  75. expect_true('floor' %in% colnames(m$history))
  76. expect_equal(m$history$y_scaled[1], 1., tolerance = 1e-6)
  77. fcst1 <- predict(m, future1)
  78. m2 <- prophet(growth = 'logistic')
  79. history2 <- history
  80. history2$y <- history2$y + 10.
  81. history2$floor <- history2$floor + 10.
  82. history2$cap <- history2$cap + 10.
  83. future1$cap <- future1$cap + 10.
  84. future1$floor <- future1$floor + 10.
  85. m2 <- fit.prophet(m2, history2)
  86. expect_equal(m2$history$y_scaled[1], 1., tolerance = 1e-6)
  87. fcst2 <- predict(m, future1)
  88. fcst2$yhat <- fcst2$yhat - 10.
  89. # Check for approximate shift invariance
  90. expect_true(all(abs(fcst1$yhat - fcst2$yhat) < 1))
  91. })
  92. test_that("get_changepoints", {
  93. history <- train
  94. m <- prophet(history, fit = FALSE)
  95. out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE)
  96. history <- out$df
  97. m <- out$m
  98. m$history <- history
  99. m <- prophet:::set_changepoints(m)
  100. cp <- m$changepoints.t
  101. expect_equal(length(cp), m$n.changepoints)
  102. expect_true(min(cp) > 0)
  103. expect_true(max(cp) < N)
  104. mat <- prophet:::get_changepoint_matrix(m)
  105. expect_equal(nrow(mat), floor(N / 2))
  106. expect_equal(ncol(mat), m$n.changepoints)
  107. })
  108. test_that("get_zero_changepoints", {
  109. history <- train
  110. m <- prophet(history, n.changepoints = 0, fit = FALSE)
  111. out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE)
  112. m <- out$m
  113. history <- out$df
  114. m$history <- history
  115. m <- prophet:::set_changepoints(m)
  116. cp <- m$changepoints.t
  117. expect_equal(length(cp), 1)
  118. expect_equal(cp[1], 0)
  119. mat <- prophet:::get_changepoint_matrix(m)
  120. expect_equal(nrow(mat), floor(N / 2))
  121. expect_equal(ncol(mat), 1)
  122. })
  123. test_that("override_n_changepoints", {
  124. history <- train[1:20,]
  125. m <- prophet(history, fit = FALSE)
  126. out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE)
  127. m <- out$m
  128. history <- out$df
  129. m$history <- history
  130. m <- prophet:::set_changepoints(m)
  131. expect_equal(m$n.changepoints, 15)
  132. cp <- m$changepoints.t
  133. expect_equal(length(cp), 15)
  134. })
  135. test_that("fourier_series_weekly", {
  136. mat <- prophet:::fourier_series(DATA$ds, 7, 3)
  137. true.values <- c(0.9165623, 0.3998920, 0.7330519, -0.6801727, -0.3302791,
  138. -0.9438833)
  139. expect_equal(true.values, mat[1, ], tolerance = 1e-6)
  140. })
  141. test_that("fourier_series_yearly", {
  142. mat <- prophet:::fourier_series(DATA$ds, 365.25, 3)
  143. true.values <- c(0.69702635, -0.71704551, -0.99959923, 0.02830854,
  144. 0.73648994, 0.67644849)
  145. expect_equal(true.values, mat[1, ], tolerance = 1e-6)
  146. })
  147. test_that("growth_init", {
  148. history <- DATA[1:468, ]
  149. history$cap <- max(history$y)
  150. m <- prophet(history, growth = 'logistic', fit = FALSE)
  151. out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE)
  152. m <- out$m
  153. history <- out$df
  154. params <- prophet:::linear_growth_init(history)
  155. expect_equal(params[1], 0.3055671, tolerance = 1e-6)
  156. expect_equal(params[2], 0.5307511, tolerance = 1e-6)
  157. params <- prophet:::logistic_growth_init(history)
  158. expect_equal(params[1], 1.507925, tolerance = 1e-6)
  159. expect_equal(params[2], -0.08167497, tolerance = 1e-6)
  160. })
  161. test_that("piecewise_linear", {
  162. t <- seq(0, 10)
  163. m <- 0
  164. k <- 1.0
  165. deltas <- c(0.5)
  166. changepoint.ts <- c(5)
  167. y <- prophet:::piecewise_linear(t, deltas, k, m, changepoint.ts)
  168. y.true <- c(0, 1, 2, 3, 4, 5, 6.5, 8, 9.5, 11, 12.5)
  169. expect_equal(y, y.true)
  170. t <- t[8:length(t)]
  171. y.true <- y.true[8:length(y.true)]
  172. y <- prophet:::piecewise_linear(t, deltas, k, m, changepoint.ts)
  173. expect_equal(y, y.true)
  174. })
  175. test_that("piecewise_logistic", {
  176. t <- seq(0, 10)
  177. cap <- rep(10, 11)
  178. m <- 0
  179. k <- 1.0
  180. deltas <- c(0.5)
  181. changepoint.ts <- c(5)
  182. y <- prophet:::piecewise_logistic(t, cap, deltas, k, m, changepoint.ts)
  183. y.true <- c(5.000000, 7.310586, 8.807971, 9.525741, 9.820138, 9.933071,
  184. 9.984988, 9.996646, 9.999252, 9.999833, 9.999963)
  185. expect_equal(y, y.true, tolerance = 1e-6)
  186. t <- t[8:length(t)]
  187. y.true <- y.true[8:length(y.true)]
  188. cap <- cap[8:length(cap)]
  189. y <- prophet:::piecewise_logistic(t, cap, deltas, k, m, changepoint.ts)
  190. expect_equal(y, y.true, tolerance = 1e-6)
  191. })
  192. test_that("holidays", {
  193. holidays = data.frame(ds = c('2016-12-25'),
  194. holiday = c('xmas'),
  195. lower_window = c(-1),
  196. upper_window = c(0))
  197. df <- data.frame(
  198. ds = seq(prophet:::set_date('2016-12-20'),
  199. prophet:::set_date('2016-12-31'), by='d'))
  200. m <- prophet(train, holidays = holidays, fit = FALSE)
  201. out <- prophet:::make_holiday_features(m, df$ds)
  202. feats <- out$holiday.features
  203. priors <- out$prior.scales
  204. expect_equal(nrow(feats), nrow(df))
  205. expect_equal(ncol(feats), 2)
  206. expect_equal(sum(colSums(feats) - c(1, 1)), 0)
  207. expect_true(all(priors == c(10., 10.)))
  208. holidays = data.frame(ds = c('2016-12-25'),
  209. holiday = c('xmas'),
  210. lower_window = c(-1),
  211. upper_window = c(10))
  212. m <- prophet(train, holidays = holidays, fit = FALSE)
  213. out <- prophet:::make_holiday_features(m, df$ds)
  214. feats <- out$holiday.features
  215. priors <- out$prior.scales
  216. expect_equal(nrow(feats), nrow(df))
  217. expect_equal(ncol(feats), 12)
  218. expect_true(all(priors == rep(10, 12)))
  219. # Check prior specifications
  220. holidays <- data.frame(
  221. ds = prophet:::set_date(c('2016-12-25', '2017-12-25')),
  222. holiday = c('xmas', 'xmas'),
  223. lower_window = c(-1, -1),
  224. upper_window = c(0, 0),
  225. prior_scale = c(5., 5.)
  226. )
  227. m <- prophet(holidays = holidays, fit = FALSE)
  228. out <- prophet:::make_holiday_features(m, df$ds)
  229. priors <- out$prior.scales
  230. expect_true(all(priors == c(5., 5.)))
  231. # 2 different priors
  232. holidays2 <- data.frame(
  233. ds = prophet:::set_date(c('2012-06-06', '2013-06-06')),
  234. holiday = c('seans-bday', 'seans-bday'),
  235. lower_window = c(0, 0),
  236. upper_window = c(1, 1),
  237. prior_scale = c(8, 8)
  238. )
  239. holidays2 <- rbind(holidays, holidays2)
  240. m <- prophet(holidays = holidays2, fit = FALSE)
  241. out <- prophet:::make_holiday_features(m, df$ds)
  242. priors <- out$prior.scales
  243. expect_true(all(priors == c(8, 8, 5, 5)))
  244. holidays2 <- data.frame(
  245. ds = prophet:::set_date(c('2012-06-06', '2013-06-06')),
  246. holiday = c('seans-bday', 'seans-bday'),
  247. lower_window = c(0, 0),
  248. upper_window = c(1, 1)
  249. )
  250. holidays2 <- dplyr::bind_rows(holidays, holidays2)
  251. m <- prophet(holidays = holidays2, fit = FALSE, holidays.prior.scale = 4)
  252. out <- prophet:::make_holiday_features(m, df$ds)
  253. priors <- out$prior.scales
  254. expect_true(all(priors == c(4, 4, 5, 5)))
  255. # Check incompatible priors
  256. holidays <- data.frame(
  257. ds = prophet:::set_date(c('2016-12-25', '2016-12-27')),
  258. holiday = c('xmasish', 'xmasish'),
  259. lower_window = c(-1, -1),
  260. upper_window = c(0, 0),
  261. prior_scale = c(5., 6.)
  262. )
  263. m <- prophet(holidays = holidays, fit = FALSE)
  264. expect_error(prophet:::make_holiday_features(m, df$ds))
  265. })
  266. test_that("fit_with_holidays", {
  267. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  268. holidays <- data.frame(ds = c('2012-06-06', '2013-06-06'),
  269. holiday = c('seans-bday', 'seans-bday'),
  270. lower_window = c(0, 0),
  271. upper_window = c(1, 1))
  272. m <- prophet(DATA, holidays = holidays, uncertainty.samples = 0)
  273. expect_error(predict(m), NA)
  274. })
  275. test_that("make_future_dataframe", {
  276. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  277. train.t <- DATA[1:234, ]
  278. m <- prophet(train.t)
  279. future <- make_future_dataframe(m, periods = 3, freq = 'day',
  280. include_history = FALSE)
  281. correct <- prophet:::set_date(c('2013-04-26', '2013-04-27', '2013-04-28'))
  282. expect_equal(future$ds, correct)
  283. future <- make_future_dataframe(m, periods = 3, freq = 'month',
  284. include_history = FALSE)
  285. correct <- prophet:::set_date(c('2013-05-25', '2013-06-25', '2013-07-25'))
  286. expect_equal(future$ds, correct)
  287. })
  288. test_that("auto_weekly_seasonality", {
  289. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  290. # Should be enabled
  291. N.w <- 15
  292. train.w <- DATA[1:N.w, ]
  293. m <- prophet(train.w, fit = FALSE)
  294. expect_equal(m$weekly.seasonality, 'auto')
  295. m <- fit.prophet(m, train.w)
  296. expect_true('weekly' %in% names(m$seasonalities))
  297. true <- list(period = 7, fourier.order = 3, prior.scale = 10)
  298. for (name in names(true)) {
  299. expect_equal(m$seasonalities$weekly[[name]], true[[name]])
  300. }
  301. # Should be disabled due to too short history
  302. N.w <- 9
  303. train.w <- DATA[1:N.w, ]
  304. m <- prophet(train.w)
  305. expect_false('weekly' %in% names(m$seasonalities))
  306. m <- prophet(train.w, weekly.seasonality = TRUE)
  307. expect_true('weekly' %in% names(m$seasonalities))
  308. # Should be False due to weekly spacing
  309. train.w <- DATA[seq(1, nrow(DATA), 7), ]
  310. m <- prophet(train.w)
  311. expect_false('weekly' %in% names(m$seasonalities))
  312. m <- prophet(DATA, weekly.seasonality = 2, seasonality.prior.scale = 3)
  313. true <- list(period = 7, fourier.order = 2, prior.scale = 3)
  314. for (name in names(true)) {
  315. expect_equal(m$seasonalities$weekly[[name]], true[[name]])
  316. }
  317. })
  318. test_that("auto_yearly_seasonality", {
  319. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  320. # Should be enabled
  321. m <- prophet(DATA, fit = FALSE)
  322. expect_equal(m$yearly.seasonality, 'auto')
  323. m <- fit.prophet(m, DATA)
  324. expect_true('yearly' %in% names(m$seasonalities))
  325. true <- list(period = 365.25, fourier.order = 10, prior.scale = 10)
  326. for (name in names(true)) {
  327. expect_equal(m$seasonalities$yearly[[name]], true[[name]])
  328. }
  329. # Should be disabled due to too short history
  330. N.w <- 240
  331. train.y <- DATA[1:N.w, ]
  332. m <- prophet(train.y)
  333. expect_false('yearly' %in% names(m$seasonalities))
  334. m <- prophet(train.y, yearly.seasonality = TRUE)
  335. expect_true('yearly' %in% names(m$seasonalities))
  336. m <- prophet(DATA, yearly.seasonality = 7, seasonality.prior.scale = 3)
  337. true <- list(period = 365.25, fourier.order = 7, prior.scale = 3)
  338. for (name in names(true)) {
  339. expect_equal(m$seasonalities$yearly[[name]], true[[name]])
  340. }
  341. })
  342. test_that("auto_daily_seasonality", {
  343. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  344. # Should be enabled
  345. m <- prophet(DATA2, fit = FALSE)
  346. expect_equal(m$daily.seasonality, 'auto')
  347. m <- fit.prophet(m, DATA2)
  348. expect_true('daily' %in% names(m$seasonalities))
  349. true <- list(period = 1, fourier.order = 4, prior.scale = 10)
  350. for (name in names(true)) {
  351. expect_equal(m$seasonalities$daily[[name]], true[[name]])
  352. }
  353. # Should be disabled due to too short history
  354. N.d <- 430
  355. train.y <- DATA2[1:N.d, ]
  356. m <- prophet(train.y)
  357. expect_false('daily' %in% names(m$seasonalities))
  358. m <- prophet(train.y, daily.seasonality = TRUE)
  359. expect_true('daily' %in% names(m$seasonalities))
  360. m <- prophet(DATA2, daily.seasonality = 7, seasonality.prior.scale = 3)
  361. true <- list(period = 1, fourier.order = 7, prior.scale = 3)
  362. for (name in names(true)) {
  363. expect_equal(m$seasonalities$daily[[name]], true[[name]])
  364. }
  365. m <- prophet(DATA)
  366. expect_false('daily' %in% names(m$seasonalities))
  367. })
  368. test_that("test_subdaily_holidays", {
  369. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  370. holidays <- data.frame(ds = c('2017-01-02'),
  371. holiday = c('special_day'))
  372. m <- prophet(DATA2, holidays=holidays)
  373. fcst <- predict(m)
  374. expect_equal(sum(fcst$special_day == 0), 575)
  375. })
  376. test_that("custom_seasonality", {
  377. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  378. holidays <- data.frame(ds = c('2017-01-02'),
  379. holiday = c('special_day'),
  380. prior_scale = c(4))
  381. m <- prophet(holidays=holidays)
  382. m <- add_seasonality(m, name='monthly', period=30, fourier.order=5)
  383. true <- list(period = 30, fourier.order = 5, prior.scale = 10)
  384. for (name in names(true)) {
  385. expect_equal(m$seasonalities$monthly[[name]], true[[name]])
  386. }
  387. expect_error(
  388. add_seasonality(m, name='special_day', period=30, fourier_order=5)
  389. )
  390. expect_error(
  391. add_seasonality(m, name='trend', period=30, fourier_order=5)
  392. )
  393. m <- add_seasonality(m, name='weekly', period=30, fourier.order=5)
  394. # Test priors
  395. m <- prophet(holidays = holidays, yearly.seasonality = FALSE)
  396. m <- add_seasonality(
  397. m, name='monthly', period=30, fourier.order=5, prior.scale = 2)
  398. m <- fit.prophet(m, DATA)
  399. prior.scales <- prophet:::make_all_seasonality_features(
  400. m, m$history)$prior.scales
  401. expect_true(all(prior.scales == c(rep(2, 10), rep(10, 6), 4)))
  402. })
  403. test_that("added_regressors", {
  404. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  405. m <- prophet()
  406. m <- add_regressor(m, 'binary_feature', prior.scale=0.2)
  407. m <- add_regressor(m, 'numeric_feature', prior.scale=0.5)
  408. m <- add_regressor(m, 'binary_feature2', standardize=TRUE)
  409. df <- DATA
  410. df$binary_feature <- c(rep(0, 255), rep(1, 255))
  411. df$numeric_feature <- 0:509
  412. # Require all regressors in df
  413. expect_error(
  414. fit.prophet(m, df)
  415. )
  416. df$binary_feature2 <- c(rep(1, 100), rep(0, 410))
  417. m <- fit.prophet(m, df)
  418. # Check that standardizations are correctly set
  419. true <- list(prior.scale = 0.2, mu = 0, std = 1, standardize = 'auto')
  420. for (name in names(true)) {
  421. expect_equal(true[[name]], m$extra_regressors$binary_feature[[name]])
  422. }
  423. true <- list(prior.scale = 0.5, mu = 254.5, std = 147.368585)
  424. for (name in names(true)) {
  425. expect_equal(true[[name]], m$extra_regressors$numeric_feature[[name]],
  426. tolerance = 1e-5)
  427. }
  428. true <- list(prior.scale = 10., mu = 0.1960784, std = 0.3974183)
  429. for (name in names(true)) {
  430. expect_equal(true[[name]], m$extra_regressors$binary_feature2[[name]],
  431. tolerance = 1e-5)
  432. }
  433. # Check that standardization is done correctly
  434. df2 <- prophet:::setup_dataframe(m, df)$df
  435. expect_equal(df2$binary_feature[1], 0)
  436. expect_equal(df2$numeric_feature[1], -1.726962, tolerance = 1e-4)
  437. expect_equal(df2$binary_feature2[1], 2.022859, tolerance = 1e-4)
  438. # Check that feature matrix and prior scales are correctly constructed
  439. out <- prophet:::make_all_seasonality_features(m, df2)
  440. seasonal.features <- out$seasonal.features
  441. prior.scales <- out$prior.scales
  442. expect_true('binary_feature' %in% colnames(seasonal.features))
  443. expect_true('numeric_feature' %in% colnames(seasonal.features))
  444. expect_true('binary_feature2' %in% colnames(seasonal.features))
  445. expect_equal(ncol(seasonal.features), 29)
  446. expect_true(all(sort(prior.scales[27:29]) == c(0.2, 0.5, 10.)))
  447. # Check that forecast components are reasonable
  448. future <- data.frame(
  449. ds = c('2014-06-01'), binary_feature = c(0), numeric_feature = c(10))
  450. expect_error(predict(m, future))
  451. future$binary_feature2 <- 0.
  452. fcst <- predict(m, future)
  453. expect_equal(ncol(fcst), 31)
  454. expect_equal(fcst$binary_feature[1], 0)
  455. expect_equal(fcst$extra_regressors[1],
  456. fcst$numeric_feature[1] + fcst$binary_feature2[1])
  457. expect_equal(fcst$seasonalities[1], fcst$yearly[1] + fcst$weekly[1])
  458. expect_equal(fcst$seasonal[1],
  459. fcst$seasonalities[1] + fcst$extra_regressors[1])
  460. expect_equal(fcst$yhat[1], fcst$trend[1] + fcst$seasonal[1])
  461. })
  462. test_that("copy", {
  463. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  464. inputs <- list(
  465. growth = c('linear', 'logistic'),
  466. changepoints = c(NULL, c('2016-12-25')),
  467. n.changepoints = c(3),
  468. yearly.seasonality = c(TRUE, FALSE),
  469. weekly.seasonality = c(TRUE, FALSE),
  470. daily.seasonality = c(TRUE, FALSE),
  471. holidays = c(NULL, 'insert_dataframe'),
  472. seasonality.prior.scale = c(1.1),
  473. holidays.prior.scale = c(1.1),
  474. changepoints.prior.scale = c(0.1),
  475. mcmc.samples = c(100),
  476. interval.width = c(0.9),
  477. uncertainty.samples = c(200)
  478. )
  479. products <- expand.grid(inputs)
  480. for (i in 1:length(products)) {
  481. if (products$holidays[i] == 'insert_dataframe') {
  482. holidays <- data.frame(ds=c('2016-12-25'), holiday=c('x'))
  483. } else {
  484. holidays <- NULL
  485. }
  486. m1 <- prophet(
  487. growth = products$growth[i],
  488. changepoints = products$changepoints[i],
  489. n.changepoints = products$n.changepoints[i],
  490. yearly.seasonality = products$yearly.seasonality[i],
  491. weekly.seasonality = products$weekly.seasonality[i],
  492. daily.seasonality = products$daily.seasonality[i],
  493. holidays = holidays,
  494. seasonality.prior.scale = products$seasonality.prior.scale[i],
  495. holidays.prior.scale = products$holidays.prior.scale[i],
  496. changepoints.prior.scale = products$changepoints.prior.scale[i],
  497. mcmc.samples = products$mcmc.samples[i],
  498. interval.width = products$interval.width[i],
  499. uncertainty.samples = products$uncertainty.samples[i],
  500. fit = FALSE
  501. )
  502. m2 <- prophet:::prophet_copy(m1)
  503. # Values should be copied correctly
  504. for (arg in names(inputs)) {
  505. expect_equal(m1[[arg]], m2[[arg]])
  506. }
  507. }
  508. # Check for cutoff
  509. changepoints <- seq.Date(as.Date('2012-06-15'), as.Date('2012-09-15'), by='d')
  510. cutoff <- as.Date('2012-07-25')
  511. m1 <- prophet(DATA, changepoints = changepoints)
  512. m2 <- prophet:::prophet_copy(m1, cutoff)
  513. changepoints <- changepoints[changepoints <= cutoff]
  514. expect_equal(prophet:::set_date(changepoints), m2$changepoints)
  515. })