main.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. import dspy
  2. from prompt_migration.engine import PromptMigrationEngine, PromptTemplate
  3. from prompt_migration.evaluator import PromptEvaluator
  4. from prompt_migration.eval_dataset import get_evaluation_dataset, get_eval_subset
  5. import os
  6. import dotenv
  7. dotenv.load_dotenv()
  8. def main():
  9. openai_lm = dspy.LM(
  10. model="gpt-3.5-turbo",
  11. api_key=os.getenv("OPENAI_API_KEY")
  12. )
  13. # target_lm = dspy.LM(
  14. # model="together_ai/togethercomputer/meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
  15. # api_key=os.getenv("TOGETHER_API_KEY")
  16. # )
  17. # target_lm = dspy.LM('ollama_chat/llama3.2:3b-instruct-fp16', api_base='http://localhost:11434', api_key='')
  18. target_lm = dspy.HFModel(model="gpt2")
  19. engine = PromptMigrationEngine(openai_lm, target_lm)
  20. source_prompt = PromptTemplate(
  21. template="Write a Python function that takes as input a file path to an image, loads the image into memory as a numpy array, then crops the rows and columns around the perimeter if they are darker than a threshold value. Use the mean value of rows and columns to decide if they should be marked for deletion.",
  22. input_variables=["text"],
  23. model_type="openai"
  24. )
  25. eval_dataset = get_evaluation_dataset()
  26. # To evaluate on a specific subset, use the following:
  27. #summarization_dataset = get_eval_subset(prompt_type="summarization")
  28. #simple_tasks = get_eval_subset(complexity="simple")
  29. # Migrate prompt
  30. print("Migrating prompt...")
  31. migrated_prompt = engine.migrate_prompt(source_prompt, eval_dataset)
  32. # Evaluate migration
  33. print("Evaluating migration...")
  34. evaluator = PromptEvaluator(openai_lm, target_lm)
  35. metrics = evaluator.evaluate(
  36. source_prompt.template,
  37. migrated_prompt.template,
  38. eval_dataset
  39. )
  40. print(f"\nResults:")
  41. print(f"Original prompt: {source_prompt.template}")
  42. print(f"Migrated prompt: {migrated_prompt.template}")
  43. print(f"Evaluation metrics:")
  44. print(f" Accuracy: {metrics.accuracy:.2f}")
  45. print(f" Similarity: {metrics.similarity:.2f}")
  46. print(f" Consistency: {metrics.consistency:.2f}")
  47. if __name__ == "__main__":
  48. main()