|
@@ -166,8 +166,7 @@ def main(**kwargs):
|
|
#setting up FSDP if enable_fsdp is enabled
|
|
#setting up FSDP if enable_fsdp is enabled
|
|
if train_config.enable_fsdp:
|
|
if train_config.enable_fsdp:
|
|
if not train_config.use_peft and train_config.freeze_layers:
|
|
if not train_config.use_peft and train_config.freeze_layers:
|
|
-
|
|
|
|
- freeze_transformer_layers(train_config.num_freeze_layers)
|
|
|
|
|
|
+ freeze_transformer_layers(model, train_config.num_freeze_layers)
|
|
|
|
|
|
mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
|
|
mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
|
|
my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)
|
|
my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)
|
|
@@ -217,7 +216,7 @@ def main(**kwargs):
|
|
split="test",
|
|
split="test",
|
|
)
|
|
)
|
|
if not train_config.enable_fsdp or rank == 0:
|
|
if not train_config.enable_fsdp or rank == 0:
|
|
- print(f"--> Validation Set Length = {len(dataset_val)}")
|
|
|
|
|
|
+ print(f"--> Validation Set Length = {len(dataset_val)}")
|
|
|
|
|
|
if train_config.batching_strategy == "packing":
|
|
if train_config.batching_strategy == "packing":
|
|
dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
|
|
dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
|