|
@@ -38,8 +38,10 @@ from llama_recipes.utils.fsdp_utils import hsdp_device_mesh
|
|
from llama_recipes.utils.train_utils import (
|
|
from llama_recipes.utils.train_utils import (
|
|
clear_gpu_cache,
|
|
clear_gpu_cache,
|
|
freeze_transformer_layers,
|
|
freeze_transformer_layers,
|
|
|
|
+ freeze_LLM_only,
|
|
get_policies,
|
|
get_policies,
|
|
print_model_size,
|
|
print_model_size,
|
|
|
|
+ print_frozen_model_status,
|
|
setup,
|
|
setup,
|
|
setup_environ_flags,
|
|
setup_environ_flags,
|
|
train,
|
|
train,
|
|
@@ -194,7 +196,7 @@ def main(**kwargs):
|
|
model.resize_token_embeddings(len(tokenizer))
|
|
model.resize_token_embeddings(len(tokenizer))
|
|
|
|
|
|
print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
|
|
print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
|
|
-
|
|
|
|
|
|
+
|
|
# Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
|
|
# Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
|
|
if (
|
|
if (
|
|
train_config.enable_fsdp
|
|
train_config.enable_fsdp
|
|
@@ -235,7 +237,14 @@ def main(**kwargs):
|
|
|
|
|
|
if not train_config.use_peft and train_config.freeze_layers:
|
|
if not train_config.use_peft and train_config.freeze_layers:
|
|
freeze_transformer_layers(model, train_config.num_freeze_layers)
|
|
freeze_transformer_layers(model, train_config.num_freeze_layers)
|
|
-
|
|
|
|
|
|
+ # print model size and frozen layers after freezing layers
|
|
|
|
+ print_frozen_model_status(model, train_config, rank if train_config.enable_fsdp else 0)
|
|
|
|
+
|
|
|
|
+ if not train_config.use_peft and train_config.freeze_LLM_only and config.model_type == "mllama":
|
|
|
|
+ freeze_LLM_only(model)
|
|
|
|
+ # print model size and frozen layers after freezing layers
|
|
|
|
+ print_frozen_model_status(model, train_config, rank if train_config.enable_fsdp else 0)
|
|
|
|
+
|
|
mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
|
|
mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
|
|
# Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models
|
|
# Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models
|
|
if is_vision:
|
|
if is_vision:
|
|
@@ -255,6 +264,11 @@ def main(**kwargs):
|
|
device_id = torch.xpu.current_device()
|
|
device_id = torch.xpu.current_device()
|
|
elif torch.cuda.is_available():
|
|
elif torch.cuda.is_available():
|
|
device_id = torch.cuda.current_device()
|
|
device_id = torch.cuda.current_device()
|
|
|
|
+
|
|
|
|
+ if train_config.freeze_LLM_only:
|
|
|
|
+ use_orig_params = True
|
|
|
|
+ else:
|
|
|
|
+ use_orig_params = False
|
|
model = FSDP(
|
|
model = FSDP(
|
|
model,
|
|
model,
|
|
auto_wrap_policy=(
|
|
auto_wrap_policy=(
|
|
@@ -282,6 +296,7 @@ def main(**kwargs):
|
|
if train_config.low_cpu_fsdp and rank != 0
|
|
if train_config.low_cpu_fsdp and rank != 0
|
|
else None
|
|
else None
|
|
),
|
|
),
|
|
|
|
+ use_orig_params=use_orig_params,
|
|
)
|
|
)
|
|
if fsdp_config.fsdp_activation_checkpointing:
|
|
if fsdp_config.fsdp_activation_checkpointing:
|
|
model.enable_input_require_grads()
|
|
model.enable_input_require_grads()
|
|
@@ -297,7 +312,7 @@ def main(**kwargs):
|
|
dataset_processer = processor
|
|
dataset_processer = processor
|
|
else:
|
|
else:
|
|
dataset_processer = tokenizer
|
|
dataset_processer = tokenizer
|
|
-
|
|
|
|
|
|
+
|
|
# Load and preprocess the dataset for training and validation
|
|
# Load and preprocess the dataset for training and validation
|
|
|
|
|
|
dataset_train = get_preprocessed_dataset(
|
|
dataset_train = get_preprocessed_dataset(
|