llama_mmlu.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import typing as t
  2. import dspy
  3. from datasets import load_dataset
  4. from .datatypes import TaskDatasets
  5. from .helpers import fixed_split, train_val_test_split
  6. def datasets(
  7. train_size: float = 0.1,
  8. validation_size: float = 0.1,
  9. ) -> TaskDatasets:
  10. """
  11. Load dataset, dataset should be datasets.Dataset type (NOT DatasetDict, OR split the dataset yourself how you want)
  12. """
  13. dataset = load_dataset(
  14. "meta-llama/Llama-3.3-70B-Instruct-evals",
  15. "Llama-3.3-70B-Instruct-evals__mmlu__0_shot__cot__details",
  16. )
  17. return train_val_test_split(dataset["latest"], _task_doc_example)
  18. class TaskDoc(t.TypedDict):
  19. task_type: str
  20. task_name: str
  21. subtask_name: str
  22. input_question: str
  23. input_choice_list: dict
  24. input_final_prompts: list
  25. input_correct_responses: list
  26. output_prediction_text: list
  27. output_parsed_answer: str
  28. output_choice_completions: t.Optional[int]
  29. output_choice_negative_log_likelihoods: t.Optional[int]
  30. output_metrics: dict
  31. is_correct: bool
  32. input_question_hash: str
  33. input_final_prompts_hash: list
  34. benchmark_label: str
  35. eval_config: dict
  36. def _task_doc_example(doc: TaskDoc) -> dspy.Example:
  37. # Get reasoning from output_prediction_text if available
  38. # reasoning = (
  39. # doc["output_prediction_text"][0] if doc.get("output_prediction_text") else ""
  40. # )
  41. example = dspy.Example(
  42. question=doc["input_question"],
  43. options=doc["input_choice_list"],
  44. answer=doc["output_parsed_answer"],
  45. # reasoning=reasoning,
  46. )
  47. example._input_keys = {"question", "options"}
  48. example._output_keys = {"answer"}
  49. return example
  50. def signature(instructions: str = "") -> dspy.Signature:
  51. class MMLU(dspy.Signature):
  52. __doc__ = instructions
  53. question: str = dspy.InputField(desc="The question to be answered")
  54. options: dict = dspy.InputField(desc="Dictionary of answer choices")
  55. # reasoning: str = dspy.OutputField(
  56. # desc="Step-by-step reasoning for arriving at the answer"
  57. # )
  58. answer: str = dspy.OutputField(desc="The correct answer letter")
  59. return MMLU
  60. def metric(gold: dspy.Example, pred: dspy.Example, trace=False) -> bool:
  61. return gold.answer == pred.answer # Keep focusing on answer accuracy