make_holidays.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the BSD-style license found in the
  5. # LICENSE file in the root directory of this source tree. An additional grant
  6. # of patent rights can be found in the PATENTS file in the same directory.
  7. from __future__ import absolute_import
  8. from __future__ import division
  9. from __future__ import print_function
  10. from __future__ import unicode_literals
  11. import pandas as pd
  12. import numpy as np
  13. import warnings
  14. import holidays as hdays_part1
  15. import fbprophet.hdays as hdays_part2
  16. def get_holiday_names(country):
  17. """Return all possible holiday names of given country
  18. Parameters
  19. ----------
  20. country: country name
  21. Returns
  22. -------
  23. A set of all possible holiday names of given country
  24. """
  25. years = np.arange(1995, 2045)
  26. try:
  27. with warnings.catch_warnings():
  28. warnings.simplefilter("ignore")
  29. holiday_names = getattr(hdays_part2, country)(years=years).values()
  30. except AttributeError:
  31. try:
  32. holiday_names = getattr(hdays_part1, country)(years=years).values()
  33. except AttributeError:
  34. raise AttributeError(
  35. "Holidays in {} are not currently supported!".format(country))
  36. return set(holiday_names)
  37. def make_holidays_df(year_list, country):
  38. """Make dataframe of holidays for given years and countries
  39. Parameters
  40. ----------
  41. year_list: a list of years
  42. country: country name
  43. Returns
  44. -------
  45. Dataframe with 'ds' and 'holiday', which can directly feed
  46. to 'holidays' params in Prophet
  47. """
  48. try:
  49. holidays = getattr(hdays_part2, country)(years=year_list)
  50. except AttributeError:
  51. try:
  52. holidays = getattr(hdays_part1, country)(years=year_list)
  53. except AttributeError:
  54. raise AttributeError(
  55. "Holidays in {} are not currently supported!".format(country))
  56. holidays_df = pd.DataFrame(list(holidays.items()), columns=['ds', 'holiday'])
  57. holidays_df.reset_index(inplace=True, drop=True)
  58. holidays_df['ds'] = pd.to_datetime(holidays_df['ds'])
  59. return (holidays_df)