llama_mmlu_pro.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. import typing as t
  2. import dspy
  3. from datasets import load_dataset
  4. from .datatypes import TaskDatasets
  5. from .helpers import fixed_split, train_val_test_split
  6. def datasets(
  7. train_size: float = 0.1,
  8. validation_size: float = 0.2,
  9. ) -> TaskDatasets:
  10. """
  11. Load dataset, dataset should be datasets.Dataset type (NOT DatasetDict, OR split the dataset yourself how you want)
  12. """
  13. dataset = load_dataset(
  14. "meta-llama/Llama-3.3-70B-Instruct-evals",
  15. "Llama-3.3-70B-Instruct-evals__mmlu_pro__details",
  16. )
  17. return fixed_split(dataset["latest"], _task_doc_example)
  18. class TaskDoc(t.TypedDict):
  19. task_type: str
  20. task_name: str
  21. subtask_name: str
  22. input_question: str
  23. input_choice_list: dict
  24. input_final_prompts: list
  25. input_correct_responses: list
  26. output_prediction_text: list
  27. output_parsed_answer: str
  28. output_choice_completions: t.Optional[dict]
  29. output_choice_negative_log_likelihoods: t.Optional[dict]
  30. output_metrics: dict
  31. is_correct: bool
  32. input_question_hash: str
  33. input_final_prompts_hash: list
  34. benchmark_label: str
  35. eval_config: dict
  36. inputs = ["input_question", "input_choice_list"]
  37. outputs = ["output_parsed_answer"]
  38. def _task_doc_example(doc: TaskDoc) -> dspy.Example:
  39. example = dspy.Example(
  40. question=doc["input_question"],
  41. options=doc["input_choice_list"],
  42. answer=doc["output_parsed_answer"],
  43. )
  44. example._input_keys = {"question", "options"}
  45. example._output_keys = {"answer"}
  46. return example
  47. def signature(instructions: str = "") -> dspy.Signature:
  48. class MMLUPro(dspy.Signature):
  49. __doc__ = instructions
  50. question: str = dspy.InputField(desc="The question to be answered")
  51. options: dict = dspy.InputField(desc="Dictionary of answer choices")
  52. answer: str = dspy.OutputField(desc="The correct answer letter")
  53. return MMLUPro
  54. def metric(gold: dspy.Example, pred: dspy.Example, trace=False) -> bool:
  55. return gold.answer == pred.answer