|
@@ -30,7 +30,8 @@ from transformers import (
|
|
MllamaForConditionalGeneration
|
|
MllamaForConditionalGeneration
|
|
)
|
|
)
|
|
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
|
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
|
-from transformers.models.clip.modeling_clip import CLIPEncoder, CLIPEncoderLayer
|
|
|
|
|
|
+from transformers.models.mllama.modeling_mllama import MllamaSelfAttentionDecoderLayer,MllamaCrossAttentionDecoderLayer,MllamaVisionEncoderLayer
|
|
|
|
+
|
|
from llama_recipes.configs import fsdp_config as FSDP_CONFIG
|
|
from llama_recipes.configs import fsdp_config as FSDP_CONFIG
|
|
from llama_recipes.configs import train_config as TRAIN_CONFIG
|
|
from llama_recipes.configs import train_config as TRAIN_CONFIG
|
|
from llama_recipes.configs import quantization_config as QUANTIZATION_CONFIG
|
|
from llama_recipes.configs import quantization_config as QUANTIZATION_CONFIG
|
|
@@ -129,7 +130,6 @@ def main(**kwargs):
|
|
model = MllamaForConditionalGeneration.from_pretrained(
|
|
model = MllamaForConditionalGeneration.from_pretrained(
|
|
train_config.model_name,
|
|
train_config.model_name,
|
|
quantization_config=bnb_config,
|
|
quantization_config=bnb_config,
|
|
- #use_cache=use_cache,
|
|
|
|
attn_implementation="sdpa" if train_config.use_fast_kernels else None,
|
|
attn_implementation="sdpa" if train_config.use_fast_kernels else None,
|
|
device_map="auto" if train_config.quantization and not train_config.enable_fsdp else None,
|
|
device_map="auto" if train_config.quantization and not train_config.enable_fsdp else None,
|
|
torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
|
|
torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
|
|
@@ -146,7 +146,7 @@ def main(**kwargs):
|
|
device_map="auto" if train_config.quantization and not train_config.enable_fsdp else None,
|
|
device_map="auto" if train_config.quantization and not train_config.enable_fsdp else None,
|
|
torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
|
|
torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
|
|
)
|
|
)
|
|
-
|
|
|
|
|
|
+ print(model)
|
|
# Load the tokenizer and add special tokens
|
|
# Load the tokenizer and add special tokens
|
|
tokenizer = AutoTokenizer.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
|
|
tokenizer = AutoTokenizer.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
|
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
|
@@ -189,11 +189,7 @@ def main(**kwargs):
|
|
freeze_transformer_layers(model, train_config.num_freeze_layers)
|
|
freeze_transformer_layers(model, train_config.num_freeze_layers)
|
|
|
|
|
|
mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
|
|
mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
|
|
- my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer,CLIPEncoderLayer])
|
|
|
|
- # if is_vision:
|
|
|
|
- # my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer,CLIPEncoderLayer])
|
|
|
|
- # else:
|
|
|
|
- # my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer])
|
|
|
|
|
|
+ my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer])
|
|
print("FSDP is enabled",my_auto_wrapping_policy)
|
|
print("FSDP is enabled",my_auto_wrapping_policy)
|
|
device_id = 0
|
|
device_id = 0
|
|
if is_xpu_available():
|
|
if is_xpu_available():
|
|
@@ -222,7 +218,8 @@ def main(**kwargs):
|
|
model.to("xpu:0")
|
|
model.to("xpu:0")
|
|
elif torch.cuda.is_available():
|
|
elif torch.cuda.is_available():
|
|
model.to("cuda")
|
|
model.to("cuda")
|
|
-
|
|
|
|
|
|
+ print("-------------------")
|
|
|
|
+ print("FSDP model", model)
|
|
dataset_config = generate_dataset_config(train_config, kwargs)
|
|
dataset_config = generate_dataset_config(train_config, kwargs)
|
|
if is_vision:
|
|
if is_vision:
|
|
dataset_processer = processor
|
|
dataset_processer = processor
|
|
@@ -248,7 +245,10 @@ def main(**kwargs):
|
|
print(f"--> Validation Set Length = {len(dataset_val)}")
|
|
print(f"--> Validation Set Length = {len(dataset_val)}")
|
|
|
|
|
|
if train_config.batching_strategy == "packing":
|
|
if train_config.batching_strategy == "packing":
|
|
- dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
|
|
|
|
|
|
+ if is_vision:
|
|
|
|
+ raise ValueError("Packing is not supported for vision datasets")
|
|
|
|
+ else:
|
|
|
|
+ dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
|
|
|
|
|
|
train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, dataset_processer, "train")
|
|
train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, dataset_processer, "train")
|
|
print("length of dataset_train", len(dataset_train))
|
|
print("length of dataset_train", len(dataset_train))
|
|
@@ -268,7 +268,10 @@ def main(**kwargs):
|
|
eval_dataloader = None
|
|
eval_dataloader = None
|
|
if train_config.run_validation:
|
|
if train_config.run_validation:
|
|
if train_config.batching_strategy == "packing":
|
|
if train_config.batching_strategy == "packing":
|
|
- dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
|
|
|
|
|
|
+ if is_vision:
|
|
|
|
+ raise ValueError("Packing is not supported for vision datasets")
|
|
|
|
+ else:
|
|
|
|
+ dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
|
|
|
|
|
|
val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, dataset_processer, "val")
|
|
val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, dataset_processer, "val")
|
|
if custom_data_collator:
|
|
if custom_data_collator:
|