123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269 |
- # Most of the code taken from https://github.com/EleutherAI/lm-evaluation-harness/blob/cddce0a148ec1710e2d60546c6f92727dd8a78fd/lm_eval/tasks/leaderboard/math/utils.py
- import re
- import signal
- from typing import Dict, List, Optional
- import datasets
- from lm_eval.utils import eval_logger
- try:
- import sympy
- from sympy.parsing.latex import parse_latex
- except ModuleNotFoundError:
- raise ModuleNotFoundError(
- "`sympy` is required for generating translation task prompt templates. \
- please install sympy via pip install lm-eval[math] or pip install -e .[math]",
- )
- # taken from
- # https://github.com/wellecks/lm-evaluation-harness/blob/master/lm_eval/tasks/minerva_math.py
- def doc_to_text(doc: dict) -> str:
- return doc["input_final_prompts"][0]
- def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
- def _process_doc(doc: dict) -> dict:
- out_doc = {
- "problem": doc["input_question"],
- "answer": normalize_final_answer(
- remove_boxed(last_boxed_only_string(doc["solution"]))
- ),
- "meta_target": doc["input_correct_responses"]
- }
- return out_doc
- return dataset.map(_process_doc)
- def process_results(doc: dict, results: List[str]) -> Dict[str, int]:
- candidates = results[0]
- last_boxed_string = last_boxed_only_string(candidates)
- if not last_boxed_string:
- # No boxed string found, so we can't evaluate
- return {"exact_match": 0}
- unnormalized_answer = remove_boxed(last_boxed_string)
- answer = normalize_final_answer(unnormalized_answer)
- if answer.strip() == doc["answer"].strip() or is_equiv(answer, doc["answer"]):
- retval = 1
- else:
- retval = 0
- results = {
- "exact_match": retval,
- }
- return results
- def last_boxed_only_string(string: str) -> Optional[str]:
- idx = string.rfind("\\boxed")
- if "\\boxed " in string:
- return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0]
- if idx < 0:
- idx = string.rfind("\\fbox")
- if idx < 0:
- return None
- i = idx
- right_brace_idx = None
- num_left_braces_open = 0
- while i < len(string):
- if string[i] == "{":
- num_left_braces_open += 1
- if string[i] == "}":
- num_left_braces_open -= 1
- if num_left_braces_open == 0:
- right_brace_idx = i
- break
- i += 1
- if right_brace_idx is None:
- retval = None
- else:
- retval = string[idx : right_brace_idx + 1]
- return retval
- def remove_boxed(s: str) -> str:
- if "\\boxed " in s:
- left = "\\boxed "
- assert s[: len(left)] == left
- return s[len(left) :]
- left = "\\boxed{"
- assert s[: len(left)] == left
- assert s[-1] == "}"
- return s[len(left) : -1]
- class timeout:
- def __init__(self, seconds=1, error_message="Timeout"):
- self.seconds = seconds
- self.error_message = error_message
- def handle_timeout(self, signum, frame):
- raise TimeoutError(self.error_message)
- def __enter__(self):
- signal.signal(signal.SIGALRM, self.handle_timeout)
- signal.alarm(self.seconds)
- def __exit__(self, type, value, traceback):
- signal.alarm(0)
- def is_equiv(x1: str, x2: str) -> bool:
- """
- x1 and x2 are normalized latex string
- """
- try:
- with timeout(seconds=5):
- try:
- parsed_x1 = parse_latex(x1)
- parsed_x2 = parse_latex(x2)
- except (
- sympy.parsing.latex.errors.LaTeXParsingError,
- sympy.SympifyError,
- TypeError,
- ):
- eval_logger.debug(f"couldn't parse one of {x1} or {x2}")
- return False
- try:
- diff = parsed_x1 - parsed_x2
- except TypeError:
- eval_logger.debug(f"couldn't subtract {x1} and {x2}")
- return False
- try:
- if sympy.simplify(diff) == 0:
- return True
- else:
- return False
- except ValueError:
- eval_logger.debug(
- f"Had some trouble simplifying when comparing {x1} and {x2}"
- )
- except TimeoutError:
- eval_logger.debug(f"Timed out comparing {x1} and {x2}")
- return False
- except ImportError as e:
- eval_logger.error(e)
- raise
- except Exception as e:
- eval_logger.debug(f"Failed comparing {x1} and {x2} with {e}")
- return False
- def get_unnormalized_answer(text: str) -> str:
- INVALID_ANSWER = "[invalidanswer]"
- end_seq = "I hope it is correct."
- text += end_seq
- match = re.search(
- r"Final Answer: The final answer is(.*?). I hope it is correct.",
- text,
- )
- if match:
- return match.group(1).strip()
- else:
- return INVALID_ANSWER
- SUBSTITUTIONS = [
- ("an ", ""),
- ("a ", ""),
- (".$", "$"),
- ("\\$", ""),
- (r"\ ", ""),
- (" ", ""),
- ("mbox", "text"),
- (",\\text{and}", ","),
- ("\\text{and}", ","),
- ("\\text{m}", "\\text{}"),
- ]
- REMOVED_EXPRESSIONS = [
- "square",
- "ways",
- "integers",
- "dollars",
- "mph",
- "inches",
- "ft",
- "hours",
- "km",
- "units",
- "\\ldots",
- "sue",
- "points",
- "feet",
- "minutes",
- "digits",
- "cents",
- "degrees",
- "cm",
- "gm",
- "pounds",
- "meters",
- "meals",
- "edges",
- "students",
- "childrentickets",
- "multiples",
- "\\text{s}",
- "\\text{.}",
- "\\text{\ns}",
- "\\text{}^2",
- "\\text{}^3",
- "\\text{\n}",
- "\\text{}",
- r"\mathrm{th}",
- r"^\circ",
- r"^{\circ}",
- r"\;",
- r",\!",
- "{,}",
- '"',
- "\\dots",
- ]
- def normalize_final_answer(final_answer: str) -> str:
- """
- Normalize a final answer to a quantitative reasoning question.
- Copied character for character from appendix D of Lewkowycz et al. (2022)
- """
- final_answer = final_answer.split("=")[-1]
- for before, after in SUBSTITUTIONS:
- final_answer = final_answer.replace(before, after)
- for expr in REMOVED_EXPRESSIONS:
- final_answer = final_answer.replace(expr, "")
- # Extract answer that is in LaTeX math, is bold,
- # is surrounded by a box, etc.
- final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer)
- final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer)
- final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer)
- final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer)
- final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer)
- # Normalize shorthand TeX:
- # \fracab -> \frac{a}{b}
- # \frac{abc}{bef} -> \frac{abc}{bef}
- # \fracabc -> \frac{a}{b}c
- # \sqrta -> \sqrt{a}
- # \sqrtab -> sqrt{a}b
- final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer)
- final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer)
- final_answer = final_answer.replace("$", "")
- # Normalize 100,000 -> 100000
- if final_answer.replace(",", "").isdigit():
- final_answer = final_answer.replace(",", "")
- return final_answer
|