engine.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import dspy
  2. from typing import List, Dict, Optional
  3. from dataclasses import dataclass
  4. @dataclass
  5. class PromptTemplate:
  6. template: str
  7. input_variables: List[str]
  8. model_type: str # 'openai' or 'llama'
  9. class PromptMigrationEngine:
  10. def __init__(self, source_lm: dspy.OpenAI, target_lm: dspy.LM):
  11. self.source_lm = source_lm
  12. self.target_lm = target_lm
  13. dspy.configure(lm=source_lm)
  14. def _optimize_transformation(self, transformer, eval_dataset):
  15. """Optimize the transformation using the evaluation dataset."""
  16. class AccuracyMetric:
  17. def __call__(self, example, prediction, trace=None):
  18. return float(prediction.target == example.expected_output)
  19. optimizer = dspy.BootstrapFewShotWithRandomSearch(
  20. metric=AccuracyMetric(),
  21. max_bootstrapped_demos=4,
  22. max_labeled_demos=4,
  23. num_threads=4
  24. )
  25. train_data = [
  26. dspy.Example(
  27. source=item["text"],
  28. expected_output=item["expected_summary"]
  29. ).with_inputs("source") for item in eval_dataset
  30. ]
  31. return optimizer.compile(transformer, trainset=train_data)
  32. def migrate_prompt(self,
  33. source_prompt: PromptTemplate,
  34. eval_dataset: Optional[List[Dict]] = None) -> PromptTemplate:
  35. """Migrates a prompt from source LM to target LM format."""
  36. class PromptTransformation(dspy.Signature):
  37. """Convert a prompt from one format to another."""
  38. source = dspy.InputField(desc="Source prompt template")
  39. target = dspy.OutputField(desc="Transformed prompt template")
  40. class Transformer(dspy.Module):
  41. def __init__(self):
  42. super().__init__()
  43. self.chain = dspy.ChainOfThought(PromptTransformation)
  44. def forward(self, source):
  45. return self.chain(source=source)
  46. transformer = Transformer()
  47. if eval_dataset:
  48. transformer = self._optimize_transformation(transformer, eval_dataset)
  49. result = transformer(source=source_prompt.template)
  50. return PromptTemplate(
  51. template=result.target,
  52. input_variables=source_prompt.input_variables,
  53. model_type='llama'
  54. )