ソースを参照

Functional daily seasonality (#239)

Ben Letham 8 年 前
コミット
825108b226
1 ファイル変更44 行追加19 行削除
  1. 44 19
      python/fbprophet/forecaster.py

+ 44 - 19
python/fbprophet/forecaster.py

@@ -50,8 +50,12 @@ class Prophet(object):
         if input `changepoints` is supplied. If `changepoints` is not supplied,
         then n_changepoints potential changepoints are selected uniformly from
         the first 80 percent of the history.
-    yearly_seasonality: Fit yearly seasonality. Can be 'auto', True, or False.
-    weekly_seasonality: Fit weekly seasonality. Can be 'auto', True, or False.
+    yearly_seasonality: Fit yearly seasonality.
+        Can be 'auto', True, False, or a number of Fourier terms to generate.
+    weekly_seasonality: Fit weekly seasonality.
+        Can be 'auto', True, False, or a number of Fourier terms to generate.
+    daily_seasonality: Fit daily seasonality.
+        Can be 'auto', True, False, or a number of Fourier terms to generate.
     holidays: pd.DataFrame with columns holiday (string) and ds (date type)
         and optionally columns lower_window and upper_window which specify a
         range of days around the date to be included as holidays.
@@ -84,6 +88,7 @@ class Prophet(object):
             n_changepoints=25,
             yearly_seasonality='auto',
             weekly_seasonality='auto',
+            daily_seasonality='auto',
             holidays=None,
             seasonality_prior_scale=10.0,
             holidays_prior_scale=10.0,
@@ -154,8 +159,8 @@ class Prophet(object):
             for h in self.holidays['holiday'].unique():
                 if '_delim_' in h:
                     raise ValueError('Holiday name cannot contain "_delim_"')
-                if h in ['zeros', 'yearly', 'weekly', 'yhat', 'seasonal',
-                         'trend']:
+                if h in ['zeros', 'yearly', 'weekly', 'daily', 'yhat',
+                         'seasonal', 'trend']:
                     raise ValueError('Holiday name {} reserved.'.format(h))
 
     def setup_dataframe(self, df, initialize_scales=False):
@@ -259,8 +264,9 @@ class Prophet(object):
         # convert to days since epoch
         t = np.array(
             (dates - pd.datetime(1970, 1, 1))
-            .dt.total_seconds()/(24*3600)
-        )
+            .dt.total_seconds()
+            .astype(np.float)
+        ) / (3600 * 24.)
         return np.column_stack([
             fun((2.0 * (i + 1) * np.pi * t / period))
             for i in range(series_order)
@@ -354,28 +360,28 @@ class Prophet(object):
         ]
 
         # Seasonality features
-        if self.yearly_seasonality:
+        if self.yearly_seasonality > 0:
             seasonal_features.append(self.make_seasonality_features(
                 df['ds'],
                 365.25,
-                10,
+                self.yearly_seasonality,
                 'yearly',
             ))
 
-        if self.weekly_seasonality:
+        if self.weekly_seasonality > 0:
             seasonal_features.append(self.make_seasonality_features(
                 df['ds'],
                 7,
-                3,
+                self.weekly_seasonality,
                 'weekly',
             ))
 
-        if self.daily_seasonality:
+        if self.daily_seasonality > 0:
             seasonal_features.append(self.make_seasonality_features(
                 df['ds'],
                 1,
-                3,
-                'daily'
+                self.daily_seasonality,
+                'daily',
             ))
 
         if self.holidays is not None:
@@ -393,21 +399,40 @@ class Prophet(object):
         last = self.history['ds'].max()
         if self.yearly_seasonality == 'auto':
             if last - first < pd.Timedelta(days=730):
-                self.yearly_seasonality = False
+                self.yearly_seasonality = 0
                 logger.info('Disabling yearly seasonality. Run prophet with '
-                            'yearly_seasonality=True to override this.')
+                      'yearly_seasonality=True to override this.')
             else:
-                self.yearly_seasonality = True
+                self.yearly_seasonality = 10
+        elif self.yearly_seasonality is True:
+            self.yearly_seasonality = 10
+                
         if self.weekly_seasonality == 'auto':
             dt = self.history['ds'].diff()
             min_dt = dt.iloc[dt.nonzero()[0]].min()
             if ((last - first < pd.Timedelta(weeks=2)) or
                     (min_dt >= pd.Timedelta(weeks=1))):
-                self.weekly_seasonality = False
+                self.weekly_seasonality = 0
                 logger.info('Disabling weekly seasonality. Run prophet with '
-                            'weekly_seasonality=True to override this.')
+                      'weekly_seasonality=True to override this.')
             else:
-                self.weekly_seasonality = True
+                self.weekly_seasonality = 3
+        elif self.weekly_seasonality is True:
+            self.weekly_seasonality = 3
+                
+        if self.daily_seasonality == 'auto':
+            # disabled by default but if the average difference is <1 day
+            # then we assume there is intra-day modeling
+            dt = self.history['ds'].diff()
+            min_dt = dt.iloc[dt.nonzero()[0]].min()
+            if (min_dt< pd.Timedelta(days=1)):
+                self.daily_seasonality = 4
+                logger.info('Enabling daily seasonality. Run prophet with '
+                      'daily_seasonality=False to override this.')
+            else:
+                self.daily_seasonality = 0
+        elif self.daily_seasonality is True:
+            self.daily_seasonality = 4
 
     @staticmethod
     def linear_growth_init(df):