|
@@ -184,13 +184,14 @@ def main(**kwargs):
|
|
if not train_config.enable_fsdp or rank == 0:
|
|
if not train_config.enable_fsdp or rank == 0:
|
|
print(f"--> Training Set Length = {len(dataset_train)}")
|
|
print(f"--> Training Set Length = {len(dataset_train)}")
|
|
|
|
|
|
- dataset_val = get_preprocessed_dataset(
|
|
|
|
- tokenizer,
|
|
|
|
- dataset_config,
|
|
|
|
- split="test",
|
|
|
|
- )
|
|
|
|
- if not train_config.enable_fsdp or rank == 0:
|
|
|
|
- print(f"--> Validation Set Length = {len(dataset_val)}")
|
|
|
|
|
|
+ if train_config.run_validation:
|
|
|
|
+ dataset_val = get_preprocessed_dataset(
|
|
|
|
+ tokenizer,
|
|
|
|
+ dataset_config,
|
|
|
|
+ split="test",
|
|
|
|
+ )
|
|
|
|
+ if not train_config.enable_fsdp or rank == 0:
|
|
|
|
+ print(f"--> Validation Set Length = {len(dataset_val)}")
|
|
|
|
|
|
if train_config.batching_strategy == "packing":
|
|
if train_config.batching_strategy == "packing":
|
|
dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
|
|
dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
|