Ben Letham 8 gadi atpakaļ
vecāks
revīzija
230b2ca6e0

+ 1 - 1
python/fbprophet/models.py

@@ -2,7 +2,7 @@
 # All rights reserved.
 #
 # This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree. An additional grant 
+# LICENSE file in the root directory of this source tree. An additional grant
 # of patent rights can be found in the PATENTS file in the same directory.
 
 from __future__ import absolute_import

+ 33 - 15
python/fbprophet/tests/test_diagnostics.py

@@ -10,21 +10,28 @@ from __future__ import division
 from __future__ import print_function
 from __future__ import unicode_literals
 
-import os
 import numpy as np
 import pandas as pd
 
+# fb-block 1 start
+import os
 from unittest import TestCase
 from fbprophet import Prophet
 from fbprophet import diagnostics
 
+DATA = pd.read_csv(
+    os.path.join(os.path.dirname(__file__), 'data.csv'), parse_dates=['ds']
+).head(100)
+# fb-block 1 end
+# fb-block 2
+
 
 class TestDiagnostics(TestCase):
 
     def __init__(self, *args, **kwargs):
         super(TestDiagnostics, self).__init__(*args, **kwargs)
         # Use first 100 record in data.csv
-        self.__df = pd.read_csv(os.path.join(os.path.dirname(__file__), 'data.csv'), parse_dates=['ds']).head(100)
+        self.__df = DATA
 
     def test_simulated_historical_forecasts(self):
         m = Prophet()
@@ -34,47 +41,55 @@ class TestDiagnostics(TestCase):
             for h in [1, 3]:
                 period = '{} days'.format(p)
                 horizon = '{} days'.format(h)
-                df_shf = diagnostics.simulated_historical_forecasts(m, horizon=horizon, k=k, period=period)
+                df_shf = diagnostics.simulated_historical_forecasts(
+                    m, horizon=horizon, k=k, period=period)
                 # All cutoff dates should be less than ds dates
                 self.assertTrue((df_shf['cutoff'] < df_shf['ds']).all())
                 # The unique size of output cutoff should be equal to 'k'
                 self.assertEqual(len(np.unique(df_shf['cutoff'])), k)
-                self.assertEqual(max(df_shf['ds'] - df_shf['cutoff']), pd.Timedelta(horizon))
+                self.assertEqual(
+                    max(df_shf['ds'] - df_shf['cutoff']),
+                    pd.Timedelta(horizon),
+                )
                 dc = df_shf['cutoff'].diff()
                 dc = dc[dc > pd.Timedelta(0)].min()
                 self.assertTrue(dc >= pd.Timedelta(period))
                 # Each y in df_shf and self.__df with same ds should be equal
                 df_merged = pd.merge(df_shf, self.__df, 'left', on='ds')
-                self.assertAlmostEqual(np.sum((df_merged['y_x'] - df_merged['y_y']) ** 2), 0.0)
+                self.assertAlmostEqual(
+                    np.sum((df_merged['y_x'] - df_merged['y_y']) ** 2), 0.0)
 
     def test_simulated_historical_forecasts_logistic(self):
         m = Prophet(growth='logistic')
         df = self.__df.copy()
         df['cap'] = 40
         m.fit(df)
-        df_shf = diagnostics.simulated_historical_forecasts(m, horizon='3 days', k=2, period='3 days')
+        df_shf = diagnostics.simulated_historical_forecasts(
+            m, horizon='3 days', k=2, period='3 days')
         # All cutoff dates should be less than ds dates
         self.assertTrue((df_shf['cutoff'] < df_shf['ds']).all())
         # The unique size of output cutoff should be equal to 'k'
         self.assertEqual(len(np.unique(df_shf['cutoff'])), 2)
         # Each y in df_shf and self.__df with same ds should be equal
         df_merged = pd.merge(df_shf, df, 'left', on='ds')
-        self.assertAlmostEqual(np.sum((df_merged['y_x'] - df_merged['y_y']) ** 2), 0.0)
+        self.assertAlmostEqual(
+            np.sum((df_merged['y_x'] - df_merged['y_y']) ** 2), 0.0)
 
     def test_simulated_historical_forecasts_default_value_check(self):
         m = Prophet()
         m.fit(self.__df)
         # Default value of period should be equal to 0.5 * horizon
-        df_shf1 = diagnostics.simulated_historical_forecasts(m, horizon='10 days', k=1)
-        df_shf2 = diagnostics.simulated_historical_forecasts(m, horizon='10 days', k=1, period='5 days')
-        self.assertAlmostEqual(((df_shf1 - df_shf2)**2)[['y', 'yhat']].sum().sum(), 0.0)
+        df_shf1 = diagnostics.simulated_historical_forecasts(
+            m, horizon='10 days', k=1)
+        df_shf2 = diagnostics.simulated_historical_forecasts(
+            m, horizon='10 days', k=1, period='5 days')
+        self.assertAlmostEqual(
+            ((df_shf1 - df_shf2)**2)[['y', 'yhat']].sum().sum(), 0.0)
 
     def test_cross_validation(self):
         m = Prophet()
         m.fit(self.__df)
         # Calculate the number of cutoff points(k)
-        te = self.__df['ds'].max()
-        ts = self.__df['ds'].min()
         horizon = pd.Timedelta('4 days')
         period = pd.Timedelta('10 days')
         k = 5
@@ -91,6 +106,9 @@ class TestDiagnostics(TestCase):
         m = Prophet()
         m.fit(self.__df)
         # Default value of initial should be equal to 3 * horizon
-        df_cv1 = diagnostics.cross_validation(m, horizon='32 days', period='10 days')
-        df_cv2 = diagnostics.cross_validation(m, horizon='32 days', period='10 days', initial='96 days')
-        self.assertAlmostEqual(((df_cv1 - df_cv2)**2)[['y', 'yhat']].sum().sum(), 0.0)
+        df_cv1 = diagnostics.cross_validation(
+            m, horizon='32 days', period='10 days')
+        df_cv2 = diagnostics.cross_validation(
+            m, horizon='32 days', period='10 days', initial='96 days')
+        self.assertAlmostEqual(
+            ((df_cv1 - df_cv2)**2)[['y', 'yhat']].sum().sum(), 0.0)

+ 4 - 2
python/fbprophet/tests/test_prophet.py

@@ -10,12 +10,12 @@ from __future__ import division
 from __future__ import print_function
 from __future__ import unicode_literals
 
+import itertools
 import numpy as np
 import pandas as pd
 
 # fb-block 1 start
 import os
-import itertools
 from unittest import TestCase
 from fbprophet import Prophet
 
@@ -551,6 +551,8 @@ class TestProphet(TestCase):
 
     def test_copy(self):
         # These values are created except for its default values
+        holiday = pd.DataFrame(
+            {'ds': pd.to_datetime(['2016-12-25']), 'holiday': ['x']})
         products = itertools.product(
             ['linear', 'logistic'],  # growth
             [None, pd.to_datetime(['2016-12-25'])],  # changepoints
@@ -558,7 +560,7 @@ class TestProphet(TestCase):
             [True, False],  # yearly_seasonality
             [True, False],  # weekly_seasonality
             [True, False],  # daily_seasonality
-            [None, pd.DataFrame({'ds': pd.to_datetime(['2016-12-25']), 'holiday': ['x']})],  # holidays
+            [None, holiday],  # holidays
             [1.1],  # seasonality_prior_scale
             [1.1],  # holidays_prior_scale
             [0.1],  # changepoint_prior_scale