Просмотр исходного кода

Implement cross-validation of time series(a rolling forecast origin) (#261)

* Resolve conflict

* Change comments and add error column to output DataFrame

* Change file structure

* Update

* Modified diagnostics

* Update diagnostics.py following the advice on Github

* Add tests and documentation

* Change copy method into Prophet class and reflect comments
Nagi Teramo 8 лет назад
Родитель
Сommit
79d0793ce4

+ 131 - 0
python/fbprophet/diagnostics.py

@@ -0,0 +1,131 @@
+# Copyright (c) 2017-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree. An additional grant
+# of patent rights can be found in the PATENTS file in the same directory.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+import numpy as np
+import pandas as pd
+from functools import reduce
+
+
+def _cutoffs(df, horizon, k, period):
+    """Generate cutoff dates
+
+    Parameters
+    ----------
+    df: pd.DataFrame with historical data
+    horizon: pd.Timedelta.
+        Forecast horizon
+    k: Int number.
+        The number of forecasts point.
+    period: pd.Timedelta.
+        Simulated Forecast will be done at every this period.
+
+    Returns
+    -------
+    list of pd.Timestamp
+    """
+    # Last cutoff is 'latest date in data - horizon' date
+    cutoff = df['ds'].max() - horizon
+    result = [cutoff]
+
+    for i in range(1, k):
+        cutoff -= period
+        # If data does not exist in data range (cutoff, cutoff + horizon]
+        if not (((df['ds'] > cutoff) & (df['ds'] <= cutoff + horizon)).any()):
+            # Next cutoff point is 'closest date before cutoff in data - horizon'
+            closest_date = df[df['ds'] <= cutoff].max()['ds']
+            cutoff = closest_date - horizon
+        if cutoff < df['ds'].min():
+            logger.warning('Not enough data for requested number of cutoffs! Using {}.'.format(k))
+            break
+        result.append(cutoff)
+
+    # Sort lines in ascending order
+    return reversed(result)
+
+
+def simulated_historical_forecasts(model, horizon, k, period=None):
+    """Simulated Historical Forecasts.
+        If you would like to know it in detail, read the original paper
+        https://facebookincubator.github.io/prophet/static/prophet_paper_20170113.pdf
+
+    Parameters
+    ----------
+    model: Prophet class object.
+        Fitted Prophet model
+    horizon: string which has pd.Timedelta compatible style.
+        Forecast horizon ('5 days', '3 hours', '10 seconds' etc)
+    k: Int number.
+        The number of forecasts point.
+    period: string which has pd.Timedelta compatible style or None, default None.
+        Simulated Forecast will be done at every this period.
+        0.5 * horizon is used when it is None.
+
+    Returns
+    -------
+    A pd.DataFrame with the forecast, actual value and cutoff.
+    """
+    df = model.history.copy().reset_index(drop=True)
+    horizon = pd.Timedelta(horizon)
+    period = 0.5 * horizon if period is None else pd.Timedelta(period)
+    cutoffs = _cutoffs(df, horizon, k, period)
+    predicts = []
+    for cutoff in cutoffs:
+        # Generate new object with copying fitting options
+        m = model.copy(cutoff)
+        # Train model
+        m.fit(df[df['ds'] <= cutoff])
+        # Calculate yhat
+        index_predicted = (df['ds'] > cutoff) & (df['ds'] <= cutoff + horizon)
+        columns = ['ds'] + (['cap'] if m.growth == 'logistic' else [])
+        yhat = m.predict(df[index_predicted][columns])
+        # Merge yhat(predicts), y(df, original data) and cutoff
+        predicts.append(pd.concat([
+            yhat[['ds', 'yhat', 'yhat_lower', 'yhat_upper']],
+            df[index_predicted][['y']].reset_index(drop=True),
+            pd.DataFrame({'cutoff': [cutoff] * len(yhat)})
+        ], axis=1))
+
+    # Combine all predicted pd.DataFrame into one pd.DataFrame
+    return reduce(lambda x, y: x.append(y), predicts).reset_index(drop=True)
+
+
+def cross_validation(model, horizon, period, initial=None):
+    """Cross-Validation for time-series.
+        This function is the same with Time series cross-validation described in https://robjhyndman.com/hyndsight/tscv/
+        when the value of period is equal to the time interval of data.
+
+    Parameters
+    ----------
+    model: Prophet class object. Fitted Prophet model
+    horizon: string which has pd.Timedelta compatible style.
+        Forecast horizon ('5 days', '3 hours', '10 seconds' etc)
+    period: string which has pd.Timedelta compatible style.
+        Simulated Forecast will be done at every this period.
+    initial: string which has pd.Timedelta compatible style or None, default None.
+        First training period.
+        3 * horizon is used when it is None.
+
+    Returns
+    -------
+    A pd.DataFrame with the forecast, actual value and cutoff.
+    """
+    te = model.history['ds'].max()
+    ts = model.history['ds'].min()
+    horizon = pd.Timedelta(horizon)
+    period = pd.Timedelta(period)
+    initial = 3 * horizon if initial is None else pd.Timedelta(initial)
+    k = int(np.floor(((te - horizon) - (ts + initial)) / period))
+    return simulated_historical_forecasts(model, horizon, k, period)

+ 36 - 1
python/fbprophet/forecaster.py

@@ -36,7 +36,6 @@ except ImportError:
 # fb-block 2
 
 
-
 class Prophet(object):
     """Prophet forecaster.
 
@@ -1395,3 +1394,39 @@ class Prophet(object):
         ax.set_xlabel('ds')
         ax.set_ylabel(name)
         return artists
+
+    def copy(self, cutoff=None):
+        """Copy Prophet object
+
+        Parameters
+        ----------
+        cutoff: pd.Timestamp or None, default None.
+            cuttoff Timestamp for changepoints member variable.
+            changepoints are only remained if 'changepoints <= cutoff'
+
+        Returns
+        -------
+        Prophet class object with the same parameter with model variable
+        """
+        if self.changepoints is not None and cutoff is not None:
+            # Filter change points '<= cutoff'
+            self.changepoints = self.changepoints[self.changepoints <= cutoff]
+            self.n_changepoints = len(self.changepoints)
+
+        return Prophet(
+            growth=self.growth,
+            n_changepoints=self.n_changepoints,
+            changepoints=self.changepoints,
+            yearly_seasonality=self.yearly_seasonality,
+            weekly_seasonality=self.weekly_seasonality,
+            daily_seasonality=self.daily_seasonality,
+            holidays=self.holidays,
+            seasonality_prior_scale=self.seasonality_prior_scale,
+            changepoint_prior_scale=self.changepoint_prior_scale,
+            holidays_prior_scale=self.holidays_prior_scale,
+            mcmc_samples=self.mcmc_samples,
+            interval_width=self.interval_width,
+            uncertainty_samples=self.uncertainty_samples
+        )
+
+

+ 88 - 0
python/fbprophet/tests/test_diagnostics.py

@@ -0,0 +1,88 @@
+# Copyright (c) 2017-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree. An additional grant
+# of patent rights can be found in the PATENTS file in the same directory.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import os
+import numpy as np
+import pandas as pd
+
+from unittest import TestCase
+from fbprophet import Prophet
+from fbprophet import diagnostics
+
+
+class TestDiagnostics(TestCase):
+
+    def __init__(self, *args, **kwargs):
+        super(TestDiagnostics, self).__init__(*args, **kwargs)
+        # Use first 100 record in data.csv
+        self.__df = pd.read_csv(os.path.join(os.path.dirname(__file__), 'data.csv'), parse_dates=['ds']).head(100)
+
+    def test_simulated_historical_forecasts(self):
+        m = Prophet()
+        m.fit(self.__df)
+        k = 3
+        for p in [1, 10]:
+            for h in [1, 3]:
+                period = '{} days'.format(p)
+                horizon = '{} days'.format(h)
+                df_shf = diagnostics.simulated_historical_forecasts(m, horizon=horizon, k=k, period=period)
+                # All cutoff dates should be less than ds dates
+                self.assertTrue((df_shf['cutoff'] < df_shf['ds']).all())
+                # The unique size of output cutoff should be equal to 'k'
+                self.assertEqual(len(np.unique(df_shf['cutoff'])), k)
+                # Each y in df_shf and self.__df with same ds should be equal
+                df_merged = pd.merge(df_shf, self.__df, 'left', on='ds')
+                self.assertAlmostEqual(np.sum((df_merged['y_x'] - df_merged['y_y']) ** 2), 0.0)
+
+    def test_simulated_historical_forecasts_logistic(self):
+        m = Prophet(growth='logistic')
+        df = self.__df.copy()
+        df['cap'] = 40
+        m.fit(df)
+        df_shf = diagnostics.simulated_historical_forecasts(m, horizon='3 days', k=2, period='3 days')
+        # All cutoff dates should be less than ds dates
+        self.assertTrue((df_shf['cutoff'] < df_shf['ds']).all())
+        # The unique size of output cutoff should be equal to 'k'
+        self.assertEqual(len(np.unique(df_shf['cutoff'])), 2)
+        # Each y in df_shf and self.__df with same ds should be equal
+        df_merged = pd.merge(df_shf, df, 'left', on='ds')
+        self.assertAlmostEqual(np.sum((df_merged['y_x'] - df_merged['y_y']) ** 2), 0.0)
+
+    def test_simulated_historical_forecasts_default_value_check(self):
+        m = Prophet()
+        m.fit(self.__df)
+        # Default value of period should be equal to 0.5 * horizon
+        df_shf1 = diagnostics.simulated_historical_forecasts(m, horizon='10 days', k=1)
+        df_shf2 = diagnostics.simulated_historical_forecasts(m, horizon='10 days', k=1, period='5 days')
+        self.assertAlmostEqual(((df_shf1 - df_shf2)**2)[['y', 'yhat']].sum().sum(), 0.0)
+
+    def test_cross_validation(self):
+        m = Prophet()
+        m.fit(self.__df)
+        # Calculate the number of cutoff points(k)
+        te = self.__df['ds'].max()
+        ts = self.__df['ds'].min()
+        horizon = pd.Timedelta('4 days')
+        period = pd.Timedelta('1 days')
+        initial = pd.Timedelta('90 days')
+        k = int(np.floor(((te - horizon) - (ts + initial)) / period))
+        df_cv = diagnostics.cross_validation(m, horizon=horizon, period=period, initial=initial)
+        # The unique size of output cutoff should be equal to 'k'
+        self.assertEqual(len(np.unique(df_cv['cutoff'])), k)
+
+    def test_cross_validation_default_value_check(self):
+        m = Prophet()
+        m.fit(self.__df)
+        # Default value of initial should be equal to 3 * horizon
+        df_cv1 = diagnostics.cross_validation(m, horizon='32 days', period='1 days')
+        df_cv2 = diagnostics.cross_validation(m, horizon='32 days', period='1 days', initial='96 days')
+        self.assertAlmostEqual(((df_cv1 - df_cv2)**2)[['y', 'yhat']].sum().sum(), 0.0)

+ 47 - 0
python/fbprophet/tests/test_prophet.py

@@ -15,6 +15,7 @@ import pandas as pd
 
 # fb-block 1 start
 import os
+import itertools
 from unittest import TestCase
 from fbprophet import Prophet
 
@@ -421,3 +422,49 @@ class TestProphet(TestCase):
             fcst['yhat'][0],
             fcst['trend'][0] + fcst['seasonal'][0],
         )
+
+    def test_copy(self):
+        # These values are created except for its default values
+        products = itertools.product(
+            ['linear', 'logistic'],  # growth
+            [None, pd.to_datetime(['2016-12-25'])],  # changepoints
+            [3],  # n_changepoints
+            [True, False],  # yearly_seasonality
+            [True, False],  # weekly_seasonality
+            [True, False],  # daily_seasonality
+            [None, pd.DataFrame({'ds': pd.to_datetime(['2016-12-25']), 'holiday': ['x']})],  # holidays
+            [1.1],  # seasonality_prior_scale
+            [1.1],  # holidays_prior_scale
+            [0.1],  # changepoint_prior_scale
+            [100],  # mcmc_samples
+            [0.9],  # interval_width
+            [200]  # uncertainty_samples
+        )
+        # Values should be copied correctly
+        for product in products:
+            m1 = Prophet(*product)
+            m2 = m1.copy()
+            self.assertEqual(m1.growth, m2.growth)
+            self.assertEqual(m1.n_changepoints, m2.n_changepoints)
+            self.assertEqual(m1.changepoints, m2.changepoints)
+            self.assertEqual(m1.yearly_seasonality, m2.yearly_seasonality)
+            self.assertEqual(m1.weekly_seasonality, m2.weekly_seasonality)
+            self.assertEqual(m1.daily_seasonality, m2.daily_seasonality)
+            if m1.holidays is None:
+                self.assertEqual(m1.holidays, m2.holidays)
+            else:
+                self.assertTrue((m1.holidays == m2.holidays).values.all())
+            self.assertEqual(m1.seasonality_prior_scale, m2.seasonality_prior_scale)
+            self.assertEqual(m1.changepoint_prior_scale, m2.changepoint_prior_scale)
+            self.assertEqual(m1.holidays_prior_scale, m2.holidays_prior_scale)
+            self.assertEqual(m1.mcmc_samples, m2.mcmc_samples)
+            self.assertEqual(m1.interval_width, m2.interval_width)
+            self.assertEqual(m1.uncertainty_samples, m2.uncertainty_samples)
+
+        # Check for cutoff
+        changepoints = pd.date_range('2016-12-15', '2017-01-15')
+        cutoff = pd.Timestamp('2016-12-25')
+        m1 = Prophet(changepoints=changepoints)
+        m2 = m1.copy(cutoff=cutoff)
+        changepoints = changepoints[changepoints <= cutoff]
+        self.assertTrue((changepoints == m2.changepoints).all())

+ 2 - 1
python/setup.py

@@ -94,6 +94,7 @@ class TestCommand(test_command):
             sys.modules.update(old_modules)
             working_set.__init__()
 
+
 setup(
     name='fbprophet',
     version='0.1.1',
@@ -117,7 +118,7 @@ setup(
         'develop': DevelopCommand,
         'test': TestCommand,
     },
-    test_suite='fbprophet.tests.test_prophet',
+    test_suite='fbprophet.tests',
     long_description="""
 Implements a procedure for forecasting time series data based on an additive model where non-linear trends are fit with yearly and weekly seasonality, plus holidays.  It works best with daily periodicity data with at least one year of historical data.  Prophet is robust to missing data, shifts in the trend, and large outliers.
 """