|
@@ -0,0 +1,268 @@
|
|
|
+# Most of the code taken from https://github.com/EleutherAI/lm-evaluation-harness/blob/cddce0a148ec1710e2d60546c6f92727dd8a78fd/lm_eval/tasks/leaderboard/math/utils.py
|
|
|
+import re
|
|
|
+import signal
|
|
|
+from typing import Dict, List, Optional
|
|
|
+
|
|
|
+import datasets
|
|
|
+
|
|
|
+from lm_eval.utils import eval_logger
|
|
|
+
|
|
|
+
|
|
|
+try:
|
|
|
+ import sympy
|
|
|
+ from sympy.parsing.latex import parse_latex
|
|
|
+except ModuleNotFoundError:
|
|
|
+ raise ModuleNotFoundError(
|
|
|
+ "`sympy` is required for generating translation task prompt templates. \
|
|
|
+please install sympy via pip install lm-eval[math] or pip install -e .[math]",
|
|
|
+ )
|
|
|
+
|
|
|
+# taken from
|
|
|
+# https://github.com/wellecks/lm-evaluation-harness/blob/master/lm_eval/tasks/minerva_math.py
|
|
|
+def doc_to_text(doc: dict) -> str:
|
|
|
+ return doc["input_final_prompts"][0]
|
|
|
+
|
|
|
+def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
|
|
|
+ def _process_doc(doc: dict) -> dict:
|
|
|
+ out_doc = {
|
|
|
+ "problem": doc["input_question"],
|
|
|
+ "answer": normalize_final_answer(
|
|
|
+ remove_boxed(last_boxed_only_string(doc["solution"]))
|
|
|
+ ),
|
|
|
+ "meta_target": doc["input_correct_responses"]
|
|
|
+ }
|
|
|
+ return out_doc
|
|
|
+ return dataset.map(_process_doc)
|
|
|
+
|
|
|
+
|
|
|
+def process_results(doc: dict, results: List[str]) -> Dict[str, int]:
|
|
|
+ candidates = results[0]
|
|
|
+ last_boxed_string = last_boxed_only_string(candidates)
|
|
|
+ if not last_boxed_string:
|
|
|
+ # No boxed string found, so we can't evaluate
|
|
|
+ return {"exact_match": 0}
|
|
|
+ unnormalized_answer = remove_boxed(last_boxed_string)
|
|
|
+ answer = normalize_final_answer(unnormalized_answer)
|
|
|
+
|
|
|
+ if answer.strip() == doc["answer"].strip() or is_equiv(answer, doc["answer"]):
|
|
|
+ retval = 1
|
|
|
+ else:
|
|
|
+ retval = 0
|
|
|
+
|
|
|
+ results = {
|
|
|
+ "exact_match": retval,
|
|
|
+ }
|
|
|
+ return results
|
|
|
+
|
|
|
+
|
|
|
+def last_boxed_only_string(string: str) -> Optional[str]:
|
|
|
+ idx = string.rfind("\\boxed")
|
|
|
+ if "\\boxed " in string:
|
|
|
+ return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0]
|
|
|
+ if idx < 0:
|
|
|
+ idx = string.rfind("\\fbox")
|
|
|
+ if idx < 0:
|
|
|
+ return None
|
|
|
+
|
|
|
+ i = idx
|
|
|
+ right_brace_idx = None
|
|
|
+ num_left_braces_open = 0
|
|
|
+ while i < len(string):
|
|
|
+ if string[i] == "{":
|
|
|
+ num_left_braces_open += 1
|
|
|
+ if string[i] == "}":
|
|
|
+ num_left_braces_open -= 1
|
|
|
+ if num_left_braces_open == 0:
|
|
|
+ right_brace_idx = i
|
|
|
+ break
|
|
|
+ i += 1
|
|
|
+
|
|
|
+ if right_brace_idx is None:
|
|
|
+ retval = None
|
|
|
+ else:
|
|
|
+ retval = string[idx : right_brace_idx + 1]
|
|
|
+
|
|
|
+ return retval
|
|
|
+
|
|
|
+
|
|
|
+def remove_boxed(s: str) -> str:
|
|
|
+ if "\\boxed " in s:
|
|
|
+ left = "\\boxed "
|
|
|
+ assert s[: len(left)] == left
|
|
|
+ return s[len(left) :]
|
|
|
+
|
|
|
+ left = "\\boxed{"
|
|
|
+
|
|
|
+ assert s[: len(left)] == left
|
|
|
+ assert s[-1] == "}"
|
|
|
+
|
|
|
+ return s[len(left) : -1]
|
|
|
+
|
|
|
+
|
|
|
+class timeout:
|
|
|
+ def __init__(self, seconds=1, error_message="Timeout"):
|
|
|
+ self.seconds = seconds
|
|
|
+ self.error_message = error_message
|
|
|
+
|
|
|
+ def handle_timeout(self, signum, frame):
|
|
|
+ raise TimeoutError(self.error_message)
|
|
|
+
|
|
|
+ def __enter__(self):
|
|
|
+ signal.signal(signal.SIGALRM, self.handle_timeout)
|
|
|
+ signal.alarm(self.seconds)
|
|
|
+
|
|
|
+ def __exit__(self, type, value, traceback):
|
|
|
+ signal.alarm(0)
|
|
|
+
|
|
|
+
|
|
|
+def is_equiv(x1: str, x2: str) -> bool:
|
|
|
+ """
|
|
|
+ x1 and x2 are normalized latex string
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ with timeout(seconds=5):
|
|
|
+ try:
|
|
|
+ parsed_x1 = parse_latex(x1)
|
|
|
+ parsed_x2 = parse_latex(x2)
|
|
|
+ except (
|
|
|
+ sympy.parsing.latex.errors.LaTeXParsingError,
|
|
|
+ sympy.SympifyError,
|
|
|
+ TypeError,
|
|
|
+ ):
|
|
|
+ eval_logger.debug(f"couldn't parse one of {x1} or {x2}")
|
|
|
+ return False
|
|
|
+
|
|
|
+ try:
|
|
|
+ diff = parsed_x1 - parsed_x2
|
|
|
+ except TypeError:
|
|
|
+ eval_logger.debug(f"couldn't subtract {x1} and {x2}")
|
|
|
+ return False
|
|
|
+
|
|
|
+ try:
|
|
|
+ if sympy.simplify(diff) == 0:
|
|
|
+ return True
|
|
|
+ else:
|
|
|
+ return False
|
|
|
+ except ValueError:
|
|
|
+ eval_logger.debug(
|
|
|
+ f"Had some trouble simplifying when comparing {x1} and {x2}"
|
|
|
+ )
|
|
|
+ except TimeoutError:
|
|
|
+ eval_logger.debug(f"Timed out comparing {x1} and {x2}")
|
|
|
+ return False
|
|
|
+ except ImportError as e:
|
|
|
+ eval_logger.error(e)
|
|
|
+ raise
|
|
|
+ except Exception as e:
|
|
|
+ eval_logger.debug(f"Failed comparing {x1} and {x2} with {e}")
|
|
|
+ return False
|
|
|
+
|
|
|
+
|
|
|
+def get_unnormalized_answer(text: str) -> str:
|
|
|
+ INVALID_ANSWER = "[invalidanswer]"
|
|
|
+ end_seq = "I hope it is correct."
|
|
|
+ text += end_seq
|
|
|
+ match = re.search(
|
|
|
+ r"Final Answer: The final answer is(.*?). I hope it is correct.",
|
|
|
+ text,
|
|
|
+ )
|
|
|
+ if match:
|
|
|
+ return match.group(1).strip()
|
|
|
+ else:
|
|
|
+ return INVALID_ANSWER
|
|
|
+
|
|
|
+
|
|
|
+SUBSTITUTIONS = [
|
|
|
+ ("an ", ""),
|
|
|
+ ("a ", ""),
|
|
|
+ (".$", "$"),
|
|
|
+ ("\\$", ""),
|
|
|
+ (r"\ ", ""),
|
|
|
+ (" ", ""),
|
|
|
+ ("mbox", "text"),
|
|
|
+ (",\\text{and}", ","),
|
|
|
+ ("\\text{and}", ","),
|
|
|
+ ("\\text{m}", "\\text{}"),
|
|
|
+]
|
|
|
+REMOVED_EXPRESSIONS = [
|
|
|
+ "square",
|
|
|
+ "ways",
|
|
|
+ "integers",
|
|
|
+ "dollars",
|
|
|
+ "mph",
|
|
|
+ "inches",
|
|
|
+ "ft",
|
|
|
+ "hours",
|
|
|
+ "km",
|
|
|
+ "units",
|
|
|
+ "\\ldots",
|
|
|
+ "sue",
|
|
|
+ "points",
|
|
|
+ "feet",
|
|
|
+ "minutes",
|
|
|
+ "digits",
|
|
|
+ "cents",
|
|
|
+ "degrees",
|
|
|
+ "cm",
|
|
|
+ "gm",
|
|
|
+ "pounds",
|
|
|
+ "meters",
|
|
|
+ "meals",
|
|
|
+ "edges",
|
|
|
+ "students",
|
|
|
+ "childrentickets",
|
|
|
+ "multiples",
|
|
|
+ "\\text{s}",
|
|
|
+ "\\text{.}",
|
|
|
+ "\\text{\ns}",
|
|
|
+ "\\text{}^2",
|
|
|
+ "\\text{}^3",
|
|
|
+ "\\text{\n}",
|
|
|
+ "\\text{}",
|
|
|
+ r"\mathrm{th}",
|
|
|
+ r"^\circ",
|
|
|
+ r"^{\circ}",
|
|
|
+ r"\;",
|
|
|
+ r",\!",
|
|
|
+ "{,}",
|
|
|
+ '"',
|
|
|
+ "\\dots",
|
|
|
+]
|
|
|
+
|
|
|
+
|
|
|
+def normalize_final_answer(final_answer: str) -> str:
|
|
|
+ """
|
|
|
+ Normalize a final answer to a quantitative reasoning question.
|
|
|
+
|
|
|
+ Copied character for character from appendix D of Lewkowycz et al. (2022)
|
|
|
+ """
|
|
|
+ final_answer = final_answer.split("=")[-1]
|
|
|
+
|
|
|
+ for before, after in SUBSTITUTIONS:
|
|
|
+ final_answer = final_answer.replace(before, after)
|
|
|
+ for expr in REMOVED_EXPRESSIONS:
|
|
|
+ final_answer = final_answer.replace(expr, "")
|
|
|
+
|
|
|
+ # Extract answer that is in LaTeX math, is bold,
|
|
|
+ # is surrounded by a box, etc.
|
|
|
+ final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer)
|
|
|
+ final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer)
|
|
|
+ final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer)
|
|
|
+ final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer)
|
|
|
+ final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer)
|
|
|
+
|
|
|
+ # Normalize shorthand TeX:
|
|
|
+ # \fracab -> \frac{a}{b}
|
|
|
+ # \frac{abc}{bef} -> \frac{abc}{bef}
|
|
|
+ # \fracabc -> \frac{a}{b}c
|
|
|
+ # \sqrta -> \sqrt{a}
|
|
|
+ # \sqrtab -> sqrt{a}b
|
|
|
+ final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer)
|
|
|
+ final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer)
|
|
|
+ final_answer = final_answer.replace("$", "")
|
|
|
+
|
|
|
+ # Normalize 100,000 -> 100000
|
|
|
+ if final_answer.replace(",", "").isdigit():
|
|
|
+ final_answer = final_answer.replace(",", "")
|
|
|
+
|
|
|
+ return final_answer
|