evaluator.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. import json
  2. from typing import List, Dict
  3. from dataclasses import dataclass
  4. import dspy
  5. import os
  6. from datetime import datetime
  7. @dataclass
  8. class EvaluationMetrics:
  9. accuracy: float
  10. similarity: float
  11. consistency: float
  12. individual_scores: List[Dict] # Store individual test case scores
  13. class PromptEvaluator:
  14. def __init__(self, source_lm: dspy.LM, target_lm: dspy.LM):
  15. self.source_lm = source_lm
  16. self.target_lm = target_lm
  17. dspy.configure(lm=source_lm) # Configure DSPy to use source_lm for judge
  18. def _create_judge(self):
  19. """Create an LLM judge to evaluate outputs."""
  20. class OutputJudge(dspy.Signature):
  21. """Judge the quality and equivalence of outputs."""
  22. input_text = dspy.InputField(desc="The coding task")
  23. source_output = dspy.InputField(desc="Output from source prompt")
  24. target_output = dspy.InputField(desc="Output from target prompt")
  25. expected_output = dspy.InputField(desc="Expected output from dataset")
  26. equivalence = dspy.OutputField(
  27. desc="Are the outputs functionally equivalent to the expected output? Answer ONLY with 'yes' or 'no'."
  28. )
  29. accuracy = dspy.OutputField(
  30. desc="Rate how well the outputs match the expected output. Provide ONLY a number between 0 and 100, no text."
  31. )
  32. consistency = dspy.OutputField(
  33. desc="Rate how consistent the outputs are with each other. Provide ONLY a number between 0 and 100, no text."
  34. )
  35. reasoning = dspy.OutputField(
  36. desc="Explain your evaluation, focusing on functionality and correctness."
  37. )
  38. class Judge(dspy.Module):
  39. def __init__(self):
  40. super().__init__()
  41. self.judge = dspy.ChainOfThought(OutputJudge)
  42. def forward(self, input_text, source_output, target_output, expected_output):
  43. try:
  44. result = self.judge(
  45. input_text=input_text,
  46. source_output=source_output,
  47. target_output=target_output,
  48. expected_output=expected_output
  49. )
  50. # Ensure numeric scores
  51. def clean_score(score):
  52. try:
  53. # Extract just numbers
  54. import re
  55. numbers = re.findall(r'\d+', str(score))
  56. return float(numbers[0]) if numbers else 0.0
  57. except:
  58. return 0.0
  59. result.accuracy = clean_score(result.accuracy)
  60. result.consistency = clean_score(result.consistency)
  61. result.equivalence = str(result.equivalence).lower().strip()
  62. return result
  63. except Exception as e:
  64. print(f"Error in judge: {str(e)}")
  65. # Return default scores
  66. return type('Result', (), {
  67. 'accuracy': '0',
  68. 'consistency': '0',
  69. 'equivalence': 'no',
  70. 'reasoning': f'Error in evaluation: {str(e)}'
  71. })()
  72. return Judge()
  73. def _get_model_output(self, prompt: str, input_text: str) -> str:
  74. """Get output from target model using the provided prompt."""
  75. try:
  76. formatted_prompt = prompt.format(text=input_text)
  77. response = self.target_lm(formatted_prompt)
  78. if isinstance(response, list):
  79. return response[0] if response else ""
  80. return str(response)
  81. except Exception as e:
  82. print(f"Error generating output: {str(e)}")
  83. return ""
  84. def _calculate_metrics(self, source_prompt: str, target_prompt: str, test_cases: List[Dict]) -> EvaluationMetrics:
  85. """Calculate evaluation metrics using target model for both prompts."""
  86. total_similarity = 0.0
  87. total_accuracy = 0.0
  88. total_consistency = 0.0
  89. individual_scores = []
  90. judge = self._create_judge()
  91. num_cases = len(test_cases)
  92. for case in test_cases:
  93. input_text = case["text"]
  94. expected = case["expected_answer"]
  95. # Get outputs from target model using both prompts
  96. source_output = self._get_model_output(source_prompt, input_text)
  97. target_output = self._get_model_output(target_prompt, input_text)
  98. judgment = judge(
  99. input_text=input_text,
  100. source_output=source_output,
  101. target_output=target_output,
  102. expected_output=expected
  103. )
  104. # Calculate scores
  105. accuracy_score = float(judgment.accuracy) / 100
  106. consistency_score = float(judgment.consistency) / 100
  107. is_equivalent = judgment.equivalence.lower() == "yes"
  108. # Store individual scores
  109. case_scores = {
  110. "input": input_text,
  111. "expected": expected,
  112. "source_output": source_output,
  113. "target_output": target_output,
  114. "accuracy": accuracy_score,
  115. "consistency": consistency_score,
  116. "equivalent": is_equivalent,
  117. "reasoning": judgment.reasoning
  118. }
  119. individual_scores.append(case_scores)
  120. # Update totals
  121. total_accuracy += accuracy_score
  122. total_consistency += consistency_score
  123. total_similarity += float(is_equivalent)
  124. print(f"\nEvaluation for case: {input_text[:50]}...")
  125. print(f"Source output: {source_output[:100]}...")
  126. print(f"Target output: {target_output[:100]}...")
  127. print(f"Expected: {expected[:100]}...")
  128. print(f"Judge's reasoning: {judgment.reasoning}")
  129. print(f"Scores - Accuracy: {accuracy_score:.2f}, Consistency: {consistency_score:.2f}, Equivalent: {is_equivalent}")
  130. # Calculate final metrics
  131. metrics = EvaluationMetrics(
  132. accuracy=total_accuracy / num_cases,
  133. similarity=total_similarity / num_cases,
  134. consistency=total_consistency / num_cases,
  135. individual_scores=individual_scores
  136. )
  137. # Save results to JSON
  138. results = {
  139. "source_prompt": source_prompt,
  140. "target_prompt": target_prompt,
  141. "aggregate_metrics": {
  142. "accuracy": metrics.accuracy,
  143. "similarity": metrics.similarity,
  144. "consistency": metrics.consistency
  145. },
  146. "individual_scores": individual_scores
  147. }
  148. self._save_results(results)
  149. return metrics
  150. def evaluate(self,
  151. source_prompt: str,
  152. target_prompt: str,
  153. test_cases: List[Dict]) -> EvaluationMetrics:
  154. """Evaluates both prompts using the target model."""
  155. return self._calculate_metrics(source_prompt, target_prompt, test_cases)
  156. def _save_results(self, results: dict, filename: str = 'results.json') -> None:
  157. """Save results to a JSON file with a new name if the file already exists."""
  158. # Check if file exists
  159. if os.path.exists(filename):
  160. # Create new filename with timestamp
  161. timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
  162. base, ext = os.path.splitext(filename)
  163. filename = f"{base}_{timestamp}{ext}"
  164. # Save results
  165. with open(filename, 'w') as f:
  166. json.dump(results, f, indent=2)
  167. print(f"Results saved to {filename}")