main.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. import dspy
  2. from prompt_migration.engine import PromptMigrationEngine, PromptTemplate
  3. from prompt_migration.evaluator import PromptEvaluator
  4. from prompt_migration.eval_dataset import get_evaluation_dataset, get_eval_subset
  5. import os
  6. import dotenv
  7. dotenv.load_dotenv()
  8. def main():
  9. openai_lm = dspy.LM(
  10. model="gpt-3.5-turbo",
  11. api_key=os.getenv("OPENAI_API_KEY")
  12. )
  13. target_lm = dspy.LM(
  14. model="together_ai/meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
  15. api_key=os.getenv("TOGETHER_API_KEY")
  16. )
  17. # To run it with ollama
  18. # target_lm = dspy.LM('ollama_chat/llama3.2:3b-instruct-fp16', api_base='http://localhost:11434', api_key='')
  19. # To run it with huggingface
  20. # target_lm = dspy.HFModel(model="gpt2")
  21. engine = PromptMigrationEngine(openai_lm, target_lm)
  22. source_prompt = PromptTemplate(
  23. template="""You are an advanced Large Language Model tasked with generating Python code snippets in response to user prompts. Your primary objective is to provide accurate, concise, and well-structured Python functions. Follow these guidelines:
  24. Understand the Context: Analyze the input prompt and identify its category (e.g., API Usage, File Handling, Error Handling).
  25. Generate Code:
  26. Write Python code that directly addresses the user's request.
  27. Ensure the code is syntactically correct, functional, and adheres to Python best practices.
  28. Include necessary imports and handle potential edge cases.
  29. Error Handling:
  30. Include appropriate error handling where applicable (e.g., try-except blocks).
  31. If exceptions occur, provide meaningful error messages.
  32. Readability:
  33. Use clear variable names and include comments where necessary for clarity.
  34. Prioritize readability and maintainability in all generated code.
  35. Complexity Alignment:
  36. Tailor the code's complexity based on the indicated difficulty (e.g., simple, medium, complex).
  37. Ensure that the solution is neither overly simplistic nor unnecessarily complicated.
  38. Prompt Type:
  39. Focus on the code_generation type for creating Python functions.
  40. Avoid deviating from the task unless additional clarification is requested.
  41. Testing and Validity:
  42. Assume the function might be run immediately. Provide code that is ready for use or minimal adaptation.
  43. Highlight any dependencies or external libraries required.
  44. """,
  45. input_variables=["text"],
  46. model_type="openai"
  47. )
  48. eval_dataset = get_evaluation_dataset()
  49. # To evaluate on a specific subset, use the following:
  50. code_generation_dataset = get_eval_subset(prompt_type="code_generation")
  51. #simple_tasks = get_eval_subset(complexity="simple")
  52. evaluator = PromptEvaluator(openai_lm, target_lm)
  53. metrics = evaluator.evaluate(
  54. source_prompt.template, # Same prompt for both
  55. source_prompt.template, # Same prompt for both
  56. code_generation_dataset
  57. )
  58. print(f"Evaluation metrics:")
  59. print(f" Accuracy: {metrics.accuracy:.2f}")
  60. print(f" Similarity: {metrics.similarity:.2f}")
  61. print(f" Consistency: {metrics.consistency:.2f}")
  62. # Migrate prompt
  63. print("Migrating prompt...")
  64. migrated_prompt = engine.migrate_prompt(source_prompt, code_generation_dataset)
  65. # Evaluate migration
  66. print("Evaluating migration...")
  67. metrics = evaluator.evaluate(
  68. source_prompt.template,
  69. migrated_prompt.template,
  70. code_generation_dataset
  71. )
  72. print(f"\nResults:")
  73. print(f"Original prompt: {source_prompt.template}")
  74. print(f"Migrated prompt: {migrated_prompt.template}")
  75. print(f"Evaluation metrics:")
  76. print(f" Accuracy: {metrics.accuracy:.2f}")
  77. print(f" Similarity: {metrics.similarity:.2f}")
  78. print(f" Consistency: {metrics.consistency:.2f}")
  79. if __name__ == "__main__":
  80. main()