浏览代码

added evaluator and formatter and main

Justin Lee 4 月之前
父节点
当前提交
2570d1642a

+ 61 - 0
recipes/use_cases/prompt-migration/main.py

@@ -0,0 +1,61 @@
+import dspy
+from prompt_migration.engine import PromptMigrationEngine, PromptTemplate
+from prompt_migration.evaluator import PromptEvaluator
+from prompt_migration.eval_dataset import get_evaluation_dataset, get_eval_subset
+
+import os
+import dotenv
+
+dotenv.load_dotenv()
+
+def main():
+    openai_lm = dspy.LM(
+        model="gpt-3.5-turbo",
+        api_key=os.getenv("OPENAI_API_KEY")
+    )
+    
+    # target_lm = dspy.LM(
+    #     model="together_ai/togethercomputer/meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
+    #     api_key=os.getenv("TOGETHER_API_KEY")
+    # )
+    # target_lm = dspy.LM('ollama_chat/llama3.2:3b-instruct-fp16', api_base='http://localhost:11434', api_key='')
+    target_lm = dspy.HFModel(model="gpt2")
+    
+    engine = PromptMigrationEngine(openai_lm, target_lm)
+    
+    source_prompt = PromptTemplate(
+        template="Write a Python function that takes as input a file path to an image, loads the image into memory as a numpy array, then crops the rows and columns around the perimeter if they are darker than a threshold value. Use the mean value of rows and columns to decide if they should be marked for deletion.",
+        input_variables=["text"],
+        model_type="openai"
+    )
+    
+    eval_dataset = get_evaluation_dataset()
+
+
+    # To evaluate on a specific subset, use the following:
+    #summarization_dataset = get_eval_subset(prompt_type="summarization")
+    #simple_tasks = get_eval_subset(complexity="simple")
+    
+    # Migrate prompt
+    print("Migrating prompt...")
+    migrated_prompt = engine.migrate_prompt(source_prompt, eval_dataset)
+    
+    # Evaluate migration
+    print("Evaluating migration...")
+    evaluator = PromptEvaluator(openai_lm, target_lm)
+    metrics = evaluator.evaluate(
+        source_prompt.template,
+        migrated_prompt.template,
+        eval_dataset
+    )
+    
+    print(f"\nResults:")
+    print(f"Original prompt: {source_prompt.template}")
+    print(f"Migrated prompt: {migrated_prompt.template}")
+    print(f"Evaluation metrics:")
+    print(f"  Accuracy: {metrics.accuracy:.2f}")
+    print(f"  Similarity: {metrics.similarity:.2f}")
+    print(f"  Consistency: {metrics.consistency:.2f}")
+
+if __name__ == "__main__":
+    main() 

+ 90 - 0
recipes/use_cases/prompt-migration/prompt_migration/evaluator.py

@@ -0,0 +1,90 @@
+import dspy
+from typing import List, Dict
+from dataclasses import dataclass
+
+@dataclass
+class EvaluationMetrics:
+    accuracy: float
+    similarity: float
+    consistency: float
+
+class PromptEvaluator:
+    def __init__(self, source_lm: dspy.OpenAI, target_lm: dspy.LM):
+        self.source_lm = source_lm
+        self.target_lm = target_lm
+        
+    def _create_judge(self):
+        """Create an LLM judge to evaluate prompt outputs."""
+        class FactJudge(dspy.Signature):
+            """Judge if the migrated prompt produces equivalent outputs."""
+            source_output = dspy.InputField(desc="Output from source model")
+            target_output = dspy.InputField(desc="Output from target model")
+            factually_correct = dspy.OutputField(
+                desc="Is the target output equivalent to the source output in terms of content and intent?",
+                prefix="Factual[Yes/No]:"
+            )
+            reasoning = dspy.OutputField(desc="Explanation for the judgment")
+
+        return dspy.ChainOfThought(FactJudge)
+
+    def _get_model_output(self, model, text: str) -> str:
+        """Helper function to get output from different model types."""
+        try:
+            # Try different methods since DSPy model interfaces can vary
+            if hasattr(model, '__call__'):
+                return model(text)
+            elif hasattr(model, 'generate'):
+                return model.generate(text)
+            elif hasattr(model, 'complete'):
+                return model.complete(text)
+            else:
+                raise AttributeError(f"Model {type(model)} has no supported generation method")
+        except Exception as e:
+            print(f"Error generating output with {type(model)}: {str(e)}")
+            return ""
+
+    def _calculate_metrics(self, evaluator, test_cases):
+        """Calculate evaluation metrics using LLM as judge."""
+        total_similarity = 0.0
+        total_accuracy = 0.0
+        total_consistency = 0.0
+        
+        judge = self._create_judge()
+        
+        for case in test_cases:
+            source_output = self._get_model_output(self.source_lm, case["text"])
+            target_output = self._get_model_output(self.target_lm, case["text"])
+            
+            judgment = judge(
+                source_output=source_output,
+                target_output=target_output
+            )
+            
+            is_equivalent = judgment.factually_correct.lower() == "yes"
+            
+            similarity = float(is_equivalent)
+            accuracy = float(target_output.lower() == case["expected_summary"].lower())
+            consistency = float(is_equivalent)
+            
+            total_similarity += similarity
+            total_accuracy += accuracy
+            total_consistency += consistency
+            
+            print(f"\nJudge's reasoning: {judgment.reasoning}")
+        
+        n = len(test_cases)
+        return EvaluationMetrics(
+            accuracy=total_accuracy / n,
+            similarity=total_similarity / n,
+            consistency=total_consistency / n
+        )
+    
+    def evaluate(self, 
+                source_prompt: str, 
+                target_prompt: str, 
+                test_cases: List[Dict]) -> EvaluationMetrics:
+        """Evaluates the quality of prompt migration using LLM as judge."""
+        
+        metrics = self._calculate_metrics(None, test_cases)  # evaluator param not needed anymore
+        
+        return metrics

+ 17 - 0
recipes/use_cases/prompt-migration/prompt_migration/formatters.py

@@ -0,0 +1,17 @@
+from typing import List
+
+class PromptFormatter:
+    @staticmethod
+    def openai_to_llama(prompt: str) -> str:
+        """Convert OpenAI-style prompts to Llama format."""
+        # Basic conversion logic
+        converted = prompt.replace("{{", "{").replace("}}", "}")
+        return converted
+    
+    @staticmethod
+    def extract_variables(prompt: str) -> List[str]:
+        """Extract variable names from a prompt template."""
+        import re
+        pattern = r"\{([^}]+)\}"
+        matches = re.findall(pattern, prompt)
+        return list(set(matches))