utils.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. # Most of the code taken from https://github.com/EleutherAI/lm-evaluation-harness/blob/cddce0a148ec1710e2d60546c6f92727dd8a78fd/lm_eval/tasks/leaderboard/math/utils.py
  2. import re
  3. import signal
  4. from typing import Dict, List, Optional
  5. import datasets
  6. from lm_eval.utils import eval_logger
  7. try:
  8. import sympy
  9. from sympy.parsing.latex import parse_latex
  10. except ModuleNotFoundError:
  11. raise ModuleNotFoundError(
  12. "`sympy` is required for generating translation task prompt templates. \
  13. please install sympy via pip install lm-eval[math] or pip install -e .[math]",
  14. )
  15. # taken from
  16. # https://github.com/wellecks/lm-evaluation-harness/blob/master/lm_eval/tasks/minerva_math.py
  17. def doc_to_text(doc: dict) -> str:
  18. return doc["input_final_prompts"][0]
  19. def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
  20. def _process_doc(doc: dict) -> dict:
  21. out_doc = {
  22. "problem": doc["input_question"],
  23. "answer": normalize_final_answer(
  24. remove_boxed(last_boxed_only_string(doc["solution"]))
  25. ),
  26. "meta_target": doc["input_correct_responses"]
  27. }
  28. return out_doc
  29. #dataset = dataset.select_columns(["input_question", "input_correct_responses", "input_final_prompts", "is_correct","output_prediction_text"])
  30. dataset = dataset.rename_column("is_correct","previously_is_correct")
  31. return dataset.map(_process_doc)
  32. def process_results(doc: dict, results: List[str]) -> Dict[str, int]:
  33. candidates = results[0]
  34. unnormalized_answer = remove_boxed(last_boxed_only_string(candidates))
  35. answer = normalize_final_answer(unnormalized_answer)
  36. if answer.strip() == doc["answer"].strip() or is_equiv(answer, doc["answer"]):
  37. retval = 1
  38. else:
  39. retval = 0
  40. results = {
  41. "exact_match": retval,
  42. }
  43. return results
  44. def last_boxed_only_string(string: str) -> Optional[str]:
  45. idx = string.rfind("\\boxed")
  46. if "\\boxed " in string:
  47. return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0]
  48. if idx < 0:
  49. idx = string.rfind("\\fbox")
  50. if idx < 0:
  51. return None
  52. i = idx
  53. right_brace_idx = None
  54. num_left_braces_open = 0
  55. while i < len(string):
  56. if string[i] == "{":
  57. num_left_braces_open += 1
  58. if string[i] == "}":
  59. num_left_braces_open -= 1
  60. if num_left_braces_open == 0:
  61. right_brace_idx = i
  62. break
  63. i += 1
  64. if right_brace_idx is None:
  65. retval = None
  66. else:
  67. retval = string[idx : right_brace_idx + 1]
  68. return retval
  69. def remove_boxed(s: str) -> str:
  70. if "\\boxed " in s:
  71. left = "\\boxed "
  72. assert s[: len(left)] == left
  73. return s[len(left) :]
  74. left = "\\boxed{"
  75. assert s[: len(left)] == left
  76. assert s[-1] == "}"
  77. return s[len(left) : -1]
  78. class timeout:
  79. def __init__(self, seconds=1, error_message="Timeout"):
  80. self.seconds = seconds
  81. self.error_message = error_message
  82. def handle_timeout(self, signum, frame):
  83. raise TimeoutError(self.error_message)
  84. def __enter__(self):
  85. signal.signal(signal.SIGALRM, self.handle_timeout)
  86. signal.alarm(self.seconds)
  87. def __exit__(self, type, value, traceback):
  88. signal.alarm(0)
  89. def is_equiv(x1: str, x2: str) -> bool:
  90. """
  91. x1 and x2 are normalized latex string
  92. """
  93. try:
  94. with timeout(seconds=5):
  95. try:
  96. parsed_x1 = parse_latex(x1)
  97. parsed_x2 = parse_latex(x2)
  98. except (
  99. sympy.parsing.latex.errors.LaTeXParsingError,
  100. sympy.SympifyError,
  101. TypeError,
  102. ):
  103. eval_logger.debug(f"couldn't parse one of {x1} or {x2}")
  104. return False
  105. try:
  106. diff = parsed_x1 - parsed_x2
  107. except TypeError:
  108. eval_logger.debug(f"couldn't subtract {x1} and {x2}")
  109. return False
  110. try:
  111. if sympy.simplify(diff) == 0:
  112. return True
  113. else:
  114. return False
  115. except ValueError:
  116. eval_logger.debug(
  117. f"Had some trouble simplifying when comparing {x1} and {x2}"
  118. )
  119. except TimeoutError:
  120. eval_logger.debug(f"Timed out comparing {x1} and {x2}")
  121. return False
  122. except ImportError as e:
  123. eval_logger.error(e)
  124. raise
  125. except Exception as e:
  126. eval_logger.debug(f"Failed comparing {x1} and {x2} with {e}")
  127. return False
  128. def get_unnormalized_answer(text: str) -> str:
  129. INVALID_ANSWER = "[invalidanswer]"
  130. end_seq = "I hope it is correct."
  131. text += end_seq
  132. match = re.search(
  133. r"Final Answer: The final answer is(.*?). I hope it is correct.",
  134. text,
  135. )
  136. if match:
  137. return match.group(1).strip()
  138. else:
  139. return INVALID_ANSWER
  140. SUBSTITUTIONS = [
  141. ("an ", ""),
  142. ("a ", ""),
  143. (".$", "$"),
  144. ("\\$", ""),
  145. (r"\ ", ""),
  146. (" ", ""),
  147. ("mbox", "text"),
  148. (",\\text{and}", ","),
  149. ("\\text{and}", ","),
  150. ("\\text{m}", "\\text{}"),
  151. ]
  152. REMOVED_EXPRESSIONS = [
  153. "square",
  154. "ways",
  155. "integers",
  156. "dollars",
  157. "mph",
  158. "inches",
  159. "ft",
  160. "hours",
  161. "km",
  162. "units",
  163. "\\ldots",
  164. "sue",
  165. "points",
  166. "feet",
  167. "minutes",
  168. "digits",
  169. "cents",
  170. "degrees",
  171. "cm",
  172. "gm",
  173. "pounds",
  174. "meters",
  175. "meals",
  176. "edges",
  177. "students",
  178. "childrentickets",
  179. "multiples",
  180. "\\text{s}",
  181. "\\text{.}",
  182. "\\text{\ns}",
  183. "\\text{}^2",
  184. "\\text{}^3",
  185. "\\text{\n}",
  186. "\\text{}",
  187. r"\mathrm{th}",
  188. r"^\circ",
  189. r"^{\circ}",
  190. r"\;",
  191. r",\!",
  192. "{,}",
  193. '"',
  194. "\\dots",
  195. ]
  196. def normalize_final_answer(final_answer: str) -> str:
  197. """
  198. Normalize a final answer to a quantitative reasoning question.
  199. Copied character for character from appendix D of Lewkowycz et al. (2022)
  200. """
  201. final_answer = final_answer.split("=")[-1]
  202. for before, after in SUBSTITUTIONS:
  203. final_answer = final_answer.replace(before, after)
  204. for expr in REMOVED_EXPRESSIONS:
  205. final_answer = final_answer.replace(expr, "")
  206. # Extract answer that is in LaTeX math, is bold,
  207. # is surrounded by a box, etc.
  208. final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer)
  209. final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer)
  210. final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer)
  211. final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer)
  212. final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer)
  213. # Normalize shorthand TeX:
  214. # \fracab -> \frac{a}{b}
  215. # \frac{abc}{bef} -> \frac{abc}{bef}
  216. # \fracabc -> \frac{a}{b}c
  217. # \sqrta -> \sqrt{a}
  218. # \sqrtab -> sqrt{a}b
  219. final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer)
  220. final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer)
  221. final_answer = final_answer.replace("$", "")
  222. # Normalize 100,000 -> 100000
  223. if final_answer.replace(",", "").isdigit():
  224. final_answer = final_answer.replace(",", "")
  225. return final_answer