llama_mmlu_pro.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. import typing as t
  2. import dspy
  3. from datasets import load_dataset
  4. from .datatypes import TaskDatasets
  5. from .helpers import train_val_test_split
  6. def datasets(
  7. train_size: float = 0.1,
  8. validation_size: float = 0.2,
  9. ) -> TaskDatasets:
  10. """
  11. Load dataset, dataset should be datasets.Dataset type (NOT DatasetDict, OR split the dataset yourself how you want)
  12. """
  13. dataset = load_dataset(
  14. "meta-llama/Llama-3.3-70B-Instruct-evals",
  15. "Llama-3.3-70B-Instruct-evals__mmlu_pro__details",
  16. )
  17. return train_val_test_split(
  18. dataset["latest"],
  19. _task_doc_example,
  20. train_size,
  21. validation_size,
  22. )
  23. class TaskDoc(t.TypedDict):
  24. task_type: str
  25. task_name: str
  26. subtask_name: str
  27. input_question: str
  28. input_choice_list: dict
  29. input_final_prompts: list
  30. input_correct_responses: list
  31. output_prediction_text: list
  32. output_parsed_answer: str
  33. output_choice_completions: t.Optional[dict]
  34. output_choice_negative_log_likelihoods: t.Optional[dict]
  35. output_metrics: dict
  36. is_correct: bool
  37. input_question_hash: str
  38. input_final_prompts_hash: list
  39. benchmark_label: str
  40. eval_config: dict
  41. inputs = ["input_question", "input_choice_list"]
  42. outputs = ["output_parsed_answer"]
  43. def _task_doc_example(doc: TaskDoc) -> dspy.Example:
  44. example = dspy.Example(
  45. question=doc["input_question"],
  46. options=doc["input_choice_list"],
  47. reasoning="",
  48. answer=doc["output_parsed_answer"],
  49. )
  50. example._input_keys = {"question", "options"}
  51. example._output_keys = {"reasoning", "answer"}
  52. return example
  53. def signature(instructions: str = "") -> dspy.Signature:
  54. class MMLUPro(dspy.Signature):
  55. __doc__ = instructions
  56. question: str = dspy.InputField(desc="The question to be answered")
  57. options: dict = dspy.InputField(desc="Dictionary of answer choices")
  58. answer: str = dspy.OutputField(desc="The correct answer letter")
  59. return MMLUPro
  60. def metric(gold: dspy.Example, pred: dspy.Example, trace=False) -> bool:
  61. return gold.answer == pred.answer