|
@@ -14,7 +14,11 @@ from torch.distributed.fsdp import (
|
|
|
FullyShardedDataParallel as FSDP,
|
|
|
ShardingStrategy
|
|
|
)
|
|
|
-
|
|
|
+from torch.distributed.fsdp.wrap import (
|
|
|
+ always_wrap_policy,
|
|
|
+ ModuleWrapPolicy,
|
|
|
+ transformer_auto_wrap_policy,
|
|
|
+)
|
|
|
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
|
|
|
from torch.optim.lr_scheduler import StepLR
|
|
|
from transformers import (
|
|
@@ -29,7 +33,7 @@ from transformers import (
|
|
|
|
|
|
)
|
|
|
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
|
|
-
|
|
|
+from transformers.models.clip.modeling_clip import CLIPEncoder, CLIPEncoderLayer
|
|
|
from llama_recipes.configs import fsdp_config as FSDP_CONFIG
|
|
|
from llama_recipes.configs import train_config as TRAIN_CONFIG
|
|
|
from llama_recipes.configs import quantization_config as QUANTIZATION_CONFIG
|
|
@@ -121,11 +125,11 @@ def main(**kwargs):
|
|
|
bnb_config = quant_config.create_bnb_config(train_config.quantization)
|
|
|
|
|
|
# Load the pre-trained model and setup its configuration
|
|
|
- #use_cache = False if train_config.enable_fsdp else None
|
|
|
+ use_cache = False if train_config.enable_fsdp else None
|
|
|
model = LlavaNextForConditionalGeneration.from_pretrained(
|
|
|
train_config.model_name,
|
|
|
quantization_config=bnb_config,
|
|
|
- # use_cache=use_cache,
|
|
|
+ #use_cache=use_cache,
|
|
|
attn_implementation="sdpa" if train_config.use_fast_kernels else None,
|
|
|
device_map="auto" if train_config.quantization and not train_config.enable_fsdp else None,
|
|
|
torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
|
|
@@ -172,16 +176,25 @@ def main(**kwargs):
|
|
|
freeze_transformer_layers(model, train_config.num_freeze_layers)
|
|
|
|
|
|
mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
|
|
|
- my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)
|
|
|
-
|
|
|
+ my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [CLIPEncoderLayer])
|
|
|
+ print("FSDP is enabled",my_auto_wrapping_policy)
|
|
|
device_id = 0
|
|
|
if is_xpu_available():
|
|
|
device_id = torch.xpu.current_device()
|
|
|
elif torch.cuda.is_available():
|
|
|
device_id = torch.cuda.current_device()
|
|
|
+ # print(dir(model))
|
|
|
+ # for layer in model.named_children():
|
|
|
+ # print(f"Layer: {layer}")
|
|
|
+
|
|
|
+ # layernorm = model.CLIPVisionTransformer.CLIPEncoder.LayerNorm
|
|
|
+ # for name, param in layernorm.named_parameters():
|
|
|
+ # print(f"Parameter: {name}, Shape: {param.shape}, Dtype: {param.dtype}")
|
|
|
+ # exit()
|
|
|
model = FSDP(
|
|
|
model,
|
|
|
- auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
|
|
|
+ auto_wrap_policy= ModuleWrapPolicy([CLIPEncoderLayer, LlamaDecoderLayer]),
|
|
|
+ #auto_wrap_policy= my_auto_wrapping_policy, #if train_config.use_peft else wrapping_policy,
|
|
|
cpu_offload=CPUOffload(offload_params=True) if fsdp_config.fsdp_cpu_offload else None,
|
|
|
mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
|
|
|
sharding_strategy=fsdp_config.sharding_strategy,
|
|
@@ -192,6 +205,7 @@ def main(**kwargs):
|
|
|
param_init_fn=(lambda module: module.to_empty(device=torch.device("cuda"), recurse=False))
|
|
|
if train_config.low_cpu_fsdp and rank != 0 else None,
|
|
|
)
|
|
|
+ #print(model)
|
|
|
if fsdp_config.fsdp_activation_checkpointing:
|
|
|
model.enable_input_require_grads()
|
|
|
model.gradient_checkpointing_enable()
|
|
@@ -205,6 +219,11 @@ def main(**kwargs):
|
|
|
dataset_config = generate_dataset_config(train_config, kwargs)
|
|
|
|
|
|
# Load and preprocess the dataset for training and validation
|
|
|
+ # dataset_train = get_preprocessed_dataset(
|
|
|
+ # processor,
|
|
|
+ # dataset_config,
|
|
|
+ # split="train",
|
|
|
+ # )
|
|
|
dataset_train = get_preprocessed_dataset(
|
|
|
processor,
|
|
|
dataset_config,
|
|
@@ -272,6 +291,7 @@ def main(**kwargs):
|
|
|
)
|
|
|
scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
|
|
|
# Start the training process
|
|
|
+
|
|
|
results = train(
|
|
|
model,
|
|
|
train_dataloader,
|