Procházet zdrojové kódy

Add RMSE as cross validation metric

Ben Letham před 7 roky
rodič
revize
7179ae3a38

+ 9 - 2
python/fbprophet/diagnostics.py

@@ -202,6 +202,7 @@ def performance_metrics(df, metrics=None, rolling_window=0.05):
     Computes a suite of performance metrics on the output of cross-validation.
     By default the following metrics are included:
     'mse': mean squared error
+    'rmse': root mean squared error
     'mae': mean absolute error
     'mape': mean percent error
     'coverage': coverage of the upper and lower intervals
@@ -226,7 +227,7 @@ def performance_metrics(df, metrics=None, rolling_window=0.05):
     ----------
     df: The dataframe returned by cross_validation.
     metrics: A list of performance metrics to compute. If not provided, will
-        use ['mse', 'mae', 'mape', 'coverage'].
+        use ['mse', 'mae', 'mape', 'coverage', 'rmse'].
     rolling_window: Proportion of data to use in each rolling window for
         computing the metrics.
 
@@ -234,7 +235,7 @@ def performance_metrics(df, metrics=None, rolling_window=0.05):
     -------
     Dataframe with a column for each metric, and column 'horizon'
     """
-    valid_metrics = ['mse', 'mae', 'mape', 'coverage']
+    valid_metrics = ['mse', 'rmse', 'mae', 'mape', 'coverage']
     if metrics is None:
         metrics = valid_metrics
     if len(set(metrics)) != len(metrics):
@@ -277,6 +278,12 @@ def mse(df, w):
     return rolling_mean(se.values, w)
 
 
+def rmse(df, w):
+    """Root mean squared error
+    """
+    return np.sqrt(mse(df, w))
+
+
 def mae(df, w):
     """Mean absolute error
     """

+ 1 - 1
python/fbprophet/tests/test_diagnostics.py

@@ -145,7 +145,7 @@ class TestDiagnostics(TestCase):
         df_none = diagnostics.performance_metrics(df_cv, rolling_window=0)
         self.assertEqual(
             set(df_none.columns),
-            {'horizon', 'coverage', 'mae', 'mape', 'mse'},
+            {'horizon', 'coverage', 'mae', 'mape', 'mse', 'rmse'},
         )
         self.assertEqual(df_none.shape[0], 14)
         # Aggregation level 0.2