소스 검색

adding splits to dataset

Hamid Shojanazeri 1 년 전
부모
커밋
07257a3f86
1개의 변경된 파일6개의 추가작업 그리고 4개의 파일을 삭제
  1. 6 4
      examples/llama_dataset.py

+ 6 - 4
examples/llama_dataset.py

@@ -5,7 +5,7 @@
 
 import copy
 import datasets
-from datasets import Dataset, load_dataset
+from datasets import Dataset, load_dataset, DatasetDict
 import itertools
 
 
@@ -27,13 +27,15 @@ def tokenize_dialog(q_a_pair, tokenizer):
     return dict(combined_tokens, attention_mask=[1]*len(combined_tokens["input_ids"]))
 
 
-def get_custom_dataset(dataset_config, tokenizer, split):
+def get_custom_dataset(dataset_config, tokenizer, split, split_ratio=0.8):
     dataset = load_dataset('json', data_files=dataset_config.data_path)
-    dataset = dataset.map(lambda sample: {
+    dataset = dataset['train'].train_test_split(test_size=1-split_ratio, shuffle=True)
+    
+    dataset = dataset[split].map(lambda sample: {
         "question": sample["question"],
         "answer": sample["answer"],
         },
         batched=True,
     )
     dataset = dataset.map(lambda x: tokenize_dialog(x, tokenizer))
-    return dataset["train"]
+    return dataset