llama_mmlu_pro.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import typing as t
  2. import dspy
  3. from datasets import load_dataset
  4. from .datatypes import TaskDatasets
  5. from .helpers import train_val_test_split
  6. def datasets(
  7. train_size: float = 0.1,
  8. validation_size: float = 0.2,
  9. ) -> TaskDatasets:
  10. """
  11. Load dataset, dataset should be datasets.Dataset type (NOT DatasetDict, OR split the dataset yourself how you want)
  12. """
  13. dataset = load_dataset(
  14. "meta-llama/Llama-3.3-70B-Instruct-evals",
  15. "Llama-3.3-70B-Instruct-evals__mmlu_pro__details",
  16. )
  17. return train_val_test_split(
  18. dataset["latest"],
  19. _task_doc_example,
  20. train_size,
  21. validation_size,
  22. )
  23. class TaskDoc(t.TypedDict):
  24. task_type: str
  25. task_name: str
  26. subtask_name: str
  27. input_question: str
  28. input_choice_list: dict
  29. input_final_prompts: list
  30. input_correct_responses: list
  31. output_prediction_text: list
  32. output_parsed_answer: str
  33. output_choice_completions: t.Optional[dict]
  34. output_choice_negative_log_likelihoods: t.Optional[dict]
  35. output_metrics: dict
  36. is_correct: bool
  37. input_question_hash: str
  38. input_final_prompts_hash: list
  39. benchmark_label: str
  40. eval_config: dict
  41. inputs = ["input_question", "input_choice_list"]
  42. outputs = ["output_parsed_answer"]
  43. def _task_doc_example(doc: TaskDoc) -> dspy.Example:
  44. example = dspy.Example(
  45. question=doc["input_question"],
  46. options=doc["input_choice_list"],
  47. answer=doc["output_parsed_answer"],
  48. )
  49. example._input_keys = {"question", "options"}
  50. example._output_keys = {"answer"}
  51. return example
  52. def signature(instructions: str = "") -> dspy.Signature:
  53. class MMLUPro(dspy.Signature):
  54. __doc__ = instructions
  55. question: str = dspy.InputField(desc="The question to be answered")
  56. options: dict = dspy.InputField(desc="Dictionary of answer choices")
  57. answer: str = dspy.OutputField(desc="The correct answer letter")
  58. return MMLUPro
  59. def metric(gold: dspy.Example, pred: dspy.Example, trace=False) -> bool:
  60. return gold.answer == pred.answer