helpers.py 708 B

1234567891011121314151617181920212223242526
  1. import typing as t
  2. from .datatypes import TaskDatasets
  3. if t.TYPE_CHECKING:
  4. from datasets import Dataset
  5. import dspy
  6. def train_val_test_split(
  7. dataset: "Dataset",
  8. mapper: t.Callable[[dict], "dspy.Example"],
  9. train_size: float = 0.1,
  10. validation_size: float = 0.2,
  11. ) -> TaskDatasets:
  12. docs = dataset.train_test_split(train_size=train_size)
  13. train_docs = docs["train"]
  14. docs = docs["test"].train_test_split(train_size=validation_size)
  15. validation_docs = docs["train"]
  16. test_docs = docs["test"]
  17. return TaskDatasets(
  18. trainset=list(map(mapper, train_docs)),
  19. valset=list(map(mapper, validation_docs)),
  20. testset=list(map(mapper, test_docs)),
  21. )