|
@@ -14,16 +14,18 @@ from torch.distributed.fsdp import (
|
|
|
FullyShardedDataParallel as FSDP,
|
|
|
ShardingStrategy
|
|
|
)
|
|
|
-
|
|
|
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
|
|
|
from torch.optim.lr_scheduler import StepLR
|
|
|
from transformers import (
|
|
|
+ AutoConfig,
|
|
|
AutoTokenizer,
|
|
|
BitsAndBytesConfig,
|
|
|
- LlamaForCausalLM,
|
|
|
- LlamaConfig,
|
|
|
+ AutoProcessor,
|
|
|
+ MllamaForConditionalGeneration,
|
|
|
+ AutoModel,
|
|
|
)
|
|
|
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
|
|
+from transformers.models.mllama.modeling_mllama import MllamaSelfAttentionDecoderLayer,MllamaCrossAttentionDecoderLayer,MllamaVisionEncoderLayer
|
|
|
|
|
|
from llama_recipes.configs import fsdp_config as FSDP_CONFIG
|
|
|
from llama_recipes.configs import train_config as TRAIN_CONFIG
|
|
@@ -39,7 +41,7 @@ from llama_recipes.utils.config_utils import (
|
|
|
get_dataloader_kwargs,
|
|
|
check_fsdp_config,
|
|
|
)
|
|
|
-from llama_recipes.utils.dataset_utils import get_preprocessed_dataset
|
|
|
+from llama_recipes.utils.dataset_utils import get_preprocessed_dataset,get_custom_data_collator
|
|
|
|
|
|
from llama_recipes.utils.fsdp_utils import hsdp_device_mesh
|
|
|
from llama_recipes.utils.train_utils import (
|
|
@@ -118,19 +120,35 @@ def main(**kwargs):
|
|
|
|
|
|
# Load the pre-trained model and setup its configuration
|
|
|
use_cache = False if train_config.enable_fsdp else None
|
|
|
- model = LlamaForCausalLM.from_pretrained(
|
|
|
+ config = AutoConfig.from_pretrained(train_config.model_name)
|
|
|
+ if config.model_type == "mllama":
|
|
|
+ is_vision = True
|
|
|
+ model = MllamaForConditionalGeneration.from_pretrained(
|
|
|
train_config.model_name,
|
|
|
quantization_config=bnb_config,
|
|
|
- use_cache=use_cache,
|
|
|
attn_implementation="sdpa" if train_config.use_fast_kernels else None,
|
|
|
device_map="auto" if train_config.quantization and not train_config.enable_fsdp else None,
|
|
|
torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
|
|
|
)
|
|
|
-
|
|
|
+ processor = AutoProcessor.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
|
|
|
+ processor.tokenizer.padding_side='right'
|
|
|
+ elif config.model_type == "llama":
|
|
|
+ is_vision = False
|
|
|
+ model = AutoModel.from_pretrained(
|
|
|
+ train_config.model_name,
|
|
|
+ quantization_config=bnb_config,
|
|
|
+ use_cache=use_cache,
|
|
|
+ attn_implementation="sdpa" if train_config.use_fast_kernels else None,
|
|
|
+ device_map="auto" if train_config.quantization and not train_config.enable_fsdp else None,
|
|
|
+ torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Model type {config.model_type} is not supported. Please use llama or mllama model.")
|
|
|
# Load the tokenizer and add special tokens
|
|
|
tokenizer = AutoTokenizer.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
|
|
|
- tokenizer.pad_token_id = tokenizer.eos_token_id
|
|
|
-
|
|
|
+ if not tokenizer.pad_token_id:
|
|
|
+ tokenizer.pad_token_id = tokenizer.eos_token_id
|
|
|
+
|
|
|
# If there is a mismatch between tokenizer vocab size and embedding matrix,
|
|
|
# throw a warning and then expand the embedding matrix
|
|
|
if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
|
|
@@ -169,8 +187,12 @@ def main(**kwargs):
|
|
|
freeze_transformer_layers(model, train_config.num_freeze_layers)
|
|
|
|
|
|
mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
|
|
|
- my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)
|
|
|
-
|
|
|
+ # Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models
|
|
|
+ if is_vision:
|
|
|
+ my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer])
|
|
|
+ else:
|
|
|
+ # Create the FSDP wrapper for LlamaDecoderLayer in text models
|
|
|
+ my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer])
|
|
|
device_id = 0
|
|
|
if is_xpu_available():
|
|
|
device_id = torch.xpu.current_device()
|
|
@@ -198,12 +220,16 @@ def main(**kwargs):
|
|
|
model.to("xpu:0")
|
|
|
elif torch.cuda.is_available():
|
|
|
model.to("cuda")
|
|
|
-
|
|
|
dataset_config = generate_dataset_config(train_config, kwargs)
|
|
|
+ if is_vision:
|
|
|
+ dataset_processer = processor
|
|
|
+ else:
|
|
|
+ dataset_processer = tokenizer
|
|
|
+
|
|
|
+ # Load and preprocess the dataset for training and validation
|
|
|
|
|
|
- # Load and preprocess the dataset for training and validation
|
|
|
dataset_train = get_preprocessed_dataset(
|
|
|
- tokenizer,
|
|
|
+ dataset_processer,
|
|
|
dataset_config,
|
|
|
split="train",
|
|
|
)
|
|
@@ -211,7 +237,7 @@ def main(**kwargs):
|
|
|
print(f"--> Training Set Length = {len(dataset_train)}")
|
|
|
|
|
|
dataset_val = get_preprocessed_dataset(
|
|
|
- tokenizer,
|
|
|
+ dataset_processer,
|
|
|
dataset_config,
|
|
|
split="test",
|
|
|
)
|
|
@@ -219,10 +245,17 @@ def main(**kwargs):
|
|
|
print(f"--> Validation Set Length = {len(dataset_val)}")
|
|
|
|
|
|
if train_config.batching_strategy == "packing":
|
|
|
- dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
|
|
|
-
|
|
|
- train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train")
|
|
|
-
|
|
|
+ if is_vision:
|
|
|
+ raise ValueError("Packing is not supported for vision datasets")
|
|
|
+ else:
|
|
|
+ dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
|
|
|
+
|
|
|
+ train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, dataset_processer, "train")
|
|
|
+ print("length of dataset_train", len(dataset_train))
|
|
|
+ custom_data_collator = get_custom_data_collator(dataset_processer,dataset_config)
|
|
|
+ if custom_data_collator:
|
|
|
+ print("custom_data_collator is used")
|
|
|
+ train_dl_kwargs["collate_fn"] = custom_data_collator
|
|
|
# Create DataLoaders for the training and validation dataset
|
|
|
train_dataloader = torch.utils.data.DataLoader(
|
|
|
dataset_train,
|
|
@@ -230,13 +263,19 @@ def main(**kwargs):
|
|
|
pin_memory=True,
|
|
|
**train_dl_kwargs,
|
|
|
)
|
|
|
+ print(f"--> Num of Training Set Batches loaded = {len(train_dataloader)}")
|
|
|
|
|
|
eval_dataloader = None
|
|
|
if train_config.run_validation:
|
|
|
if train_config.batching_strategy == "packing":
|
|
|
- dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
|
|
|
+ if is_vision:
|
|
|
+ raise ValueError("Packing is not supported for vision datasets")
|
|
|
+ else:
|
|
|
+ dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
|
|
|
|
|
|
- val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, tokenizer, "val")
|
|
|
+ val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, dataset_processer, "val")
|
|
|
+ if custom_data_collator:
|
|
|
+ val_dl_kwargs["collate_fn"] = custom_data_collator
|
|
|
|
|
|
eval_dataloader = torch.utils.data.DataLoader(
|
|
|
dataset_val,
|
|
@@ -244,6 +283,7 @@ def main(**kwargs):
|
|
|
pin_memory=True,
|
|
|
**val_dl_kwargs,
|
|
|
)
|
|
|
+ print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
|
|
|
if len(eval_dataloader) == 0:
|
|
|
raise ValueError("The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set.")
|
|
|
else:
|
|
@@ -266,7 +306,6 @@ def main(**kwargs):
|
|
|
weight_decay=train_config.weight_decay,
|
|
|
)
|
|
|
scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
|
|
|
- # Start the training process
|
|
|
results = train(
|
|
|
model,
|
|
|
train_dataloader,
|