engine.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. import dspy
  2. from typing import List, Dict, Optional
  3. from dataclasses import dataclass
  4. @dataclass
  5. class PromptTemplate:
  6. template: str
  7. input_variables: List[str]
  8. model_type: str # 'openai' or 'llama'
  9. class PromptMigrationEngine:
  10. def __init__(self, source_lm: dspy.LM, target_lm: dspy.LM):
  11. self.source_lm = source_lm
  12. self.target_lm = target_lm
  13. dspy.configure(lm=source_lm)
  14. def _optimize_transformation(self, transformer, eval_dataset):
  15. """Optimize the transformation using the evaluation dataset."""
  16. class PromptQualityMetric:
  17. def __init__(self, source_lm, target_lm):
  18. self.source_lm = source_lm
  19. self.target_lm = target_lm
  20. def __call__(self, example, prediction, trace=None):
  21. if not hasattr(prediction, 'target'):
  22. return 0.0
  23. try:
  24. # Get outputs from both models using the prompts
  25. source_output = self.source_lm(example.source)
  26. target_output = self.target_lm(prediction.target)
  27. # Compare outputs (basic similarity)
  28. from difflib import SequenceMatcher
  29. similarity = SequenceMatcher(None,
  30. str(source_output),
  31. str(target_output)).ratio()
  32. return similarity
  33. except Exception as e:
  34. print(f"Error in metric: {e}")
  35. return 0.0
  36. optimizer = dspy.BootstrapFewShotWithRandomSearch(
  37. metric=PromptQualityMetric(self.source_lm, self.target_lm),
  38. max_bootstrapped_demos=2,
  39. max_labeled_demos=2,
  40. num_threads=1
  41. )
  42. # Prepare training data
  43. train_data = []
  44. for item in eval_dataset:
  45. # Create example with both prompt and expected output
  46. example = dspy.Example(
  47. source=item["text"],
  48. expected_output=item["expected_answer"]
  49. ).with_inputs("source")
  50. train_data.append(example)
  51. return optimizer.compile(transformer, trainset=train_data)
  52. def migrate_prompt(self,
  53. source_prompt: PromptTemplate,
  54. eval_dataset: Optional[List[Dict]] = None) -> PromptTemplate:
  55. """Migrates a prompt from source LM to target LM format."""
  56. class PromptTransformation(dspy.Signature):
  57. """Convert a prompt from one format to another."""
  58. source = dspy.InputField(desc="Source prompt template")
  59. target = dspy.OutputField(desc="Transformed prompt template that maintains functionality while adapting to target model format")
  60. class Transformer(dspy.Module):
  61. def __init__(self):
  62. super().__init__()
  63. self.chain = dspy.ChainOfThought(PromptTransformation)
  64. def forward(self, source):
  65. # Add context about the transformation task
  66. prompt = f"""
  67. Transform this prompt while:
  68. 1. Maintaining core functionality
  69. 2. Adapting to target model format
  70. 3. Preserving input variables
  71. 4. Keeping essential instructions
  72. Source prompt:
  73. {source}
  74. """
  75. return self.chain(source=prompt)
  76. transformer = Transformer()
  77. if eval_dataset:
  78. transformer = self._optimize_transformation(transformer, eval_dataset)
  79. result = transformer(source=source_prompt.template)
  80. # Format for target model
  81. if source_prompt.model_type == "openai" and "llama" in str(self.target_lm):
  82. result.target = f"### Instruction:\n{result.target}\n\n### Response:"
  83. return PromptTemplate(
  84. template=result.target,
  85. input_variables=source_prompt.input_variables,
  86. model_type='llama'
  87. )