utils.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. # Most of the code taken from https://github.com/EleutherAI/lm-evaluation-harness/blob/cddce0a148ec1710e2d60546c6f92727dd8a78fd/lm_eval/tasks/leaderboard/math/utils.py
  2. import re
  3. import signal
  4. from typing import Dict, List, Optional
  5. import datasets
  6. from lm_eval.utils import eval_logger
  7. try:
  8. import sympy
  9. from sympy.parsing.latex import parse_latex
  10. except ModuleNotFoundError:
  11. raise ModuleNotFoundError(
  12. "`sympy` is required for generating translation task prompt templates. \
  13. please install sympy via pip install lm-eval[math] or pip install -e .[math]",
  14. )
  15. # taken from
  16. # https://github.com/wellecks/lm-evaluation-harness/blob/master/lm_eval/tasks/minerva_math.py
  17. def doc_to_text(doc: dict) -> str:
  18. return doc["input_final_prompts"][0]
  19. def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
  20. def _process_doc(doc: dict) -> dict:
  21. out_doc = {
  22. "problem": doc["input_question"],
  23. "answer": normalize_final_answer(
  24. remove_boxed(last_boxed_only_string(doc["solution"]))
  25. ),
  26. "meta_target": doc["input_correct_responses"]
  27. }
  28. return out_doc
  29. #dataset = dataset.select_columns(["input_question", "input_correct_responses", "input_final_prompts", "is_correct","output_prediction_text"])
  30. dataset = dataset.rename_column("is_correct","previously_is_correct")
  31. return dataset.map(_process_doc)
  32. def extract_result_from_boxed(answer: str) -> str:
  33. box_start = "\\boxed"
  34. # format is `\\boxed <value>$` or `\\boxed{<value>}`, with potential white spaces framing `<value>`
  35. start = answer.rfind(box_start)
  36. if start < 0:
  37. return ""
  38. answer = answer[start + len(box_start) :].strip()
  39. ends_with_curly = answer.startswith("{")
  40. i = 0
  41. open_braces = 0
  42. while i < len(answer):
  43. if answer[i] == "{":
  44. open_braces += 1
  45. elif answer[i] == "}":
  46. open_braces -= 1
  47. if open_braces == 0:
  48. if ends_with_curly:
  49. answer = answer[: i + 1].strip()
  50. break
  51. elif answer[i] == "$":
  52. answer = answer[:i].strip()
  53. break
  54. i += 1
  55. else:
  56. return ""
  57. # remove extra curly braces
  58. while True:
  59. if answer.startswith("{") and answer.endswith("}"):
  60. answer = answer[1:-1].strip()
  61. else:
  62. break
  63. return answer
  64. def process_results(doc: dict, results: List[str]) -> Dict[str, int]:
  65. candidates = results[0]
  66. unnormalized_answer = get_unnormalized_answer(candidates)
  67. if unnormalized_answer == "[invalidanswer]":
  68. unnormalized_answer = extract_result_from_boxed(candidates)
  69. answer = normalize_final_answer(unnormalized_answer)
  70. if answer.strip() == doc["answer"].strip() or is_equiv(answer, doc["answer"]):
  71. retval = 1
  72. else:
  73. retval = 0
  74. results = {
  75. "exact_match": retval,
  76. }
  77. return results
  78. def last_boxed_only_string(string: str) -> Optional[str]:
  79. idx = string.rfind("\\boxed")
  80. if "\\boxed " in string:
  81. return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0]
  82. if idx < 0:
  83. idx = string.rfind("\\fbox")
  84. if idx < 0:
  85. return None
  86. i = idx
  87. right_brace_idx = None
  88. num_left_braces_open = 0
  89. while i < len(string):
  90. if string[i] == "{":
  91. num_left_braces_open += 1
  92. if string[i] == "}":
  93. num_left_braces_open -= 1
  94. if num_left_braces_open == 0:
  95. right_brace_idx = i
  96. break
  97. i += 1
  98. if right_brace_idx is None:
  99. retval = None
  100. else:
  101. retval = string[idx : right_brace_idx + 1]
  102. return retval
  103. def remove_boxed(s: str) -> str:
  104. if "\\boxed " in s:
  105. left = "\\boxed "
  106. assert s[: len(left)] == left
  107. return s[len(left) :]
  108. left = "\\boxed{"
  109. assert s[: len(left)] == left
  110. assert s[-1] == "}"
  111. return s[len(left) : -1]
  112. class timeout:
  113. def __init__(self, seconds=1, error_message="Timeout"):
  114. self.seconds = seconds
  115. self.error_message = error_message
  116. def handle_timeout(self, signum, frame):
  117. raise TimeoutError(self.error_message)
  118. def __enter__(self):
  119. signal.signal(signal.SIGALRM, self.handle_timeout)
  120. signal.alarm(self.seconds)
  121. def __exit__(self, type, value, traceback):
  122. signal.alarm(0)
  123. def is_equiv(x1: str, x2: str) -> bool:
  124. """
  125. x1 and x2 are normalized latex string
  126. """
  127. try:
  128. with timeout(seconds=5):
  129. try:
  130. parsed_x1 = parse_latex(x1)
  131. parsed_x2 = parse_latex(x2)
  132. except (
  133. sympy.parsing.latex.errors.LaTeXParsingError,
  134. sympy.SympifyError,
  135. TypeError,
  136. ):
  137. eval_logger.debug(f"couldn't parse one of {x1} or {x2}")
  138. return False
  139. try:
  140. diff = parsed_x1 - parsed_x2
  141. except TypeError:
  142. eval_logger.debug(f"couldn't subtract {x1} and {x2}")
  143. return False
  144. try:
  145. if sympy.simplify(diff) == 0:
  146. return True
  147. else:
  148. return False
  149. except ValueError:
  150. eval_logger.debug(
  151. f"Had some trouble simplifying when comparing {x1} and {x2}"
  152. )
  153. except TimeoutError:
  154. eval_logger.debug(f"Timed out comparing {x1} and {x2}")
  155. return False
  156. except ImportError as e:
  157. eval_logger.error(e)
  158. raise
  159. except Exception as e:
  160. eval_logger.debug(f"Failed comparing {x1} and {x2} with {e}")
  161. return False
  162. def get_unnormalized_answer(text: str) -> str:
  163. INVALID_ANSWER = "[invalidanswer]"
  164. end_seq = "I hope it is correct."
  165. text += end_seq
  166. match = re.search(
  167. r"Final Answer: The final answer is(.*?). I hope it is correct.",
  168. text,
  169. )
  170. if match:
  171. return match.group(1).strip()
  172. else:
  173. return INVALID_ANSWER
  174. SUBSTITUTIONS = [
  175. ("an ", ""),
  176. ("a ", ""),
  177. (".$", "$"),
  178. ("\\$", ""),
  179. (r"\ ", ""),
  180. (" ", ""),
  181. ("mbox", "text"),
  182. (",\\text{and}", ","),
  183. ("\\text{and}", ","),
  184. ("\\text{m}", "\\text{}"),
  185. ]
  186. REMOVED_EXPRESSIONS = [
  187. "square",
  188. "ways",
  189. "integers",
  190. "dollars",
  191. "mph",
  192. "inches",
  193. "ft",
  194. "hours",
  195. "km",
  196. "units",
  197. "\\ldots",
  198. "sue",
  199. "points",
  200. "feet",
  201. "minutes",
  202. "digits",
  203. "cents",
  204. "degrees",
  205. "cm",
  206. "gm",
  207. "pounds",
  208. "meters",
  209. "meals",
  210. "edges",
  211. "students",
  212. "childrentickets",
  213. "multiples",
  214. "\\text{s}",
  215. "\\text{.}",
  216. "\\text{\ns}",
  217. "\\text{}^2",
  218. "\\text{}^3",
  219. "\\text{\n}",
  220. "\\text{}",
  221. r"\mathrm{th}",
  222. r"^\circ",
  223. r"^{\circ}",
  224. r"\;",
  225. r",\!",
  226. "{,}",
  227. '"',
  228. "\\dots",
  229. ]
  230. def normalize_final_answer(final_answer: str) -> str:
  231. """
  232. Normalize a final answer to a quantitative reasoning question.
  233. Copied character for character from appendix D of Lewkowycz et al. (2022)
  234. """
  235. final_answer = final_answer.split("=")[-1]
  236. for before, after in SUBSTITUTIONS:
  237. final_answer = final_answer.replace(before, after)
  238. for expr in REMOVED_EXPRESSIONS:
  239. final_answer = final_answer.replace(expr, "")
  240. # Extract answer that is in LaTeX math, is bold,
  241. # is surrounded by a box, etc.
  242. final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer)
  243. final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer)
  244. final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer)
  245. final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer)
  246. final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer)
  247. # Normalize shorthand TeX:
  248. # \fracab -> \frac{a}{b}
  249. # \frac{abc}{bef} -> \frac{abc}{bef}
  250. # \fracabc -> \frac{a}{b}c
  251. # \sqrta -> \sqrt{a}
  252. # \sqrtab -> sqrt{a}b
  253. final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer)
  254. final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer)
  255. final_answer = final_answer.replace("$", "")
  256. # Normalize 100,000 -> 100000
  257. if final_answer.replace(",", "").isdigit():
  258. final_answer = final_answer.replace(",", "")
  259. return final_answer