|
@@ -5,7 +5,7 @@
|
|
|
|
|
|
import copy
|
|
|
import datasets
|
|
|
-from datasets import Dataset, load_dataset
|
|
|
+from datasets import Dataset, load_dataset, DatasetDict
|
|
|
import itertools
|
|
|
|
|
|
|
|
@@ -27,13 +27,15 @@ def tokenize_dialog(q_a_pair, tokenizer):
|
|
|
return dict(combined_tokens, attention_mask=[1]*len(combined_tokens["input_ids"]))
|
|
|
|
|
|
|
|
|
-def get_custom_dataset(dataset_config, tokenizer, split):
|
|
|
+def get_custom_dataset(dataset_config, tokenizer, split, split_ratio=0.8):
|
|
|
dataset = load_dataset('json', data_files=dataset_config.data_path)
|
|
|
- dataset = dataset.map(lambda sample: {
|
|
|
+ dataset = dataset['train'].train_test_split(test_size=1-split_ratio, shuffle=True)
|
|
|
+
|
|
|
+ dataset = dataset[split].map(lambda sample: {
|
|
|
"question": sample["question"],
|
|
|
"answer": sample["answer"],
|
|
|
},
|
|
|
batched=True,
|
|
|
)
|
|
|
dataset = dataset.map(lambda x: tokenize_dialog(x, tokenizer))
|
|
|
- return dataset["train"]
|
|
|
+ return dataset
|