llama_mmlu_pro.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. import json
  2. import re
  3. import typing as t
  4. import dspy
  5. from datasets import load_dataset
  6. from .datatypes import TaskDatasets
  7. from .helpers import train_val_test_split
  8. def datasets(
  9. train_size: float = 0.1,
  10. validation_size: float = 0.2,
  11. ) -> TaskDatasets:
  12. """
  13. Load dataset, dataset should be datasets.Dataset type (NOT DatasetDict, OR split the dataset yourself how you want)
  14. """
  15. dataset = load_dataset(
  16. "meta-llama/Llama-3.3-70B-Instruct-evals",
  17. "Llama-3.3-70B-Instruct-evals__mmlu_pro__details",
  18. )
  19. return train_val_test_split(
  20. dataset["latest"],
  21. _task_doc_example,
  22. train_size,
  23. validation_size,
  24. )
  25. class TaskDoc(t.TypedDict):
  26. task_type: str
  27. task_name: str
  28. subtask_name: str
  29. input_question: str
  30. input_choice_list: dict
  31. input_final_prompts: list
  32. input_correct_responses: list
  33. output_prediction_text: list
  34. output_parsed_answer: str
  35. output_choice_completions: t.Optional[dict]
  36. output_choice_negative_log_likelihoods: t.Optional[dict]
  37. output_metrics: dict
  38. is_correct: bool
  39. input_question_hash: str
  40. input_final_prompts_hash: list
  41. benchmark_label: str
  42. eval_config: dict
  43. inputs = ["input_question", "input_choice_list"]
  44. outputs = ["output_parsed_answer"]
  45. class CustomJSONAdapter(dspy.JSONAdapter):
  46. def parse(self, signature, completion):
  47. try:
  48. try:
  49. fields = json.loads(completion)
  50. except:
  51. fields = {"reasoning": completion, "answer": ""}
  52. if isinstance(fields, list):
  53. fields = fields[0] if fields else {"reasoning": "", "answer": ""}
  54. if "reasoning" not in fields:
  55. fields["reasoning"] = ""
  56. if "answer" not in fields:
  57. reasoning = fields.get("reasoning", "")
  58. match = re.search(
  59. r"\b([A-J])\b|answer\s+is\s+([A-J])\b", reasoning, re.IGNORECASE
  60. )
  61. fields["answer"] = (
  62. (match.group(1) or match.group(2)).upper() if match else ""
  63. )
  64. return fields
  65. except Exception as e:
  66. return {"reasoning": "", "answer": ""}
  67. def signature(instructions: str = "") -> dspy.Signature:
  68. """Define the signature for MMLU Pro task."""
  69. class MMLUPro(dspy.Signature):
  70. """Multiple choice question answering with reasoning."""
  71. question: str = dspy.InputField(desc="The question to be answered")
  72. options: dict = dspy.InputField(desc="Dictionary of answer choices")
  73. reasoning: str = dspy.OutputField(
  74. desc="Step by step reasoning to arrive at the answer"
  75. )
  76. answer: str = dspy.OutputField(desc="The correct answer letter (A-J)")
  77. dspy.settings.configure(adapter=CustomJSONAdapter())
  78. return MMLUPro
  79. def _task_doc_example(doc: TaskDoc) -> dspy.Example:
  80. """Create an example with proper input/output key configuration."""
  81. example = dspy.Example(
  82. question=doc["input_question"],
  83. options=doc["input_choice_list"],
  84. reasoning="", # Initialize empty reasoning
  85. answer=doc["output_parsed_answer"] if doc["output_parsed_answer"] else "",
  86. )
  87. example._input_keys = {"question", "options"}
  88. example._output_keys = {"reasoning", "answer"}
  89. return example
  90. def metric(gold: dspy.Example, pred: dspy.Example, trace=False) -> bool:
  91. """
  92. Compares gold and predicted answers while handling various response formats.
  93. Ensures answer field is always present by extracting from reasoning if needed.
  94. """
  95. try:
  96. pred_dict = pred if isinstance(pred, dict) else pred.__dict__
  97. reasoning = pred_dict.get("reasoning", "")
  98. if isinstance(reasoning, str) and "answer" not in pred_dict:
  99. match = re.search(
  100. r"\b([A-J])\b|answer\s+is\s+([A-J])\b", reasoning, re.IGNORECASE
  101. )
  102. if match:
  103. answer = match.group(1) or match.group(2)
  104. pred_dict["answer"] = answer.upper()
  105. pred_answer = pred_dict.get("answer", "")
  106. if isinstance(pred_answer, str):
  107. pred_answer = pred_answer.strip().upper()
  108. if len(pred_answer) > 1:
  109. pred_answer = pred_answer[0]
  110. gold_answer = gold.answer if hasattr(gold, "answer") else ""
  111. if isinstance(gold_answer, str):
  112. gold_answer = gold_answer.strip().upper()
  113. # Handle empty answers
  114. if not gold_answer or not pred_answer:
  115. return False
  116. return gold_answer == pred_answer
  117. except Exception as e:
  118. if trace:
  119. print(f"Error in metric: {str(e)}")
  120. return False