| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151 |
- import json
- import re
- import typing as t
- import dspy
- from datasets import load_dataset
- from .datatypes import TaskDatasets
- from .helpers import train_val_test_split
- def datasets(
- train_size: float = 0.1,
- validation_size: float = 0.2,
- ) -> TaskDatasets:
- """
- Load dataset, dataset should be datasets.Dataset type (NOT DatasetDict, OR split the dataset yourself how you want)
- """
- dataset = load_dataset(
- "meta-llama/Llama-3.3-70B-Instruct-evals",
- "Llama-3.3-70B-Instruct-evals__mmlu_pro__details",
- )
- return train_val_test_split(
- dataset["latest"],
- _task_doc_example,
- train_size,
- validation_size,
- )
- class TaskDoc(t.TypedDict):
- task_type: str
- task_name: str
- subtask_name: str
- input_question: str
- input_choice_list: dict
- input_final_prompts: list
- input_correct_responses: list
- output_prediction_text: list
- output_parsed_answer: str
- output_choice_completions: t.Optional[dict]
- output_choice_negative_log_likelihoods: t.Optional[dict]
- output_metrics: dict
- is_correct: bool
- input_question_hash: str
- input_final_prompts_hash: list
- benchmark_label: str
- eval_config: dict
- inputs = ["input_question", "input_choice_list"]
- outputs = ["output_parsed_answer"]
- class CustomJSONAdapter(dspy.JSONAdapter):
- def parse(self, signature, completion):
- try:
- try:
- fields = json.loads(completion)
- except:
- fields = {"reasoning": completion, "answer": ""}
- if isinstance(fields, list):
- fields = fields[0] if fields else {"reasoning": "", "answer": ""}
- if "reasoning" not in fields:
- fields["reasoning"] = ""
- if "answer" not in fields:
- reasoning = fields.get("reasoning", "")
- match = re.search(
- r"\b([A-J])\b|answer\s+is\s+([A-J])\b", reasoning, re.IGNORECASE
- )
- fields["answer"] = (
- (match.group(1) or match.group(2)).upper() if match else ""
- )
- return fields
- except Exception as e:
- return {"reasoning": "", "answer": ""}
- def signature(instructions: str = "") -> dspy.Signature:
- """Define the signature for MMLU Pro task."""
- class MMLUPro(dspy.Signature):
- """Multiple choice question answering with reasoning."""
- question: str = dspy.InputField(desc="The question to be answered")
- options: dict = dspy.InputField(desc="Dictionary of answer choices")
- reasoning: str = dspy.OutputField(
- desc="Step by step reasoning to arrive at the answer"
- )
- answer: str = dspy.OutputField(desc="The correct answer letter (A-J)")
- dspy.settings.configure(adapter=CustomJSONAdapter())
- return MMLUPro
- def _task_doc_example(doc: TaskDoc) -> dspy.Example:
- """Create an example with proper input/output key configuration."""
- example = dspy.Example(
- question=doc["input_question"],
- options=doc["input_choice_list"],
- reasoning="", # Initialize empty reasoning
- answer=doc["output_parsed_answer"] if doc["output_parsed_answer"] else "",
- )
- example._input_keys = {"question", "options"}
- example._output_keys = {"reasoning", "answer"}
- return example
- def metric(gold: dspy.Example, pred: dspy.Example, trace=False) -> bool:
- """
- Compares gold and predicted answers while handling various response formats.
- Ensures answer field is always present by extracting from reasoning if needed.
- """
- try:
- pred_dict = pred if isinstance(pred, dict) else pred.__dict__
- reasoning = pred_dict.get("reasoning", "")
- if isinstance(reasoning, str) and "answer" not in pred_dict:
- match = re.search(
- r"\b([A-J])\b|answer\s+is\s+([A-J])\b", reasoning, re.IGNORECASE
- )
- if match:
- answer = match.group(1) or match.group(2)
- pred_dict["answer"] = answer.upper()
- pred_answer = pred_dict.get("answer", "")
- if isinstance(pred_answer, str):
- pred_answer = pred_answer.strip().upper()
- if len(pred_answer) > 1:
- pred_answer = pred_answer[0]
- gold_answer = gold.answer if hasattr(gold, "answer") else ""
- if isinstance(gold_answer, str):
- gold_answer = gold_answer.strip().upper()
- # Handle empty answers
- if not gold_answer or not pred_answer:
- return False
- return gold_answer == pred_answer
- except Exception as e:
- if trace:
- print(f"Error in metric: {str(e)}")
- return False
|