diagnostics.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the BSD-style license found in the
  5. # LICENSE file in the root directory of this source tree. An additional grant
  6. # of patent rights can be found in the PATENTS file in the same directory.
  7. from __future__ import absolute_import
  8. from __future__ import division
  9. from __future__ import print_function
  10. from __future__ import unicode_literals
  11. import logging
  12. logger = logging.getLogger(__name__)
  13. import numpy as np
  14. import pandas as pd
  15. from functools import reduce
  16. def _cutoffs(df, horizon, k, period):
  17. """Generate cutoff dates
  18. Parameters
  19. ----------
  20. df: pd.DataFrame with historical data
  21. horizon: pd.Timedelta.
  22. Forecast horizon
  23. k: Int number.
  24. The number of forecasts point.
  25. period: pd.Timedelta.
  26. Simulated Forecast will be done at every this period.
  27. Returns
  28. -------
  29. list of pd.Timestamp
  30. """
  31. # Last cutoff is 'latest date in data - horizon' date
  32. cutoff = df['ds'].max() - horizon
  33. if cutoff < df['ds'].min():
  34. raise ValueError('Less data than horizon.')
  35. result = [cutoff]
  36. for i in range(1, k):
  37. cutoff -= period
  38. # If data does not exist in data range (cutoff, cutoff + horizon]
  39. if not (((df['ds'] > cutoff) & (df['ds'] <= cutoff + horizon)).any()):
  40. # Next cutoff point is 'last date before cutoff in data - horizon'
  41. closest_date = df[df['ds'] <= cutoff].max()['ds']
  42. cutoff = closest_date - horizon
  43. if cutoff < df['ds'].min():
  44. logger.warning(
  45. 'Not enough data for requested number of cutoffs! '
  46. 'Using {}.'.format(i))
  47. break
  48. result.append(cutoff)
  49. # Sort lines in ascending order
  50. return reversed(result)
  51. def simulated_historical_forecasts(model, horizon, k, period=None):
  52. """Simulated Historical Forecasts.
  53. Make forecasts from k historical cutoff points, working backwards from
  54. (end - horizon) with a spacing of period between each cutoff.
  55. Parameters
  56. ----------
  57. model: Prophet class object.
  58. Fitted Prophet model
  59. horizon: string with pd.Timedelta compatible style, e.g., '5 days',
  60. '3 hours', '10 seconds'.
  61. k: Int number of forecasts point.
  62. period: Optional string with pd.Timedelta compatible style. Simulated
  63. forecast will be done at every this period. If not provided,
  64. 0.5 * horizon is used.
  65. Returns
  66. -------
  67. A pd.DataFrame with the forecast, actual value and cutoff.
  68. """
  69. df = model.history.copy().reset_index(drop=True)
  70. horizon = pd.Timedelta(horizon)
  71. period = 0.5 * horizon if period is None else pd.Timedelta(period)
  72. cutoffs = _cutoffs(df, horizon, k, period)
  73. predicts = []
  74. for cutoff in cutoffs:
  75. # Generate new object with copying fitting options
  76. m = model.copy(cutoff)
  77. # Train model
  78. m.fit(df[df['ds'] <= cutoff])
  79. # Calculate yhat
  80. index_predicted = (df['ds'] > cutoff) & (df['ds'] <= cutoff + horizon)
  81. columns = ['ds']
  82. if m.growth == 'logistic':
  83. columns.append('cap')
  84. if m.logistic_floor:
  85. columns.append('floor')
  86. yhat = m.predict(df[index_predicted][columns])
  87. # Merge yhat(predicts), y(df, original data) and cutoff
  88. predicts.append(pd.concat([
  89. yhat[['ds', 'yhat', 'yhat_lower', 'yhat_upper']],
  90. df[index_predicted][['y']].reset_index(drop=True),
  91. pd.DataFrame({'cutoff': [cutoff] * len(yhat)})
  92. ], axis=1))
  93. # Combine all predicted pd.DataFrame into one pd.DataFrame
  94. return reduce(lambda x, y: x.append(y), predicts).reset_index(drop=True)
  95. def cross_validation(model, horizon, period=None, initial=None):
  96. """Cross-Validation for time series.
  97. Computes forecasts from historical cutoff points. Beginning from initial,
  98. makes cutoffs with a spacing of period up to (end - horizon).
  99. When period is equal to the time interval of the data, this is the
  100. technique described in https://robjhyndman.com/hyndsight/tscv/ .
  101. Parameters
  102. ----------
  103. model: Prophet class object. Fitted Prophet model
  104. horizon: string with pd.Timedelta compatible style, e.g., '5 days',
  105. '3 hours', '10 seconds'.
  106. period: string with pd.Timedelta compatible style. Simulated forecast will
  107. be done at every this period. If not provided, 0.5 * horizon is used.
  108. initial: string with pd.Timedelta compatible style. The first training
  109. period will begin here. If not provided, 3 * horizon is used.
  110. Returns
  111. -------
  112. A pd.DataFrame with the forecast, actual value and cutoff.
  113. """
  114. te = model.history['ds'].max()
  115. ts = model.history['ds'].min()
  116. horizon = pd.Timedelta(horizon)
  117. period = 0.5 * horizon if period is None else pd.Timedelta(period)
  118. initial = 3 * horizon if initial is None else pd.Timedelta(initial)
  119. k = int(np.ceil(((te - horizon) - (ts + initial)) / period))
  120. if k < 1:
  121. raise ValueError(
  122. 'Not enough data for specified horizon, period, and initial.')
  123. return simulated_historical_forecasts(model, horizon, k, period)