|
@@ -3,15 +3,15 @@ import typing as t
|
|
from .datatypes import TaskDatasets
|
|
from .datatypes import TaskDatasets
|
|
|
|
|
|
if t.TYPE_CHECKING:
|
|
if t.TYPE_CHECKING:
|
|
- from datasets import Dataset
|
|
|
|
import dspy
|
|
import dspy
|
|
|
|
+ from datasets import Dataset
|
|
|
|
|
|
|
|
|
|
def train_val_test_split(
|
|
def train_val_test_split(
|
|
dataset: "Dataset",
|
|
dataset: "Dataset",
|
|
mapper: t.Callable[[dict], "dspy.Example"],
|
|
mapper: t.Callable[[dict], "dspy.Example"],
|
|
train_size: float = 0.1,
|
|
train_size: float = 0.1,
|
|
- validation_size: float = 0.2,
|
|
|
|
|
|
+ validation_size: float = 0.1,
|
|
) -> TaskDatasets:
|
|
) -> TaskDatasets:
|
|
docs = dataset.train_test_split(train_size=train_size)
|
|
docs = dataset.train_test_split(train_size=train_size)
|
|
train_docs = docs["train"]
|
|
train_docs = docs["train"]
|
|
@@ -32,20 +32,20 @@ def fixed_split(
|
|
validation_size: int = 200,
|
|
validation_size: int = 200,
|
|
) -> TaskDatasets:
|
|
) -> TaskDatasets:
|
|
"""Split dataset by taking first N examples instead of random sampling.
|
|
"""Split dataset by taking first N examples instead of random sampling.
|
|
-
|
|
|
|
|
|
+
|
|
Args:
|
|
Args:
|
|
dataset: Input dataset
|
|
dataset: Input dataset
|
|
mapper: Function to map dataset examples to dspy.Example
|
|
mapper: Function to map dataset examples to dspy.Example
|
|
train_size: Number of examples to use for training (default: 1000)
|
|
train_size: Number of examples to use for training (default: 1000)
|
|
validation_size: Number of examples to use for validation (default: 200)
|
|
validation_size: Number of examples to use for validation (default: 200)
|
|
-
|
|
|
|
|
|
+
|
|
Returns:
|
|
Returns:
|
|
TaskDatasets containing train, validation and test splits
|
|
TaskDatasets containing train, validation and test splits
|
|
"""
|
|
"""
|
|
train_docs = dataset.select(range(train_size))
|
|
train_docs = dataset.select(range(train_size))
|
|
validation_docs = dataset.select(range(train_size, train_size + validation_size))
|
|
validation_docs = dataset.select(range(train_size, train_size + validation_size))
|
|
test_docs = dataset.select(range(train_size + validation_size, len(dataset)))
|
|
test_docs = dataset.select(range(train_size + validation_size, len(dataset)))
|
|
-
|
|
|
|
|
|
+
|
|
return TaskDatasets(
|
|
return TaskDatasets(
|
|
trainset=list(map(mapper, train_docs)),
|
|
trainset=list(map(mapper, train_docs)),
|
|
valset=list(map(mapper, validation_docs)),
|
|
valset=list(map(mapper, validation_docs)),
|