test_diagnostics.py 4.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the BSD-style license found in the
  5. # LICENSE file in the root directory of this source tree. An additional grant
  6. # of patent rights can be found in the PATENTS file in the same directory.
  7. from __future__ import absolute_import
  8. from __future__ import division
  9. from __future__ import print_function
  10. from __future__ import unicode_literals
  11. import os
  12. import numpy as np
  13. import pandas as pd
  14. from unittest import TestCase
  15. from fbprophet import Prophet
  16. from fbprophet import diagnostics
  17. class TestDiagnostics(TestCase):
  18. def __init__(self, *args, **kwargs):
  19. super(TestDiagnostics, self).__init__(*args, **kwargs)
  20. # Use first 100 record in data.csv
  21. self.__df = pd.read_csv(os.path.join(os.path.dirname(__file__), 'data.csv'), parse_dates=['ds']).head(100)
  22. def test_simulated_historical_forecasts(self):
  23. m = Prophet()
  24. m.fit(self.__df)
  25. k = 2
  26. for p in [1, 10]:
  27. for h in [1, 3]:
  28. period = '{} days'.format(p)
  29. horizon = '{} days'.format(h)
  30. df_shf = diagnostics.simulated_historical_forecasts(m, horizon=horizon, k=k, period=period)
  31. # All cutoff dates should be less than ds dates
  32. self.assertTrue((df_shf['cutoff'] < df_shf['ds']).all())
  33. # The unique size of output cutoff should be equal to 'k'
  34. self.assertEqual(len(np.unique(df_shf['cutoff'])), k)
  35. self.assertEqual(max(df_shf['ds'] - df_shf['cutoff']), pd.Timedelta(horizon))
  36. dc = df_shf['cutoff'].diff()
  37. dc = dc[dc > pd.Timedelta(0)].min()
  38. self.assertTrue(dc >= pd.Timedelta(period))
  39. # Each y in df_shf and self.__df with same ds should be equal
  40. df_merged = pd.merge(df_shf, self.__df, 'left', on='ds')
  41. self.assertAlmostEqual(np.sum((df_merged['y_x'] - df_merged['y_y']) ** 2), 0.0)
  42. def test_simulated_historical_forecasts_logistic(self):
  43. m = Prophet(growth='logistic')
  44. df = self.__df.copy()
  45. df['cap'] = 40
  46. m.fit(df)
  47. df_shf = diagnostics.simulated_historical_forecasts(m, horizon='3 days', k=2, period='3 days')
  48. # All cutoff dates should be less than ds dates
  49. self.assertTrue((df_shf['cutoff'] < df_shf['ds']).all())
  50. # The unique size of output cutoff should be equal to 'k'
  51. self.assertEqual(len(np.unique(df_shf['cutoff'])), 2)
  52. # Each y in df_shf and self.__df with same ds should be equal
  53. df_merged = pd.merge(df_shf, df, 'left', on='ds')
  54. self.assertAlmostEqual(np.sum((df_merged['y_x'] - df_merged['y_y']) ** 2), 0.0)
  55. def test_simulated_historical_forecasts_default_value_check(self):
  56. m = Prophet()
  57. m.fit(self.__df)
  58. # Default value of period should be equal to 0.5 * horizon
  59. df_shf1 = diagnostics.simulated_historical_forecasts(m, horizon='10 days', k=1)
  60. df_shf2 = diagnostics.simulated_historical_forecasts(m, horizon='10 days', k=1, period='5 days')
  61. self.assertAlmostEqual(((df_shf1 - df_shf2)**2)[['y', 'yhat']].sum().sum(), 0.0)
  62. def test_cross_validation(self):
  63. m = Prophet()
  64. m.fit(self.__df)
  65. # Calculate the number of cutoff points(k)
  66. te = self.__df['ds'].max()
  67. ts = self.__df['ds'].min()
  68. horizon = pd.Timedelta('4 days')
  69. period = pd.Timedelta('10 days')
  70. initial = pd.Timedelta('90 days')
  71. k = int(np.floor(((te - horizon) - (ts + initial)) / period))
  72. df_cv = diagnostics.cross_validation(m, horizon=horizon, period=period, initial=initial)
  73. # The unique size of output cutoff should be equal to 'k'
  74. self.assertEqual(len(np.unique(df_cv['cutoff'])), k)
  75. self.assertEqual(max(df_cv['ds'] - df_cv['cutoff']), horizon)
  76. dc = df_cv['cutoff'].diff()
  77. dc = dc[dc > pd.Timedelta(0)].min()
  78. self.assertTrue(dc >= period)
  79. def test_cross_validation_default_value_check(self):
  80. m = Prophet()
  81. m.fit(self.__df)
  82. # Default value of initial should be equal to 3 * horizon
  83. df_cv1 = diagnostics.cross_validation(m, horizon='32 days', period='10 days')
  84. df_cv2 = diagnostics.cross_validation(m, horizon='32 days', period='10 days', initial='96 days')
  85. self.assertAlmostEqual(((df_cv1 - df_cv2)**2)[['y', 'yhat']].sum().sum(), 0.0)