|
@@ -1,61 +1,68 @@
|
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
|
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
|
|
|
|
|
-from collections import Counter
|
|
|
+import dataclasses
|
|
|
import os
|
|
|
+import random
|
|
|
+from collections import Counter
|
|
|
+from warnings import warn
|
|
|
|
|
|
-import dataclasses
|
|
|
import fire
|
|
|
-import random
|
|
|
+import numpy as np
|
|
|
import torch
|
|
|
import torch.optim as optim
|
|
|
-import numpy as np
|
|
|
-from peft import get_peft_model, PeftModel
|
|
|
-from torch.distributed.fsdp import (
|
|
|
- FullyShardedDataParallel as FSDP,
|
|
|
- ShardingStrategy
|
|
|
-)
|
|
|
-from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
|
|
|
-from torch.optim.lr_scheduler import StepLR
|
|
|
-from transformers import (
|
|
|
- AutoConfig,
|
|
|
- AutoTokenizer,
|
|
|
- BitsAndBytesConfig,
|
|
|
- AutoProcessor,
|
|
|
- LlamaForCausalLM,
|
|
|
- MllamaForConditionalGeneration,
|
|
|
-)
|
|
|
-from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
|
|
-from transformers.models.mllama.modeling_mllama import MllamaSelfAttentionDecoderLayer,MllamaCrossAttentionDecoderLayer,MllamaVisionEncoderLayer
|
|
|
+from accelerate.utils import is_xpu_available
|
|
|
|
|
|
-from llama_recipes.configs import fsdp_config as FSDP_CONFIG
|
|
|
-from llama_recipes.configs import train_config as TRAIN_CONFIG
|
|
|
-from llama_recipes.configs import quantization_config as QUANTIZATION_CONFIG
|
|
|
+from llama_recipes.configs import (
|
|
|
+ fsdp_config as FSDP_CONFIG,
|
|
|
+ quantization_config as QUANTIZATION_CONFIG,
|
|
|
+ train_config as TRAIN_CONFIG,
|
|
|
+)
|
|
|
from llama_recipes.data.concatenator import ConcatDataset
|
|
|
from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing
|
|
|
|
|
|
from llama_recipes.utils import fsdp_auto_wrap_policy
|
|
|
from llama_recipes.utils.config_utils import (
|
|
|
- update_config,
|
|
|
- generate_peft_config,
|
|
|
+ check_fsdp_config,
|
|
|
generate_dataset_config,
|
|
|
+ generate_peft_config,
|
|
|
get_dataloader_kwargs,
|
|
|
- check_fsdp_config,
|
|
|
+ update_config,
|
|
|
+)
|
|
|
+from llama_recipes.utils.dataset_utils import (
|
|
|
+ get_custom_data_collator,
|
|
|
+ get_preprocessed_dataset,
|
|
|
)
|
|
|
-from llama_recipes.utils.dataset_utils import get_preprocessed_dataset,get_custom_data_collator
|
|
|
|
|
|
from llama_recipes.utils.fsdp_utils import hsdp_device_mesh
|
|
|
from llama_recipes.utils.train_utils import (
|
|
|
- train,
|
|
|
+ clear_gpu_cache,
|
|
|
freeze_transformer_layers,
|
|
|
+ get_policies,
|
|
|
+ print_model_size,
|
|
|
setup,
|
|
|
setup_environ_flags,
|
|
|
- clear_gpu_cache,
|
|
|
- print_model_size,
|
|
|
- get_policies,
|
|
|
+ train,
|
|
|
+)
|
|
|
+from peft import get_peft_model, PeftModel
|
|
|
+from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy
|
|
|
+from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
|
|
|
+from torch.optim.lr_scheduler import StepLR
|
|
|
+from transformers import (
|
|
|
+ AutoConfig,
|
|
|
+ AutoProcessor,
|
|
|
+ AutoTokenizer,
|
|
|
+ BitsAndBytesConfig,
|
|
|
+ LlamaForCausalLM,
|
|
|
+ MllamaForConditionalGeneration,
|
|
|
+)
|
|
|
+from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
|
|
+from transformers.models.mllama.modeling_mllama import (
|
|
|
+ MllamaCrossAttentionDecoderLayer,
|
|
|
+ MllamaSelfAttentionDecoderLayer,
|
|
|
+ MllamaVisionEncoderLayer,
|
|
|
)
|
|
|
-from accelerate.utils import is_xpu_available
|
|
|
-from warnings import warn
|
|
|
+
|
|
|
|
|
|
def setup_wandb(train_config, fsdp_config, **kwargs):
|
|
|
try:
|
|
@@ -66,6 +73,7 @@ def setup_wandb(train_config, fsdp_config, **kwargs):
|
|
|
"Please install it using pip install wandb"
|
|
|
)
|
|
|
from llama_recipes.configs import wandb_config as WANDB_CONFIG
|
|
|
+
|
|
|
wandb_config = WANDB_CONFIG()
|
|
|
update_config(wandb_config, **kwargs)
|
|
|
init_dict = dataclasses.asdict(wandb_config)
|
|
@@ -74,6 +82,7 @@ def setup_wandb(train_config, fsdp_config, **kwargs):
|
|
|
run.config.update(fsdp_config, allow_val_change=True)
|
|
|
return run
|
|
|
|
|
|
+
|
|
|
def main(**kwargs):
|
|
|
# Update the configuration for the training and sharding process
|
|
|
train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG()
|
|
@@ -103,18 +112,23 @@ def main(**kwargs):
|
|
|
wandb_run = None
|
|
|
|
|
|
if train_config.use_wandb:
|
|
|
- if not train_config.enable_fsdp or rank==0:
|
|
|
+ if not train_config.enable_fsdp or rank == 0:
|
|
|
wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)
|
|
|
-
|
|
|
- #setting quantization configs
|
|
|
+
|
|
|
+ # setting quantization configs
|
|
|
bnb_config = None
|
|
|
if train_config.quantization:
|
|
|
if type(train_config.quantization) == type(True):
|
|
|
- warn("Quantization (--quantization) is a boolean, please specify quantization as '4bit' or '8bit'. Defaulting to '8bit' but this might change in the future.", FutureWarning)
|
|
|
+ warn(
|
|
|
+ "Quantization (--quantization) is a boolean, please specify quantization as '4bit' or '8bit'. Defaulting to '8bit' but this might change in the future.",
|
|
|
+ FutureWarning,
|
|
|
+ )
|
|
|
train_config.quantization = "8bit"
|
|
|
|
|
|
if train_config.quantization == "8bit" and train_config.enable_fsdp:
|
|
|
- raise ValueError("8bit quantization is not supported with FSDP, please use 4bit quantization")
|
|
|
+ raise ValueError(
|
|
|
+ "8bit quantization is not supported with FSDP, please use 4bit quantization"
|
|
|
+ )
|
|
|
|
|
|
quant_config = QUANTIZATION_CONFIG()
|
|
|
update_config(quant_config, **kwargs)
|
|
@@ -126,14 +140,22 @@ def main(**kwargs):
|
|
|
if config.model_type == "mllama":
|
|
|
is_vision = True
|
|
|
model = MllamaForConditionalGeneration.from_pretrained(
|
|
|
- train_config.model_name,
|
|
|
- quantization_config=bnb_config,
|
|
|
- attn_implementation="sdpa" if train_config.use_fast_kernels else None,
|
|
|
- device_map="auto" if train_config.quantization and not train_config.enable_fsdp else None,
|
|
|
- torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
|
|
|
- )
|
|
|
- processor = AutoProcessor.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
|
|
|
- processor.tokenizer.padding_side='right'
|
|
|
+ train_config.model_name,
|
|
|
+ quantization_config=bnb_config,
|
|
|
+ attn_implementation="sdpa" if train_config.use_fast_kernels else None,
|
|
|
+ device_map=(
|
|
|
+ "auto"
|
|
|
+ if train_config.quantization and not train_config.enable_fsdp
|
|
|
+ else None
|
|
|
+ ),
|
|
|
+ torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
|
|
|
+ )
|
|
|
+ processor = AutoProcessor.from_pretrained(
|
|
|
+ train_config.model_name
|
|
|
+ if train_config.tokenizer_name is None
|
|
|
+ else train_config.tokenizer_name
|
|
|
+ )
|
|
|
+ processor.tokenizer.padding_side = "right"
|
|
|
model.supports_gradient_checkpointing = True
|
|
|
model.language_model.supports_gradient_checkpointing = True
|
|
|
elif config.model_type == "llama":
|
|
@@ -143,32 +165,50 @@ def main(**kwargs):
|
|
|
quantization_config=bnb_config,
|
|
|
use_cache=use_cache,
|
|
|
attn_implementation="sdpa" if train_config.use_fast_kernels else None,
|
|
|
- device_map="auto" if train_config.quantization and not train_config.enable_fsdp else None,
|
|
|
+ device_map=(
|
|
|
+ "auto"
|
|
|
+ if train_config.quantization and not train_config.enable_fsdp
|
|
|
+ else None
|
|
|
+ ),
|
|
|
torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
|
|
|
)
|
|
|
else:
|
|
|
- raise ValueError(f"Model type {config.model_type} is not supported. Please use llama or mllama model.")
|
|
|
+ raise ValueError(
|
|
|
+ f"Model type {config.model_type} is not supported. Please use llama or mllama model."
|
|
|
+ )
|
|
|
# Load the tokenizer and add special tokens
|
|
|
- tokenizer = AutoTokenizer.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
|
|
|
- if not tokenizer.pad_token_id:
|
|
|
+ tokenizer = AutoTokenizer.from_pretrained(
|
|
|
+ train_config.model_name
|
|
|
+ if train_config.tokenizer_name is None
|
|
|
+ else train_config.tokenizer_name
|
|
|
+ )
|
|
|
+ if not tokenizer.pad_token_id:
|
|
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
|
|
-
|
|
|
+
|
|
|
# If there is a mismatch between tokenizer vocab size and embedding matrix,
|
|
|
# throw a warning and then expand the embedding matrix
|
|
|
if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
|
|
|
- print("WARNING: Resizing the embedding matrix to match the tokenizer vocab size.")
|
|
|
+ print(
|
|
|
+ "WARNING: Resizing the embedding matrix to match the tokenizer vocab size."
|
|
|
+ )
|
|
|
model.resize_token_embeddings(len(tokenizer))
|
|
|
|
|
|
print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
|
|
|
|
|
|
# Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
|
|
|
- if train_config.enable_fsdp and fsdp_config.pure_bf16 and not train_config.quantization:
|
|
|
+ if (
|
|
|
+ train_config.enable_fsdp
|
|
|
+ and fsdp_config.pure_bf16
|
|
|
+ and not train_config.quantization
|
|
|
+ ):
|
|
|
model.to(torch.bfloat16)
|
|
|
-
|
|
|
+
|
|
|
if train_config.use_peft:
|
|
|
# Load the pre-trained peft model checkpoint and setup its configuration
|
|
|
if train_config.from_peft_checkpoint:
|
|
|
- model = PeftModel.from_pretrained(model, train_config.from_peft_checkpoint, is_trainable=True)
|
|
|
+ model = PeftModel.from_pretrained(
|
|
|
+ model, train_config.from_peft_checkpoint, is_trainable=True
|
|
|
+ )
|
|
|
peft_config = model.peft_config
|
|
|
# Generate the peft config and start fine-tuning from original model
|
|
|
else:
|
|
@@ -179,23 +219,36 @@ def main(**kwargs):
|
|
|
model.print_trainable_parameters()
|
|
|
|
|
|
hsdp_device_mesh_plan = None
|
|
|
- if fsdp_config.hsdp and fsdp_config.sharding_strategy == ShardingStrategy.HYBRID_SHARD:
|
|
|
- hsdp_device_mesh_plan = hsdp_device_mesh(replica_group_size=fsdp_config.replica_group_size, sharding_group_size=fsdp_config.sharding_group_size)
|
|
|
+ if (
|
|
|
+ fsdp_config.hsdp
|
|
|
+ and fsdp_config.sharding_strategy == ShardingStrategy.HYBRID_SHARD
|
|
|
+ ):
|
|
|
+ hsdp_device_mesh_plan = hsdp_device_mesh(
|
|
|
+ replica_group_size=fsdp_config.replica_group_size,
|
|
|
+ sharding_group_size=fsdp_config.sharding_group_size,
|
|
|
+ )
|
|
|
print("HSDP device mesh is ready")
|
|
|
|
|
|
- #setting up FSDP if enable_fsdp is enabled
|
|
|
+ # setting up FSDP if enable_fsdp is enabled
|
|
|
if train_config.enable_fsdp:
|
|
|
check_fsdp_config(fsdp_config)
|
|
|
-
|
|
|
+
|
|
|
if not train_config.use_peft and train_config.freeze_layers:
|
|
|
freeze_transformer_layers(model, train_config.num_freeze_layers)
|
|
|
|
|
|
mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
|
|
|
# Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models
|
|
|
if is_vision:
|
|
|
- my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer])
|
|
|
+ my_auto_wrapping_policy = fsdp_auto_wrap_policy(
|
|
|
+ model,
|
|
|
+ [
|
|
|
+ MllamaSelfAttentionDecoderLayer,
|
|
|
+ MllamaSelfAttentionDecoderLayer,
|
|
|
+ MllamaVisionEncoderLayer,
|
|
|
+ ],
|
|
|
+ )
|
|
|
else:
|
|
|
- # Create the FSDP wrapper for LlamaDecoderLayer in text models
|
|
|
+ # Create the FSDP wrapper for LlamaDecoderLayer in text models
|
|
|
my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer])
|
|
|
device_id = 0
|
|
|
if is_xpu_available():
|
|
@@ -204,21 +257,36 @@ def main(**kwargs):
|
|
|
device_id = torch.cuda.current_device()
|
|
|
model = FSDP(
|
|
|
model,
|
|
|
- auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
|
|
|
- cpu_offload=CPUOffload(offload_params=True) if fsdp_config.fsdp_cpu_offload else None,
|
|
|
- mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
|
|
|
+ auto_wrap_policy=(
|
|
|
+ my_auto_wrapping_policy if train_config.use_peft else wrapping_policy
|
|
|
+ ),
|
|
|
+ cpu_offload=(
|
|
|
+ CPUOffload(offload_params=True)
|
|
|
+ if fsdp_config.fsdp_cpu_offload
|
|
|
+ else None
|
|
|
+ ),
|
|
|
+ mixed_precision=(
|
|
|
+ mixed_precision_policy if not fsdp_config.pure_bf16 else None
|
|
|
+ ),
|
|
|
sharding_strategy=fsdp_config.sharding_strategy,
|
|
|
device_mesh=hsdp_device_mesh_plan,
|
|
|
device_id=device_id,
|
|
|
limit_all_gathers=True,
|
|
|
sync_module_states=train_config.low_cpu_fsdp,
|
|
|
- param_init_fn=(lambda module: module.to_empty(device=torch.device("cuda"), recurse=False))
|
|
|
- if train_config.low_cpu_fsdp and rank != 0 else None,
|
|
|
+ param_init_fn=(
|
|
|
+ (
|
|
|
+ lambda module: module.to_empty(
|
|
|
+ device=torch.device("cuda"), recurse=False
|
|
|
+ )
|
|
|
+ )
|
|
|
+ if train_config.low_cpu_fsdp and rank != 0
|
|
|
+ else None
|
|
|
+ ),
|
|
|
)
|
|
|
- if fsdp_config.fsdp_activation_checkpointing:
|
|
|
+ if fsdp_config.fsdp_activation_checkpointing:
|
|
|
model.enable_input_require_grads()
|
|
|
model.gradient_checkpointing_enable()
|
|
|
- apply_fsdp_checkpointing(model)
|
|
|
+ apply_fsdp_checkpointing(model)
|
|
|
elif not train_config.quantization and not train_config.enable_fsdp:
|
|
|
if is_xpu_available():
|
|
|
model.to("xpu:0")
|
|
@@ -252,11 +320,15 @@ def main(**kwargs):
|
|
|
if is_vision:
|
|
|
raise ValueError("Packing is not supported for vision datasets")
|
|
|
else:
|
|
|
- dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
|
|
|
+ dataset_train = ConcatDataset(
|
|
|
+ dataset_train, chunk_size=train_config.context_length
|
|
|
+ )
|
|
|
|
|
|
- train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, dataset_processer, "train")
|
|
|
+ train_dl_kwargs = get_dataloader_kwargs(
|
|
|
+ train_config, dataset_train, dataset_processer, "train"
|
|
|
+ )
|
|
|
print("length of dataset_train", len(dataset_train))
|
|
|
- custom_data_collator = get_custom_data_collator(dataset_processer,dataset_config)
|
|
|
+ custom_data_collator = get_custom_data_collator(dataset_processer, dataset_config)
|
|
|
if custom_data_collator:
|
|
|
print("custom_data_collator is used")
|
|
|
train_dl_kwargs["collate_fn"] = custom_data_collator
|
|
@@ -275,9 +347,13 @@ def main(**kwargs):
|
|
|
if is_vision:
|
|
|
raise ValueError("Packing is not supported for vision datasets")
|
|
|
else:
|
|
|
- dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
|
|
|
+ dataset_val = ConcatDataset(
|
|
|
+ dataset_val, chunk_size=train_config.context_length
|
|
|
+ )
|
|
|
|
|
|
- val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, dataset_processer, "val")
|
|
|
+ val_dl_kwargs = get_dataloader_kwargs(
|
|
|
+ train_config, dataset_val, dataset_processer, "val"
|
|
|
+ )
|
|
|
if custom_data_collator:
|
|
|
val_dl_kwargs["collate_fn"] = custom_data_collator
|
|
|
|
|
@@ -289,7 +365,9 @@ def main(**kwargs):
|
|
|
)
|
|
|
print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
|
|
|
if len(eval_dataloader) == 0:
|
|
|
- raise ValueError(f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({len(eval_dataloader)=})")
|
|
|
+ raise ValueError(
|
|
|
+ f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({len(eval_dataloader)=})"
|
|
|
+ )
|
|
|
else:
|
|
|
print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
|
|
|
|
|
@@ -324,11 +402,12 @@ def main(**kwargs):
|
|
|
rank if train_config.enable_fsdp else None,
|
|
|
wandb_run,
|
|
|
)
|
|
|
- if not train_config.enable_fsdp or rank==0:
|
|
|
- [print(f'Key: {k}, Value: {v}') for k, v in results.items()]
|
|
|
+ if not train_config.enable_fsdp or rank == 0:
|
|
|
+ [print(f"Key: {k}, Value: {v}") for k, v in results.items()]
|
|
|
if train_config.use_wandb:
|
|
|
- for k,v in results.items():
|
|
|
+ for k, v in results.items():
|
|
|
wandb_run.summary[k] = v
|
|
|
|
|
|
+
|
|
|
if __name__ == "__main__":
|
|
|
fire.Fire(main)
|