浏览代码

Eval reproduce recipe using lm-evaluation-harness and our 3.1 evals datasets (#627)

Hamid Shojanazeri 7 月之前
父节点
当前提交
b5f64c0b69
共有 20 个文件被更改,包括 1154 次插入3 次删除
  1. 19 1
      .github/scripts/spellcheck_conf/wordlist.txt
  2. 2 1
      recipes/use_cases/multilingual/README.md
  3. 1 1
      requirements.txt
  4. 4 0
      tools/benchmarks/llm_eval_harness/README.md
  5. 213 0
      tools/benchmarks/llm_eval_harness/meta_eval_reproduce/README.md
  6. 32 0
      tools/benchmarks/llm_eval_harness/meta_eval_reproduce/eval_config.yaml
  7. 28 0
      tools/benchmarks/llm_eval_harness/meta_eval_reproduce/meta_template/bbh/bbh_3shot_cot.yaml
  8. 21 0
      tools/benchmarks/llm_eval_harness/meta_eval_reproduce/meta_template/bbh/utils.py
  9. 29 0
      tools/benchmarks/llm_eval_harness/meta_eval_reproduce/meta_template/gpqa_cot/gpqa_0shot_cot.yaml
  10. 20 0
      tools/benchmarks/llm_eval_harness/meta_eval_reproduce/meta_template/gpqa_cot/utils.py
  11. 32 0
      tools/benchmarks/llm_eval_harness/meta_eval_reproduce/meta_template/ifeval/ifeval.yaml
  12. 139 0
      tools/benchmarks/llm_eval_harness/meta_eval_reproduce/meta_template/ifeval/utils.py
  13. 21 0
      tools/benchmarks/llm_eval_harness/meta_eval_reproduce/meta_template/math_hard/math_hard_0shot_cot.yaml
  14. 268 0
      tools/benchmarks/llm_eval_harness/meta_eval_reproduce/meta_template/math_hard/utils.py
  15. 6 0
      tools/benchmarks/llm_eval_harness/meta_eval_reproduce/meta_template/meta_instruct.yaml
  16. 4 0
      tools/benchmarks/llm_eval_harness/meta_eval_reproduce/meta_template/meta_pretrain.yaml
  17. 29 0
      tools/benchmarks/llm_eval_harness/meta_eval_reproduce/meta_template/mmlu_pro/mmlu_pro_5shot_cot_instruct.yaml
  18. 28 0
      tools/benchmarks/llm_eval_harness/meta_eval_reproduce/meta_template/mmlu_pro/mmlu_pro_5shot_cot_pretrain.yaml
  19. 21 0
      tools/benchmarks/llm_eval_harness/meta_eval_reproduce/meta_template/mmlu_pro/utils.py
  20. 237 0
      tools/benchmarks/llm_eval_harness/meta_eval_reproduce/prepare_meta_eval.py

+ 19 - 1
.github/scripts/spellcheck_conf/wordlist.txt

@@ -1432,4 +1432,22 @@ CPUs
 modelUpgradeExample
 guardrailing
 MaaS
-MFU
+MFU
+BBH
+GPQA
+IFEVAL
+IFeval
+bos
+gpqa
+ifeval
+lighteval
+sqrt
+wis
+evals
+mmlu
+parsers
+reproducibility
+openhathi
+sarvam
+subtask
+acc

+ 2 - 1
recipes/use_cases/multilingual/README.md

@@ -1,7 +1,8 @@
 # Extending Llama to a new language
 Authored by : Sarvam team
 In this recipe, we will see how to add a new language to the Llama family of models. The steps are quite general and can be easily adapted to other models as well. Using this recipe, you should be able to replicate the findings of [OpenHathi](https://huggingface.co/sarvamai/OpenHathi-7B-Hi-v0.1-Base).
-Please read more about OpenHathi [here](https://analyticsindiamag.com/industry-insights/ai-startups/indian-startup-sarvam-ai-launches-hindi-llm-openhathi/)
+Please read more about OpenHathi [here](https://web.archive.org/web/20240418103408/https://www.sarvam.ai/blog/announcing-openhathi-series)
+
 ## Data
 The original OpenHathi model uses a combination of [Sangraha](https://huggingface.co/datasets/ai4bharat/sangraha) and Wikipedia as its primary data sources. If the reader is interested in using these sources, they would also have to preprocess the data: clean, filter, and deduplicate. See [Setu](https://github.com/AI4Bharat/setu) for an easy way to do this at scale.
 

+ 1 - 1
requirements.txt

@@ -28,4 +28,4 @@ langchain_openai
 langchain
 langchain_community
 sentence_transformers
-codeshield
+codeshield

文件差异内容过多而无法显示
+ 4 - 0
tools/benchmarks/llm_eval_harness/README.md


文件差异内容过多而无法显示
+ 213 - 0
tools/benchmarks/llm_eval_harness/meta_eval_reproduce/README.md


+ 32 - 0
tools/benchmarks/llm_eval_harness/meta_eval_reproduce/eval_config.yaml

@@ -0,0 +1,32 @@
+model_name: "meta-llama/Meta-Llama-3.1-8B-Instruct" # The name of the model to evaluate. This must be a valid Meta Llama 3 based model name in the HuggingFace model hub."
+
+evals_dataset: "meta-llama/Meta-Llama-3.1-8B-Instruct-evals" # The name of the 3.1 evals dataset to evaluate, please make sure this eval dataset corresponds to the model loaded. This must be a valid Meta Llama 3.1 evals dataset name in the Llama 3.1 Evals collection.
+# Must be one of the following ["meta-llama/Meta-Llama-3.1-8B-Instruct-evals","meta-llama/Meta-Llama-3.1-70B-Instruct-evals","meta-llama/Meta-Llama-3.1-405B-Instruct-evals","meta-llama/Meta-Llama-3.1-8B-evals","meta-llama/Meta-Llama-3.1-70B-evals","meta-llama/Meta-Llama-3.1-405B-evals"]
+
+tasks: "meta_instruct" # Available tasks for instruct model: "meta_math_hard", "meta_gpqa", "meta_mmlu_pro_instruct", "meta_ifeval"; or just use "meta_instruct" to run all of them.
+# Available tasks for pretrain model: "meta_bbh", "meta_mmlu_pro_pretrain"; or just use "meta_pretrain" to run all of them.
+
+tensor_parallel_size: 1 # The VLLM argument that speicify the tensor parallel size for the model, eg how many GPUs to use for a model copy.
+
+data_parallel_size: 4 # The VLLM argument that speicify the data parallel size for the model, eg how copies of model will be used.
+
+gpu_memory_utilization: 0.9 #The VLLM argument that speicify gpu memory utilization, the rest will be reserved for KV cache.
+
+max_model_len: 8192 #The VLLM argument that speicify model max length, decrease this value only if GPU memory issue encountered. Please make sure the max_gen_toks in the yaml does not exceed this length.
+
+batch_size: "auto" # Batch size, can be 'auto', 'auto:N', or an integer. It is strongly recommend to use 'auto' for vllm to speed up the inference
+
+output_path: "eval_results" # the output folder to store all the eval results and samples.
+
+#limit: 12 # Limit number of examples per task, set 'null' to run all.
+limit: null # Limit number of examples per task, set 'null' to run all.
+
+verbosity: "INFO" #Logging level: CRITICAL, ERROR, WARNING, INFO, DEBUG.
+
+log_samples: true # If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis.
+
+work_dir: ./work_dir # The work folder where the task template yaml files will be copied and modified, datasets will be downloaded for math_hard, ifeval.
+
+template_dir: ./meta_template #Path to the folder that contains all the meta task templates
+
+show_config: false # If True, shows the full config of all tasks at the end of the evaluation.

+ 28 - 0
tools/benchmarks/llm_eval_harness/meta_eval_reproduce/meta_template/bbh/bbh_3shot_cot.yaml

@@ -0,0 +1,28 @@
+dataset_path: meta-llama/Meta-Llama-3.1-8B-evals
+dataset_name: Meta-Llama-3.1-8B-evals__bbh__details
+task: meta_bbh
+output_type: generate_until
+process_docs: !function utils.process_docs
+test_split: latest
+doc_to_text: !function utils.doc_to_text
+doc_to_target: answer
+filter_list:
+  - name: "strict-match"
+    filter:
+      - function: "regex"
+        regex_pattern: 'the answer is (.*?)\.'
+      - function: "take_first"
+generation_kwargs:
+  until: "\n\nQ: "
+  do_sample: false
+  temperature: 0
+  max_gen_toks: 512
+num_fewshot: 0
+metric_list:
+  - metric: exact_match
+    aggregation: mean
+    higher_is_better: true
+    ignore_case: true
+    ignore_punctuation: true
+metadata:
+  version: 1.0

+ 21 - 0
tools/benchmarks/llm_eval_harness/meta_eval_reproduce/meta_template/bbh/utils.py

@@ -0,0 +1,21 @@
+import random
+import re
+
+import datasets
+
+
+
+def doc_to_text(doc: dict) -> str:
+    return doc["input_final_prompts"][0]
+
+def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
+    def _process_doc(doc: dict) -> dict:
+        out_doc = {
+            "problem": doc["input_question"],
+            "answer": doc["input_correct_responses"][0],
+        }
+        return out_doc
+    dataset = dataset.select_columns(["input_question", "input_correct_responses", "input_final_prompts", "is_correct","input_question_hash","output_prediction_text"])
+    dataset = dataset.rename_column("is_correct","previously_is_correct")
+    dataset = dataset.map(_process_doc)
+    return dataset.map(_process_doc)

+ 29 - 0
tools/benchmarks/llm_eval_harness/meta_eval_reproduce/meta_template/gpqa_cot/gpqa_0shot_cot.yaml

@@ -0,0 +1,29 @@
+dataset_path: meta-llama/Meta-Llama-3.1-8B-Instruct-evals
+dataset_name: Meta-Llama-3.1-8B-Instruct-evals__gpqa__details
+task: meta_gpqa
+output_type: generate_until
+process_docs: !function utils.process_docs
+test_split: latest
+doc_to_text: !function utils.doc_to_text
+doc_to_target: gold
+filter_list:
+  - name: "strict-match"
+    filter:
+      - function: "regex"
+        group_select: -1
+        regex_pattern: 'best answer is ([A-Z])'
+      - function: "take_first"
+generation_kwargs:
+  until: []
+  do_sample: false
+  temperature: 0
+  max_gen_toks: 2048
+num_fewshot: 0
+metric_list:
+  - metric: exact_match
+    aggregation: mean
+    higher_is_better: true
+    ignore_case: true
+    ignore_punctuation: true
+metadata:
+  version: 1.0

+ 20 - 0
tools/benchmarks/llm_eval_harness/meta_eval_reproduce/meta_template/gpqa_cot/utils.py

@@ -0,0 +1,20 @@
+import random
+import re
+
+import datasets
+
+
+
+def doc_to_text(doc: dict) -> str:
+    return doc["input_final_prompts"][0]
+def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
+    def _process_doc(doc: dict) -> dict:
+        out_doc = {
+            "problem": doc["input_question"],
+            "gold": doc["input_correct_responses"][0],
+        }
+        return out_doc
+    dataset = dataset.select_columns(["input_question", "input_correct_responses", "input_final_prompts", "is_correct","input_question_hash","input_choice_list","output_prediction_text"])
+    dataset = dataset.rename_column("is_correct","previously_is_correct")
+    dataset = dataset.map(_process_doc)
+    return dataset.map(_process_doc)

+ 32 - 0
tools/benchmarks/llm_eval_harness/meta_eval_reproduce/meta_template/ifeval/ifeval.yaml

@@ -0,0 +1,32 @@
+task: meta_ifeval
+dataset_path: parquet
+dataset_kwargs:
+  data_files: ./work_dir/joined_ifeval.parquet
+output_type: generate_until
+test_split: train
+num_fewshot: 0
+doc_to_text: prompt
+doc_to_target: 0
+generation_kwargs:
+  until: []
+  do_sample: false
+  temperature: 0.0
+  max_gen_toks: 1280
+process_results: !function utils.process_results
+metric_list:
+  - metric: prompt_level_strict_acc
+    aggregation: mean
+    higher_is_better: true
+  - metric: inst_level_strict_acc
+    aggregation: !function utils.agg_inst_level_acc
+    higher_is_better: true
+  - metric: prompt_level_loose_acc
+    aggregation: mean
+    higher_is_better: true
+  - metric: inst_level_loose_acc
+    aggregation: !function utils.agg_inst_level_acc
+    higher_is_better: true
+metadata:
+  version: 2.0
+fewshot_config:
+  sampler: first_n

+ 139 - 0
tools/benchmarks/llm_eval_harness/meta_eval_reproduce/meta_template/ifeval/utils.py

@@ -0,0 +1,139 @@
+import dataclasses
+from typing import Dict, Optional, Union
+
+from lm_eval.tasks.ifeval import instructions_registry
+
+
+@dataclasses.dataclass
+class InputExample:
+    key: int
+    instruction_id_list: list[str]
+    prompt: str
+    kwargs: list[Dict[str, Optional[Union[str, int]]]]
+
+
+@dataclasses.dataclass
+class OutputExample:
+    instruction_id_list: list[str]
+    prompt: str
+    response: str
+    follow_all_instructions: bool
+    follow_instruction_list: list[bool]
+
+
+def test_instruction_following_strict(
+    inp,
+    response,
+):
+    """Tests response to see if instructions are followed."""
+    instruction_list = inp.instruction_id_list
+    is_following_list = []
+
+    for index, instruction_id in enumerate(instruction_list):
+        instruction_cls = instructions_registry.INSTRUCTION_DICT[instruction_id]
+        instruction = instruction_cls(instruction_id)
+                
+        # Remove None values from kwargs to avoid unexpected keyword argument errors in build_description method.
+        kwargs = {k: v for k, v in inp.kwargs[index].items() if v}
+        instruction.build_description(**kwargs)
+        args = instruction.get_instruction_args()
+        if args and "prompt" in args:
+            instruction.build_description(prompt=inp.prompt)
+
+        if response.strip() and instruction.check_following(response):
+            is_following_list.append(True)
+        else:
+            is_following_list.append(False)
+
+    return OutputExample(
+        instruction_id_list=inp.instruction_id_list,
+        prompt=inp.prompt,
+        response=response,
+        follow_all_instructions=all(is_following_list),
+        follow_instruction_list=is_following_list,
+    )
+
+
+def test_instruction_following_loose(
+    inp,
+    response,
+):
+    """Tests response for an upper bound for following instructions."""
+    r = response.split("\n")
+    response_remove_first = "\n".join(r[1:]).strip()
+    response_remove_last = "\n".join(r[:-1]).strip()
+    response_remove_both = "\n".join(r[1:-1]).strip()
+    revised_response = response.replace("*", "")
+    revised_response_remove_first = response_remove_first.replace("*", "")
+    revised_response_remove_last = response_remove_last.replace("*", "")
+    revised_response_remove_both = response_remove_both.replace("*", "")
+    all_responses = [
+        response,
+        revised_response,
+        response_remove_first,
+        response_remove_last,
+        response_remove_both,
+        revised_response_remove_first,
+        revised_response_remove_last,
+        revised_response_remove_both,
+    ]
+    instruction_list = inp.instruction_id_list
+    is_following_list = []
+
+    for index, instruction_id in enumerate(instruction_list):
+        instruction_cls = instructions_registry.INSTRUCTION_DICT[instruction_id]
+        instruction = instruction_cls(instruction_id)
+
+        # Remove None values from kwargs to avoid unexpected keyword argument errors in build_description method.
+        kwargs = {k: v for k, v in inp.kwargs[index].items() if v}
+        instruction.build_description(**kwargs)
+        args = instruction.get_instruction_args()
+        if args and "prompt" in args:
+            instruction.build_description(prompt=inp.prompt)
+
+        is_following = False
+        for r in all_responses:
+            if r.strip() and instruction.check_following(r):
+                is_following = True
+                break
+
+        is_following_list.append(is_following)
+
+    return OutputExample(
+        instruction_id_list=inp.instruction_id_list,
+        prompt=inp.prompt,
+        response=response,
+        follow_all_instructions=all(is_following_list),
+        follow_instruction_list=is_following_list,
+    )
+
+
+def process_results(doc, results):
+    new_kwargs = []
+    for item in doc["kwargs"]:
+        if item["nth_paragraph"]:
+            item["nth_paragraph"] = int(item["nth_paragraph"])
+        new_kwargs.append(item)
+    inp = InputExample(
+        key=doc["key"],
+        instruction_id_list=doc["instruction_id_list"],
+        prompt=doc["prompt"],
+        kwargs=new_kwargs,
+    )
+    response = results[0]
+
+    out_strict = test_instruction_following_strict(inp, response)
+    out_loose = test_instruction_following_loose(inp, response)
+
+    return {
+        "prompt_level_strict_acc": out_strict.follow_all_instructions,
+        "inst_level_strict_acc": out_strict.follow_instruction_list,
+        "prompt_level_loose_acc": out_loose.follow_all_instructions,
+        "inst_level_loose_acc": out_loose.follow_instruction_list,
+    }
+
+
+def agg_inst_level_acc(items):
+    flat_items = [item for sublist in items for item in sublist]
+    inst_level_acc = sum(flat_items) / len(flat_items)
+    return inst_level_acc

+ 21 - 0
tools/benchmarks/llm_eval_harness/meta_eval_reproduce/meta_template/math_hard/math_hard_0shot_cot.yaml

@@ -0,0 +1,21 @@
+dataset_path: parquet
+dataset_kwargs:
+  data_files: ./work_dir/joined_math.parquet
+task: meta_math_hard
+process_docs: !function utils.process_docs
+output_type: generate_until
+test_split: train
+doc_to_text:  !function utils.doc_to_text
+process_results: !function utils.process_results
+doc_to_target: answer
+generation_kwargs:
+  until: []
+  do_sample: false
+  temperature: 0
+  max_gen_toks: 5120
+metric_list:
+  - metric: exact_match
+    aggregation: mean
+    higher_is_better: true
+metadata:
+  version: 1.0

+ 268 - 0
tools/benchmarks/llm_eval_harness/meta_eval_reproduce/meta_template/math_hard/utils.py

@@ -0,0 +1,268 @@
+# Most of the code taken from https://github.com/EleutherAI/lm-evaluation-harness/blob/cddce0a148ec1710e2d60546c6f92727dd8a78fd/lm_eval/tasks/leaderboard/math/utils.py
+import re
+import signal
+from typing import Dict, List, Optional
+
+import datasets
+
+from lm_eval.utils import eval_logger
+
+
+try:
+    import sympy
+    from sympy.parsing.latex import parse_latex
+except ModuleNotFoundError:
+    raise ModuleNotFoundError(
+        "`sympy` is required for generating translation task prompt templates. \
+please install sympy via pip install lm-eval[math] or pip install -e .[math]",
+    )
+
+# taken from
+# https://github.com/wellecks/lm-evaluation-harness/blob/master/lm_eval/tasks/minerva_math.py
+def doc_to_text(doc: dict) -> str:
+    return doc["input_final_prompts"][0]
+
+def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
+    def _process_doc(doc: dict) -> dict:
+        out_doc = {
+            "problem": doc["input_question"],
+            "answer": normalize_final_answer(
+                 remove_boxed(last_boxed_only_string(doc["solution"]))
+            ),
+            "meta_target": doc["input_correct_responses"]
+        }
+        return out_doc
+    return dataset.map(_process_doc)
+
+
+def process_results(doc: dict, results: List[str]) -> Dict[str, int]:
+    candidates = results[0]
+    last_boxed_string = last_boxed_only_string(candidates)
+    if not last_boxed_string:
+        # No boxed string found, so we can't evaluate
+        return {"exact_match": 0}
+    unnormalized_answer = remove_boxed(last_boxed_string)
+    answer = normalize_final_answer(unnormalized_answer)
+
+    if answer.strip() == doc["answer"].strip() or is_equiv(answer, doc["answer"]):
+        retval = 1
+    else:
+        retval = 0
+
+    results = {
+        "exact_match": retval,
+    }
+    return results
+
+
+def last_boxed_only_string(string: str) -> Optional[str]:
+    idx = string.rfind("\\boxed")
+    if "\\boxed " in string:
+        return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0]
+    if idx < 0:
+        idx = string.rfind("\\fbox")
+        if idx < 0:
+            return None
+
+    i = idx
+    right_brace_idx = None
+    num_left_braces_open = 0
+    while i < len(string):
+        if string[i] == "{":
+            num_left_braces_open += 1
+        if string[i] == "}":
+            num_left_braces_open -= 1
+            if num_left_braces_open == 0:
+                right_brace_idx = i
+                break
+        i += 1
+
+    if right_brace_idx is None:
+        retval = None
+    else:
+        retval = string[idx : right_brace_idx + 1]
+
+    return retval
+
+
+def remove_boxed(s: str) -> str:
+    if "\\boxed " in s:
+        left = "\\boxed "
+        assert s[: len(left)] == left
+        return s[len(left) :]
+
+    left = "\\boxed{"
+
+    assert s[: len(left)] == left
+    assert s[-1] == "}"
+
+    return s[len(left) : -1]
+
+
+class timeout:
+    def __init__(self, seconds=1, error_message="Timeout"):
+        self.seconds = seconds
+        self.error_message = error_message
+
+    def handle_timeout(self, signum, frame):
+        raise TimeoutError(self.error_message)
+
+    def __enter__(self):
+        signal.signal(signal.SIGALRM, self.handle_timeout)
+        signal.alarm(self.seconds)
+
+    def __exit__(self, type, value, traceback):
+        signal.alarm(0)
+
+
+def is_equiv(x1: str, x2: str) -> bool:
+    """
+    x1 and x2 are normalized latex string
+    """
+    try:
+        with timeout(seconds=5):
+            try:
+                parsed_x1 = parse_latex(x1)
+                parsed_x2 = parse_latex(x2)
+            except (
+                sympy.parsing.latex.errors.LaTeXParsingError,
+                sympy.SympifyError,
+                TypeError,
+            ):
+                eval_logger.debug(f"couldn't parse one of {x1} or {x2}")
+                return False
+
+            try:
+                diff = parsed_x1 - parsed_x2
+            except TypeError:
+                eval_logger.debug(f"couldn't subtract {x1} and {x2}")
+                return False
+
+            try:
+                if sympy.simplify(diff) == 0:
+                    return True
+                else:
+                    return False
+            except ValueError:
+                eval_logger.debug(
+                    f"Had some trouble simplifying when comparing {x1} and {x2}"
+                )
+    except TimeoutError:
+        eval_logger.debug(f"Timed out comparing {x1} and {x2}")
+        return False
+    except ImportError as e:
+        eval_logger.error(e)
+        raise
+    except Exception as e:
+        eval_logger.debug(f"Failed comparing {x1} and {x2} with {e}")
+        return False
+
+
+def get_unnormalized_answer(text: str) -> str:
+    INVALID_ANSWER = "[invalidanswer]"
+    end_seq = "I hope it is correct."
+    text += end_seq
+    match = re.search(
+        r"Final Answer: The final answer is(.*?). I hope it is correct.",
+        text,
+    )
+    if match:
+        return match.group(1).strip()
+    else:
+        return INVALID_ANSWER
+
+
+SUBSTITUTIONS = [
+    ("an ", ""),
+    ("a ", ""),
+    (".$", "$"),
+    ("\\$", ""),
+    (r"\ ", ""),
+    (" ", ""),
+    ("mbox", "text"),
+    (",\\text{and}", ","),
+    ("\\text{and}", ","),
+    ("\\text{m}", "\\text{}"),
+]
+REMOVED_EXPRESSIONS = [
+    "square",
+    "ways",
+    "integers",
+    "dollars",
+    "mph",
+    "inches",
+    "ft",
+    "hours",
+    "km",
+    "units",
+    "\\ldots",
+    "sue",
+    "points",
+    "feet",
+    "minutes",
+    "digits",
+    "cents",
+    "degrees",
+    "cm",
+    "gm",
+    "pounds",
+    "meters",
+    "meals",
+    "edges",
+    "students",
+    "childrentickets",
+    "multiples",
+    "\\text{s}",
+    "\\text{.}",
+    "\\text{\ns}",
+    "\\text{}^2",
+    "\\text{}^3",
+    "\\text{\n}",
+    "\\text{}",
+    r"\mathrm{th}",
+    r"^\circ",
+    r"^{\circ}",
+    r"\;",
+    r",\!",
+    "{,}",
+    '"',
+    "\\dots",
+]
+
+
+def normalize_final_answer(final_answer: str) -> str:
+    """
+    Normalize a final answer to a quantitative reasoning question.
+
+    Copied character for character from appendix D of Lewkowycz et al. (2022)
+    """
+    final_answer = final_answer.split("=")[-1]
+
+    for before, after in SUBSTITUTIONS:
+        final_answer = final_answer.replace(before, after)
+    for expr in REMOVED_EXPRESSIONS:
+        final_answer = final_answer.replace(expr, "")
+
+    # Extract answer that is in LaTeX math, is bold,
+    # is surrounded by a box, etc.
+    final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer)
+    final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer)
+    final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer)
+    final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer)
+    final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer)
+
+    # Normalize shorthand TeX:
+    #  \fracab -> \frac{a}{b}
+    #  \frac{abc}{bef} -> \frac{abc}{bef}
+    #  \fracabc -> \frac{a}{b}c
+    #  \sqrta -> \sqrt{a}
+    #  \sqrtab -> sqrt{a}b
+    final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer)
+    final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer)
+    final_answer = final_answer.replace("$", "")
+
+    # Normalize 100,000 -> 100000
+    if final_answer.replace(",", "").isdigit():
+        final_answer = final_answer.replace(",", "")
+
+    return final_answer

+ 6 - 0
tools/benchmarks/llm_eval_harness/meta_eval_reproduce/meta_template/meta_instruct.yaml

@@ -0,0 +1,6 @@
+group: meta_instruct
+task:
+  - meta_ifeval
+  - meta_math_hard
+  - meta_gpqa
+  - meta_mmlu_pro_instruct

+ 4 - 0
tools/benchmarks/llm_eval_harness/meta_eval_reproduce/meta_template/meta_pretrain.yaml

@@ -0,0 +1,4 @@
+group: meta_pretrain
+task:
+  - meta_bbh
+  - meta_mmlu_pro_pretrain

+ 29 - 0
tools/benchmarks/llm_eval_harness/meta_eval_reproduce/meta_template/mmlu_pro/mmlu_pro_5shot_cot_instruct.yaml

@@ -0,0 +1,29 @@
+task: meta_mmlu_pro_instruct
+dataset_path: meta-llama/Meta-Llama-3.1-8B-Instruct-evals
+dataset_name: Meta-Llama-3.1-8B-Instruct-evals__mmlu_pro__details
+test_split: latest
+output_type: generate_until
+process_docs: !function utils.process_docs
+doc_to_text: !function utils.doc_to_text
+doc_to_target: gold
+filter_list:
+  - name: "strict-match"
+    filter:
+      - function: "regex"
+        group_select: -1
+        regex_pattern: 'best answer is ([A-Z])'
+      - function: "take_first"
+generation_kwargs:
+  until: []
+  do_sample: false
+  temperature: 0
+  max_gen_toks: 1024
+num_fewshot: 0
+metric_list:
+  - metric: exact_match
+    aggregation: mean
+    higher_is_better: true
+    ignore_case: true
+    ignore_punctuation: true
+metadata:
+  version: 1.0

+ 28 - 0
tools/benchmarks/llm_eval_harness/meta_eval_reproduce/meta_template/mmlu_pro/mmlu_pro_5shot_cot_pretrain.yaml

@@ -0,0 +1,28 @@
+task: meta_mmlu_pro_pretrain
+dataset_path: meta-llama/Meta-Llama-3.1-8B-evals
+dataset_name: Meta-Llama-3.1-8B-evals__mmlu_pro__details
+test_split: latest
+output_type: generate_until
+process_docs: !function utils.process_docs
+doc_to_text: !function utils.doc_to_text
+doc_to_target: gold
+filter_list:
+  - name: "strict-match"
+    filter:
+      - function: "regex"
+        regex_pattern: 'answer is \(([A-Z])\)'
+      - function: "take_first"
+generation_kwargs:
+  until: "\n\nQ: "
+  do_sample: false
+  temperature: 0
+  max_gen_toks: 512
+num_fewshot: 0
+metric_list:
+  - metric: exact_match
+    aggregation: mean
+    higher_is_better: true
+    ignore_case: true
+    ignore_punctuation: true
+metadata:
+  version: 1.0

+ 21 - 0
tools/benchmarks/llm_eval_harness/meta_eval_reproduce/meta_template/mmlu_pro/utils.py

@@ -0,0 +1,21 @@
+import string
+
+
+import datasets
+
+
+
+def doc_to_text(doc: dict) -> str:
+    return doc["input_final_prompts"][0]
+
+def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
+    def _process_doc(doc: dict) -> dict:
+        out_doc = {
+            "problem": doc["input_question"],
+            "gold": doc["input_correct_responses"][0],
+        }
+        return out_doc
+    dataset = dataset.select_columns(["input_question", "input_correct_responses", "input_final_prompts", "is_correct","input_question_hash","input_choice_list","output_prediction_text"])
+    dataset = dataset.rename_column("is_correct","previously_is_correct")
+    dataset = dataset.map(_process_doc)
+    return dataset.map(_process_doc)

+ 237 - 0
tools/benchmarks/llm_eval_harness/meta_eval_reproduce/prepare_meta_eval.py

@@ -0,0 +1,237 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 3 Community License Agreement.
+
+import argparse
+import errno
+import shutil
+import glob
+import os
+from pathlib import Path
+import nltk
+import yaml
+from datasets import Dataset, load_dataset
+
+
+# get the ifeval  from the evals dataset and join it with the original ifeval datasets
+def get_ifeval_data(model_name, output_dir):
+    print(f"preparing the ifeval data using {model_name}'s evals dataset")
+    if model_name not in [
+        "Meta-Llama-3.1-8B-Instruct",
+        "Meta-Llama-3.1-70B-Instruct",
+        "Meta-Llama-3.1-405B-Instruct",
+    ]:
+        raise ValueError(
+            "Only Meta-Llama-3.1-8B-Instruct, Meta-Llama-3.1-70B-Instruct, Meta-Llama-3.1-405B-Instruct models are supported for IFEval"
+        )
+    original_dataset_name = "wis-k/instruction-following-eval"
+    meta_dataset_name = f"meta-llama/{model_name}-evals"
+    meta_data = load_dataset(
+        meta_dataset_name,
+        name=f"{model_name}-evals__ifeval__strict__details",
+        split="latest",
+    )
+    ifeval_data = load_dataset(original_dataset_name, split="train")
+    meta_data = meta_data.map(get_question)
+    meta_df = meta_data.to_pandas()
+    ifeval_df = ifeval_data.to_pandas()
+    ifeval_df = ifeval_df.rename(columns={"prompt": "input_question"})
+    # join the two datasets on the input_question column
+    joined = meta_df.join(ifeval_df.set_index("input_question"), on="input_question")
+    joined = joined.rename(columns={"input_final_prompts": "prompt"})
+    joined = joined.rename(columns={"is_correct": "previous_is_correct"})
+    joined = Dataset.from_pandas(joined)
+    joined = joined.select_columns(
+        [
+            "input_question",
+            "prompt",
+            "previous_is_correct",
+            "instruction_id_list",
+            "kwargs",
+            "output_prediction_text",
+            "key",
+        ]
+    )
+    joined.rename_column("output_prediction_text", "previous_output_prediction_text")
+    joined.to_parquet(output_dir + "/joined_ifeval.parquet")
+
+
+# get the math_hard data from the evals dataset and join it with the original math_hard dataset
+def get_math_data(model_name, output_dir):
+    print(f"preparing the math data using {model_name}'s evals dataset")
+    if model_name not in [
+        "Meta-Llama-3.1-8B-Instruct",
+        "Meta-Llama-3.1-70B-Instruct",
+        "Meta-Llama-3.1-405B-Instruct",
+    ]:
+        raise ValueError(
+            "Only Meta-Llama-3.1-8B-Instruct, Meta-Llama-3.1-70B-Instruct, Meta-Llama-3.1-405B-Instruct models are supported for MATH_hard"
+        )
+    original_dataset_name = "lighteval/MATH-Hard"
+    meta_dataset_name = f"meta-llama/{model_name}-evals"
+    meta_data = load_dataset(
+        meta_dataset_name,
+        name=f"{model_name}-evals__math_hard__details",
+        split="latest",
+    )
+    math_data = load_dataset(original_dataset_name, split="test")
+    meta_df = meta_data.to_pandas()
+    math_df = math_data.to_pandas()
+    math_df = math_df.rename(columns={"problem": "input_question"})
+    # join the two datasets on the input_question column
+    joined = meta_df.join(math_df.set_index("input_question"), on="input_question")
+    joined = Dataset.from_pandas(joined)
+    joined = joined.select_columns(
+        [
+            "input_question",
+            "input_correct_responses",
+            "input_final_prompts",
+            "is_correct",
+            "solution",
+            "output_prediction_text",
+        ]
+    )
+    joined = joined.rename_column("is_correct", "previous_is_correct")
+    joined = joined.rename_column(
+        "output_prediction_text", "previous_output_prediction_text"
+    )
+
+    joined.to_parquet(output_dir + "/joined_math.parquet")
+
+
+# get the question from the ifeval dataset
+def get_question(example):
+    try:
+        example["input_question"] = (
+            eval(
+                example["input_question"]
+                .replace("null", "None")
+                .replace("true", "True")
+                .replace("false", "False")
+            )["dialog"][0]["body"]
+            .replace("Is it True that the first song", "Is it true that the first song")
+            .replace("Is the following True", "Is the following true")
+        )
+        example["input_final_prompts"] = example["input_final_prompts"][0]
+        return example
+    except:
+        print(example["input_question"])
+        return
+
+
+# change the yaml file to use the correct model name
+def change_yaml(args, base_name):
+    for yaml_file in glob.glob(args.template_dir + "**/*/*.yaml", recursive=True):
+        with open(yaml_file, "r") as sources:
+            lines = sources.readlines()
+        output_path = yaml_file.replace(args.template_dir, args.work_dir)
+        print(f"changing {yaml_file} to output_path: {output_path}")
+        path = Path(output_path)
+        yaml_dir = path.parent
+        with open(output_path, "w") as output:
+            for line in lines:
+                output.write(
+                    line.replace("Meta-Llama-3.1-8B", base_name).replace(
+                        "WORK_DIR", str(yaml_dir)
+                    )
+                )
+
+
+# copy the files and change the yaml file to use the correct model name
+def copy_and_prepare(args):
+    # nltk punkt_tab package is needed
+    nltk.download('punkt_tab')
+    if not os.path.exists(args.work_dir):
+        # Copy the all files, including yaml files and python files, from template folder to the work folder
+
+        copy_dir(args.template_dir, args.work_dir)
+    else:
+        print("work_dir already exists, no need to copy files")
+    # Use the template yaml to get the correct model name in work_dir yaml
+    base_name = (
+        args.evals_dataset.split("/")[-1].replace("-evals", "").replace("-Instruct", "")
+    )
+    change_yaml(args, base_name)
+
+
+def parse_eval_args():
+    parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
+    parser.add_argument(
+        "--config_path",
+        type=str,
+        default="./eval_config.yaml",
+        help="the config yaml file that contains all the eval parameters",
+    )
+    return parser.parse_args()
+
+
+def prepare_datasets(args):
+    # Prepare the dataset for the IFeval and MATH_Hard tasks as we need to join the original dataset with the evals dataset by the actual questions.
+    # model_name are derived from the evals_dataset name
+    task_list = args.tasks.split(",")
+    model_name = args.evals_dataset.split("/")[-1].replace("-evals", "")
+    if "meta_instruct" in task_list:
+        get_ifeval_data(model_name, args.work_dir)
+
+        get_math_data(model_name, args.work_dir)
+    else:
+        if "meta_ifeval" in task_list:
+            get_ifeval_data(model_name, args.work_dir)
+        if "meta_math_hard" in task_list:
+            get_math_data(model_name, args.work_dir)
+
+
+# copy the files from src to dst
+def copy_dir(src, dst):
+    try:
+        shutil.copytree(src, dst)
+    except OSError as exc:  # python >2.5
+        if exc.errno in (errno.ENOTDIR, errno.EINVAL):
+            shutil.copy(src, dst)
+        else:
+            raise
+
+
+# load the config yaml file
+def load_config(config_path: str = "./config.yaml"):
+    # Read the YAML configuration file
+    with open(config_path, "r") as file:
+        config = yaml.safe_load(file)
+    return config
+
+
+if __name__ == "__main__":
+    args = parse_eval_args()
+    config = load_config(args.config_path)
+    # Create VLLM model args
+    for k, v in config.items():
+        args.__setattr__(k, v)
+    if not os.path.exists(args.template_dir):
+        raise ValueError("The template_dir does not exist, please check the path")
+    if args.evals_dataset not in [
+        "meta-llama/Meta-Llama-3.1-8B-Instruct-evals",
+        "meta-llama/Meta-Llama-3.1-70B-Instruct-evals",
+        "meta-llama/Meta-Llama-3.1-405B-Instruct-evals",
+        "meta-llama/Meta-Llama-3.1-8B-evals",
+        "meta-llama/Meta-Llama-3.1-70B-evals",
+        "meta-llama/Meta-Llama-3.1-405B-evals",
+    ]:
+        raise ValueError(
+            "The evals dataset is not valid, please double check the name, must use the name in the Llama 3.1 Evals collection"
+        )
+    args.model_args = f"pretrained={args.model_name},tensor_parallel_size={args.tensor_parallel_size},dtype=auto,gpu_memory_utilization={args.gpu_memory_utilization},data_parallel_size={args.data_parallel_size},max_model_len={args.max_model_len},add_bos_token=True,seed=42"
+    # Copy the all files from template folder to the work folder
+    copy_and_prepare(args)
+    # Prepare the datasets for the IFeval and MATH_Hard tasks as we need to join the original dataset
+    prepare_datasets(args)
+    print(
+        f"prepration for the {args.model_name} using {args.evals_dataset} is done, all saved the work_dir: {args.work_dir}"
+    )
+    command_str = f"lm_eval --model vllm   --model_args {args.model_args} --tasks {args.tasks} --batch_size auto --output_path { args.output_path} --include_path {os.path.abspath(args.work_dir)} --seed 42 "
+    if args.limit:
+        command_str += f" --limit {args.limit}"
+    if args.log_samples:
+        command_str += " --log_samples "
+    if args.show_config:
+        command_str += " --show_config "
+    print("please use the following command to run the meta reproduce evals:")
+    print(command_str)