瀏覽代碼

Show multiplicative seasonality as percent in plots (Py)

Ben Letham 7 年之前
父節點
當前提交
f1e24d3c2c
共有 3 個文件被更改,包括 34 次插入11 次删除
  1. 13 7
      python/fbprophet/forecaster.py
  2. 18 3
      python/fbprophet/plot.py
  3. 3 1
      python/fbprophet/tests/test_prophet.py

+ 13 - 7
python/fbprophet/forecaster.py

@@ -156,6 +156,7 @@ class Prophet(object):
         self.history = None
         self.history = None
         self.history_dates = None
         self.history_dates = None
         self.train_component_cols = None
         self.train_component_cols = None
+        self.component_modes = None
         self.validate_inputs()
         self.validate_inputs()
 
 
     def validate_inputs(self):
     def validate_inputs(self):
@@ -196,8 +197,8 @@ class Prophet(object):
             raise ValueError('Name cannot contain "_delim_"')
             raise ValueError('Name cannot contain "_delim_"')
         reserved_names = [
         reserved_names = [
             'trend', 'additive_terms', 'daily', 'weekly', 'yearly',
             'trend', 'additive_terms', 'daily', 'weekly', 'yearly',
-            'holidays', 'zeros', 'extra_regressors_additive',
-            'extra_regressors_multiplicative', 'yhat',
+            'holidays', 'zeros', 'extra_regressors_additive','yhat',
+            'extra_regressors_multiplicative', 'multiplicative_terms',
         ]
         ]
         rn_l = [n + '_lower' for n in reserved_names]
         rn_l = [n + '_lower' for n in reserved_names]
         rn_u = [n + '_upper' for n in reserved_names]
         rn_u = [n + '_upper' for n in reserved_names]
@@ -686,6 +687,8 @@ class Prophet(object):
             # Add combination components to modes
             # Add combination components to modes
             modes[mode].append(mode + '_terms')
             modes[mode].append(mode + '_terms')
             modes[mode].append('extra_regressors_' + mode)
             modes[mode].append('extra_regressors_' + mode)
+        # After all of the additive/multiplicative groups have been added,
+        modes[self.seasonality_mode].append('holidays')
         # Convert to a binary matrix
         # Convert to a binary matrix
         component_cols = pd.crosstab(
         component_cols = pd.crosstab(
             components['col'], components['component'],
             components['col'], components['component'],
@@ -724,8 +727,10 @@ class Prophet(object):
         Dataframe with components.
         Dataframe with components.
         """
         """
         new_comp = components[components['component'].isin(set(group))].copy()
         new_comp = components[components['component'].isin(set(group))].copy()
-        new_comp['component'] = name
-        components = components.append(new_comp)
+        group_cols = new_comp['col'].unique()
+        if len(group_cols) > 0:
+            new_comp = pd.DataFrame({'component': name, 'col': group_cols})
+            components = components.append(new_comp)
         return components
         return components
 
 
     def parse_seasonality_args(self, name, arg, auto_disable, default_order):
     def parse_seasonality_args(self, name, arg, auto_disable, default_order):
@@ -920,9 +925,10 @@ class Prophet(object):
         history = self.setup_dataframe(history, initialize_scales=True)
         history = self.setup_dataframe(history, initialize_scales=True)
         self.history = history
         self.history = history
         self.set_auto_seasonalities()
         self.set_auto_seasonalities()
-        seasonal_features, prior_scales, component_cols, _ = (
+        seasonal_features, prior_scales, component_cols, modes = (
             self.make_all_seasonality_features(history))
             self.make_all_seasonality_features(history))
         self.train_component_cols = component_cols
         self.train_component_cols = component_cols
+        self.component_modes = modes
 
 
         self.set_changepoints()
         self.set_changepoints()
 
 
@@ -1131,7 +1137,7 @@ class Prophet(object):
         -------
         -------
         Dataframe with seasonal components.
         Dataframe with seasonal components.
         """
         """
-        seasonal_features, _, component_cols, modes = (
+        seasonal_features, _, component_cols, _ = (
             self.make_all_seasonality_features(df)
             self.make_all_seasonality_features(df)
         )
         )
         lower_p = 100 * (1.0 - self.interval_width) / 2
         lower_p = 100 * (1.0 - self.interval_width) / 2
@@ -1143,7 +1149,7 @@ class Prophet(object):
             beta_c = self.params['beta'] * component_cols[component].values
             beta_c = self.params['beta'] * component_cols[component].values
 
 
             comp = np.matmul(X, beta_c.transpose())
             comp = np.matmul(X, beta_c.transpose())
-            if component in modes['additive']:
+            if component in self.component_modes['additive']:
                  comp *= self.y_scale
                  comp *= self.y_scale
             data[component] = np.nanmean(comp, axis=1)
             data[component] = np.nanmean(comp, axis=1)
             data[component + '_lower'] = np.nanpercentile(
             data[component + '_lower'] = np.nanpercentile(

+ 18 - 3
python/fbprophet/plot.py

@@ -186,6 +186,8 @@ def plot_forecast_component(
     ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
     ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
     ax.set_xlabel('ds')
     ax.set_xlabel('ds')
     ax.set_ylabel(name)
     ax.set_ylabel(name)
+    if name in m.component_modes['multiplicative']:
+        ax = set_y_as_percent(ax)
     return artists
     return artists
 
 
 
 
@@ -246,7 +248,9 @@ def plot_weekly(m, ax=None, uncertainty=True, weekly_start=0):
     ax.set_xticks(range(len(days)))
     ax.set_xticks(range(len(days)))
     ax.set_xticklabels(days)
     ax.set_xticklabels(days)
     ax.set_xlabel('Day of week')
     ax.set_xlabel('Day of week')
-    ax.set_ylabel('weekly ({})'.format(m.seasonalities['weekly']['mode']))
+    ax.set_ylabel('weekly')
+    if m.seasonalities['weekly']['mode'] == 'multiplicative':
+        ax = set_y_as_percent(ax)
     return artists
     return artists
 
 
 
 
@@ -288,7 +292,9 @@ def plot_yearly(m, ax=None, uncertainty=True, yearly_start=0):
         lambda x, pos=None: '{dt:%B} {dt.day}'.format(dt=num2date(x))))
         lambda x, pos=None: '{dt:%B} {dt.day}'.format(dt=num2date(x))))
     ax.xaxis.set_major_locator(months)
     ax.xaxis.set_major_locator(months)
     ax.set_xlabel('Day of year')
     ax.set_xlabel('Day of year')
-    ax.set_ylabel('yearly ({})'.format(m.seasonalities['yearly']['mode']))
+    ax.set_ylabel('yearly')
+    if m.seasonalities['yearly']['mode'] == 'multiplicative':
+        ax = set_y_as_percent(ax)
     return artists
     return artists
 
 
 
 
@@ -338,10 +344,19 @@ def plot_seasonality(m, name, ax=None, uncertainty=True):
     ax.xaxis.set_major_formatter(FuncFormatter(
     ax.xaxis.set_major_formatter(FuncFormatter(
         lambda x, pos=None: fmt_str.format(dt=num2date(x))))
         lambda x, pos=None: fmt_str.format(dt=num2date(x))))
     ax.set_xlabel('ds')
     ax.set_xlabel('ds')
-    ax.set_ylabel('{} ({})'.format(name, m.seasonalities[name]['mode']))
+    ax.set_ylabel('{}'.format(name))
+    if m.seasonalities[name]['mode'] == 'multiplicative':
+        ax = set_y_as_percent(ax)
     return artists
     return artists
 
 
 
 
+def set_y_as_percent(ax):
+    yticks = 100 * ax.get_yticks()
+    yticklabels = ['{0:.4g}%'.format(y) for y in yticks]
+    ax.set_yticklabels(yticklabels)
+    return ax
+
+
 def add_changepoints_to_plot(
 def add_changepoints_to_plot(
     ax, m, fcst, threshold=0.01, cp_color='r', cp_linestyle='--', trend=True,
     ax, m, fcst, threshold=0.01, cp_color='r', cp_linestyle='--', trend=True,
 ):
 ):

+ 3 - 1
python/fbprophet/tests/test_prophet.py

@@ -678,5 +678,7 @@ class TestProphet(TestCase):
         self.assertEqual(
         self.assertEqual(
             set(modes['multiplicative']),
             set(modes['multiplicative']),
             {'weekly', 'yearly', 'xmas', 'numeric_feature',
             {'weekly', 'yearly', 'xmas', 'numeric_feature',
-             'multiplicative_terms', 'extra_regressors_multiplicative'},
+             'multiplicative_terms', 'extra_regressors_multiplicative',
+             'holidays',
+            },
         )
         )