evaluator.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. import dspy
  2. from typing import List, Dict
  3. from dataclasses import dataclass
  4. @dataclass
  5. class EvaluationMetrics:
  6. accuracy: float
  7. similarity: float
  8. consistency: float
  9. class PromptEvaluator:
  10. def __init__(self, source_lm: dspy.OpenAI, target_lm: dspy.LM):
  11. self.source_lm = source_lm
  12. self.target_lm = target_lm
  13. def _create_judge(self):
  14. """Create an LLM judge to evaluate prompt outputs."""
  15. class FactJudge(dspy.Signature):
  16. """Judge if the migrated prompt produces equivalent outputs."""
  17. source_output = dspy.InputField(desc="Output from source model")
  18. target_output = dspy.InputField(desc="Output from target model")
  19. factually_correct = dspy.OutputField(
  20. desc="Is the target output equivalent to the source output in terms of content and intent?",
  21. prefix="Factual[Yes/No]:"
  22. )
  23. reasoning = dspy.OutputField(desc="Explanation for the judgment")
  24. return dspy.ChainOfThought(FactJudge)
  25. def _get_model_output(self, model, text: str) -> str:
  26. """Helper function to get output from different model types."""
  27. try:
  28. # Try different methods since DSPy model interfaces can vary
  29. if hasattr(model, '__call__'):
  30. return model(text)
  31. elif hasattr(model, 'generate'):
  32. return model.generate(text)
  33. elif hasattr(model, 'complete'):
  34. return model.complete(text)
  35. else:
  36. raise AttributeError(f"Model {type(model)} has no supported generation method")
  37. except Exception as e:
  38. print(f"Error generating output with {type(model)}: {str(e)}")
  39. return ""
  40. def _calculate_metrics(self, evaluator, test_cases):
  41. """Calculate evaluation metrics using LLM as judge."""
  42. total_similarity = 0.0
  43. total_accuracy = 0.0
  44. total_consistency = 0.0
  45. judge = self._create_judge()
  46. for case in test_cases:
  47. source_output = self._get_model_output(self.source_lm, case["text"])
  48. target_output = self._get_model_output(self.target_lm, case["text"])
  49. judgment = judge(
  50. source_output=source_output,
  51. target_output=target_output
  52. )
  53. is_equivalent = judgment.factually_correct.lower() == "yes"
  54. similarity = float(is_equivalent)
  55. accuracy = float(target_output.lower() == case["expected_summary"].lower())
  56. consistency = float(is_equivalent)
  57. total_similarity += similarity
  58. total_accuracy += accuracy
  59. total_consistency += consistency
  60. print(f"\nJudge's reasoning: {judgment.reasoning}")
  61. n = len(test_cases)
  62. return EvaluationMetrics(
  63. accuracy=total_accuracy / n,
  64. similarity=total_similarity / n,
  65. consistency=total_consistency / n
  66. )
  67. def evaluate(self,
  68. source_prompt: str,
  69. target_prompt: str,
  70. test_cases: List[Dict]) -> EvaluationMetrics:
  71. """Evaluates the quality of prompt migration using LLM as judge."""
  72. metrics = self._calculate_metrics(None, test_cases) # evaluator param not needed anymore
  73. return metrics