Browse Source

adding data_path to custom cfg andcheck to finetuning

Hamid Shojanazeri 1 year ago
parent
commit
7a5ca61136
2 changed files with 10 additions and 8 deletions
  1. 2 1
      src/llama_recipes/configs/datasets.py
  2. 8 7
      src/llama_recipes/finetuning.py

+ 2 - 1
src/llama_recipes/configs/datasets.py

@@ -31,4 +31,5 @@ class custom_dataset:
     dataset: str = "custom_dataset"
     dataset: str = "custom_dataset"
     file: str = "examples/custom_dataset.py"
     file: str = "examples/custom_dataset.py"
     train_split: str = "train"
     train_split: str = "train"
-    test_split: str = "validation"
+    test_split: str = "validation"
+    data_path: str = "custom_dataset.json"

+ 8 - 7
src/llama_recipes/finetuning.py

@@ -184,13 +184,14 @@ def main(**kwargs):
     if not train_config.enable_fsdp or rank == 0:
     if not train_config.enable_fsdp or rank == 0:
         print(f"--> Training Set Length = {len(dataset_train)}")
         print(f"--> Training Set Length = {len(dataset_train)}")
 
 
-    dataset_val = get_preprocessed_dataset(
-        tokenizer,
-        dataset_config,
-        split="test",
-    )
-    if not train_config.enable_fsdp or rank == 0:
-            print(f"--> Validation Set Length = {len(dataset_val)}")
+    if train_config.run_validation:
+        dataset_val = get_preprocessed_dataset(
+            tokenizer,
+            dataset_config,
+            split="test",
+        )
+        if not train_config.enable_fsdp or rank == 0:
+                print(f"--> Validation Set Length = {len(dataset_val)}")
 
 
     if train_config.batching_strategy == "packing":
     if train_config.batching_strategy == "packing":
         dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
         dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)