Selaa lähdekoodia

Speed up diagnostics unit tests

bletham 8 vuotta sitten
vanhempi
commit
509666d1d2
1 muutettua tiedostoa jossa 12 lisäystä ja 4 poistoa
  1. 12 4
      python/fbprophet/tests/test_diagnostics.py

+ 12 - 4
python/fbprophet/tests/test_diagnostics.py

@@ -29,7 +29,7 @@ class TestDiagnostics(TestCase):
     def test_simulated_historical_forecasts(self):
         m = Prophet()
         m.fit(self.__df)
-        k = 3
+        k = 2
         for p in [1, 10]:
             for h in [1, 3]:
                 period = '{} days'.format(p)
@@ -39,6 +39,10 @@ class TestDiagnostics(TestCase):
                 self.assertTrue((df_shf['cutoff'] < df_shf['ds']).all())
                 # The unique size of output cutoff should be equal to 'k'
                 self.assertEqual(len(np.unique(df_shf['cutoff'])), k)
+                self.assertEqual(max(df_shf['ds'] - df_shf['cutoff']), pd.Timedelta(horizon))
+                dc = df_shf['cutoff'].diff()
+                dc = dc[dc > pd.Timedelta(0)].min()
+                self.assertTrue(dc >= pd.Timedelta(period))
                 # Each y in df_shf and self.__df with same ds should be equal
                 df_merged = pd.merge(df_shf, self.__df, 'left', on='ds')
                 self.assertAlmostEqual(np.sum((df_merged['y_x'] - df_merged['y_y']) ** 2), 0.0)
@@ -72,17 +76,21 @@ class TestDiagnostics(TestCase):
         te = self.__df['ds'].max()
         ts = self.__df['ds'].min()
         horizon = pd.Timedelta('4 days')
-        period = pd.Timedelta('1 days')
+        period = pd.Timedelta('10 days')
         initial = pd.Timedelta('90 days')
         k = int(np.floor(((te - horizon) - (ts + initial)) / period))
         df_cv = diagnostics.cross_validation(m, horizon=horizon, period=period, initial=initial)
         # The unique size of output cutoff should be equal to 'k'
         self.assertEqual(len(np.unique(df_cv['cutoff'])), k)
+        self.assertEqual(max(df_cv['ds'] - df_cv['cutoff']), horizon)
+        dc = df_cv['cutoff'].diff()
+        dc = dc[dc > pd.Timedelta(0)].min()
+        self.assertTrue(dc >= period)
 
     def test_cross_validation_default_value_check(self):
         m = Prophet()
         m.fit(self.__df)
         # Default value of initial should be equal to 3 * horizon
-        df_cv1 = diagnostics.cross_validation(m, horizon='32 days', period='1 days')
-        df_cv2 = diagnostics.cross_validation(m, horizon='32 days', period='1 days', initial='96 days')
+        df_cv1 = diagnostics.cross_validation(m, horizon='32 days', period='10 days')
+        df_cv2 = diagnostics.cross_validation(m, horizon='32 days', period='10 days', initial='96 days')
         self.assertAlmostEqual(((df_cv1 - df_cv2)**2)[['y', 'yhat']].sum().sum(), 0.0)