utils.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. # Most of the code taken from https://github.com/EleutherAI/lm-evaluation-harness/blob/cddce0a148ec1710e2d60546c6f92727dd8a78fd/lm_eval/tasks/leaderboard/math/utils.py
  2. import re
  3. import signal
  4. from typing import Dict, List, Optional
  5. import datasets
  6. from lm_eval.utils import eval_logger
  7. try:
  8. import sympy
  9. from sympy.parsing.latex import parse_latex
  10. except ModuleNotFoundError:
  11. raise ModuleNotFoundError(
  12. "`sympy` is required for generating translation task prompt templates. \
  13. please install sympy via pip install lm-eval[math] or pip install -e .[math]",
  14. )
  15. # taken from
  16. # https://github.com/wellecks/lm-evaluation-harness/blob/master/lm_eval/tasks/minerva_math.py
  17. def doc_to_text(doc: dict) -> str:
  18. return doc["input_final_prompts"][0]
  19. def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
  20. def _process_doc(doc: dict) -> dict:
  21. out_doc = {
  22. "problem": doc["input_question"],
  23. "answer": normalize_final_answer(
  24. remove_boxed(last_boxed_only_string(doc["solution"]))
  25. ),
  26. "meta_target": doc["input_correct_responses"]
  27. }
  28. return out_doc
  29. return dataset.map(_process_doc)
  30. def process_results(doc: dict, results: List[str]) -> Dict[str, int]:
  31. candidates = results[0]
  32. last_boxed_string = last_boxed_only_string(candidates)
  33. if not last_boxed_string:
  34. # No boxed string found, so we can't evaluate
  35. return {"exact_match": 0}
  36. unnormalized_answer = remove_boxed(last_boxed_string)
  37. answer = normalize_final_answer(unnormalized_answer)
  38. if answer.strip() == doc["answer"].strip() or is_equiv(answer, doc["answer"]):
  39. retval = 1
  40. else:
  41. retval = 0
  42. results = {
  43. "exact_match": retval,
  44. }
  45. return results
  46. def last_boxed_only_string(string: str) -> Optional[str]:
  47. idx = string.rfind("\\boxed")
  48. if "\\boxed " in string:
  49. return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0]
  50. if idx < 0:
  51. idx = string.rfind("\\fbox")
  52. if idx < 0:
  53. return None
  54. i = idx
  55. right_brace_idx = None
  56. num_left_braces_open = 0
  57. while i < len(string):
  58. if string[i] == "{":
  59. num_left_braces_open += 1
  60. if string[i] == "}":
  61. num_left_braces_open -= 1
  62. if num_left_braces_open == 0:
  63. right_brace_idx = i
  64. break
  65. i += 1
  66. if right_brace_idx is None:
  67. retval = None
  68. else:
  69. retval = string[idx : right_brace_idx + 1]
  70. return retval
  71. def remove_boxed(s: str) -> str:
  72. if "\\boxed " in s:
  73. left = "\\boxed "
  74. assert s[: len(left)] == left
  75. return s[len(left) :]
  76. left = "\\boxed{"
  77. assert s[: len(left)] == left
  78. assert s[-1] == "}"
  79. return s[len(left) : -1]
  80. class timeout:
  81. def __init__(self, seconds=1, error_message="Timeout"):
  82. self.seconds = seconds
  83. self.error_message = error_message
  84. def handle_timeout(self, signum, frame):
  85. raise TimeoutError(self.error_message)
  86. def __enter__(self):
  87. signal.signal(signal.SIGALRM, self.handle_timeout)
  88. signal.alarm(self.seconds)
  89. def __exit__(self, type, value, traceback):
  90. signal.alarm(0)
  91. def is_equiv(x1: str, x2: str) -> bool:
  92. """
  93. x1 and x2 are normalized latex string
  94. """
  95. try:
  96. with timeout(seconds=5):
  97. try:
  98. parsed_x1 = parse_latex(x1)
  99. parsed_x2 = parse_latex(x2)
  100. except (
  101. sympy.parsing.latex.errors.LaTeXParsingError,
  102. sympy.SympifyError,
  103. TypeError,
  104. ):
  105. eval_logger.debug(f"couldn't parse one of {x1} or {x2}")
  106. return False
  107. try:
  108. diff = parsed_x1 - parsed_x2
  109. except TypeError:
  110. eval_logger.debug(f"couldn't subtract {x1} and {x2}")
  111. return False
  112. try:
  113. if sympy.simplify(diff) == 0:
  114. return True
  115. else:
  116. return False
  117. except ValueError:
  118. eval_logger.debug(
  119. f"Had some trouble simplifying when comparing {x1} and {x2}"
  120. )
  121. except TimeoutError:
  122. eval_logger.debug(f"Timed out comparing {x1} and {x2}")
  123. return False
  124. except ImportError as e:
  125. eval_logger.error(e)
  126. raise
  127. except Exception as e:
  128. eval_logger.debug(f"Failed comparing {x1} and {x2} with {e}")
  129. return False
  130. def get_unnormalized_answer(text: str) -> str:
  131. INVALID_ANSWER = "[invalidanswer]"
  132. end_seq = "I hope it is correct."
  133. text += end_seq
  134. match = re.search(
  135. r"Final Answer: The final answer is(.*?). I hope it is correct.",
  136. text,
  137. )
  138. if match:
  139. return match.group(1).strip()
  140. else:
  141. return INVALID_ANSWER
  142. SUBSTITUTIONS = [
  143. ("an ", ""),
  144. ("a ", ""),
  145. (".$", "$"),
  146. ("\\$", ""),
  147. (r"\ ", ""),
  148. (" ", ""),
  149. ("mbox", "text"),
  150. (",\\text{and}", ","),
  151. ("\\text{and}", ","),
  152. ("\\text{m}", "\\text{}"),
  153. ]
  154. REMOVED_EXPRESSIONS = [
  155. "square",
  156. "ways",
  157. "integers",
  158. "dollars",
  159. "mph",
  160. "inches",
  161. "ft",
  162. "hours",
  163. "km",
  164. "units",
  165. "\\ldots",
  166. "sue",
  167. "points",
  168. "feet",
  169. "minutes",
  170. "digits",
  171. "cents",
  172. "degrees",
  173. "cm",
  174. "gm",
  175. "pounds",
  176. "meters",
  177. "meals",
  178. "edges",
  179. "students",
  180. "childrentickets",
  181. "multiples",
  182. "\\text{s}",
  183. "\\text{.}",
  184. "\\text{\ns}",
  185. "\\text{}^2",
  186. "\\text{}^3",
  187. "\\text{\n}",
  188. "\\text{}",
  189. r"\mathrm{th}",
  190. r"^\circ",
  191. r"^{\circ}",
  192. r"\;",
  193. r",\!",
  194. "{,}",
  195. '"',
  196. "\\dots",
  197. ]
  198. def normalize_final_answer(final_answer: str) -> str:
  199. """
  200. Normalize a final answer to a quantitative reasoning question.
  201. Copied character for character from appendix D of Lewkowycz et al. (2022)
  202. """
  203. final_answer = final_answer.split("=")[-1]
  204. for before, after in SUBSTITUTIONS:
  205. final_answer = final_answer.replace(before, after)
  206. for expr in REMOVED_EXPRESSIONS:
  207. final_answer = final_answer.replace(expr, "")
  208. # Extract answer that is in LaTeX math, is bold,
  209. # is surrounded by a box, etc.
  210. final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer)
  211. final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer)
  212. final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer)
  213. final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer)
  214. final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer)
  215. # Normalize shorthand TeX:
  216. # \fracab -> \frac{a}{b}
  217. # \frac{abc}{bef} -> \frac{abc}{bef}
  218. # \fracabc -> \frac{a}{b}c
  219. # \sqrta -> \sqrt{a}
  220. # \sqrtab -> sqrt{a}b
  221. final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer)
  222. final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer)
  223. final_answer = final_answer.replace("$", "")
  224. # Normalize 100,000 -> 100000
  225. if final_answer.replace(",", "").isdigit():
  226. final_answer = final_answer.replace(",", "")
  227. return final_answer