helpers.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. import typing as t
  2. from .datatypes import TaskDatasets
  3. if t.TYPE_CHECKING:
  4. import dspy
  5. from datasets import Dataset
  6. def train_val_test_split(
  7. dataset: "Dataset",
  8. mapper: t.Callable[[dict], "dspy.Example"],
  9. train_size: float = 0.1,
  10. validation_size: float = 0.1,
  11. ) -> TaskDatasets:
  12. docs = dataset.train_test_split(train_size=train_size)
  13. train_docs = docs["train"]
  14. docs = docs["test"].train_test_split(train_size=validation_size)
  15. validation_docs = docs["train"]
  16. test_docs = docs["test"]
  17. return TaskDatasets(
  18. trainset=list(map(mapper, train_docs)),
  19. valset=list(map(mapper, validation_docs)),
  20. testset=list(map(mapper, test_docs)),
  21. )
  22. def fixed_split(
  23. dataset: "Dataset",
  24. mapper: t.Callable[[dict], "dspy.Example"],
  25. train_size: int = 1000,
  26. validation_size: int = 200,
  27. ) -> TaskDatasets:
  28. """Split dataset by taking first N examples instead of random sampling.
  29. Args:
  30. dataset: Input dataset
  31. mapper: Function to map dataset examples to dspy.Example
  32. train_size: Number of examples to use for training (default: 1000)
  33. validation_size: Number of examples to use for validation (default: 200)
  34. Returns:
  35. TaskDatasets containing train, validation and test splits
  36. """
  37. train_docs = dataset.select(range(train_size))
  38. validation_docs = dataset.select(range(train_size, train_size + validation_size))
  39. test_docs = dataset.select(range(train_size + validation_size, len(dataset)))
  40. return TaskDatasets(
  41. trainset=list(map(mapper, train_docs)),
  42. valset=list(map(mapper, validation_docs)),
  43. testset=list(map(mapper, test_docs)),
  44. )