| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438 | 
							- # Copyright (c) Meta Platforms, Inc. and affiliates.
 
- # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
- import dataclasses
 
- import os
 
- import random
 
- from collections import Counter
 
- from warnings import warn
 
- import fire
 
- import numpy as np
 
- import torch
 
- import torch.optim as optim
 
- from accelerate.utils import is_xpu_available
 
- from llama_cookbook.configs import (
 
-     fsdp_config as FSDP_CONFIG,
 
-     quantization_config as QUANTIZATION_CONFIG,
 
-     train_config as TRAIN_CONFIG,
 
- )
 
- from llama_cookbook.data.concatenator import ConcatDataset
 
- from llama_cookbook.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing
 
- from llama_cookbook.utils import fsdp_auto_wrap_policy
 
- from llama_cookbook.utils.config_utils import (
 
-     check_fsdp_config,
 
-     generate_dataset_config,
 
-     generate_peft_config,
 
-     get_dataloader_kwargs,
 
-     update_config,
 
- )
 
- from llama_cookbook.utils.dataset_utils import (
 
-     get_custom_data_collator,
 
-     get_preprocessed_dataset,
 
- )
 
- from llama_cookbook.utils.fsdp_utils import hsdp_device_mesh, get_policies
 
- from llama_cookbook.utils.train_utils import (
 
-     clear_gpu_cache,
 
-     freeze_transformer_layers,
 
-     freeze_LLM_only,
 
-     print_model_size,
 
-     print_frozen_model_status,
 
-     setup,
 
-     setup_environ_flags,
 
-     train,
 
- )
 
- from peft import get_peft_model, PeftModel
 
- from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy
 
- from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
 
- from torch.optim.lr_scheduler import StepLR
 
- from transformers import (
 
-     AutoConfig,
 
-     AutoProcessor,
 
-     AutoTokenizer,
 
-     BitsAndBytesConfig,
 
-     LlamaForCausalLM,
 
-     MllamaForConditionalGeneration,
 
- )
 
- from transformers.models.llama.modeling_llama import LlamaDecoderLayer
 
- from transformers.models.mllama.modeling_mllama import (
 
-     MllamaCrossAttentionDecoderLayer,
 
-     MllamaSelfAttentionDecoderLayer,
 
-     MllamaVisionEncoderLayer,
 
- )
 
- def setup_wandb(train_config, fsdp_config, **kwargs):
 
-     try:
 
-         import wandb
 
-     except ImportError:
 
-         raise ImportError(
 
-             "You are trying to use wandb which is not currently installed. "
 
-             "Please install it using pip install wandb"
 
-         )
 
-     from llama_cookbook.configs import wandb_config as WANDB_CONFIG
 
-     wandb_config = WANDB_CONFIG()
 
-     update_config(wandb_config, **kwargs)
 
-     init_dict = dataclasses.asdict(wandb_config)
 
-     run = wandb.init(**init_dict)
 
-     run.config.update(train_config)
 
-     run.config.update(fsdp_config, allow_val_change=True)
 
-     return run
 
- def main(**kwargs):
 
-     # Update the configuration for the training and sharding process
 
-     train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG()
 
-     update_config((train_config, fsdp_config), **kwargs)
 
-     # Set the seeds for reproducibility
 
-     if is_xpu_available():
 
-         torch.xpu.manual_seed(train_config.seed)
 
-     torch.manual_seed(train_config.seed)
 
-     random.seed(train_config.seed)
 
-     np.random.seed(train_config.seed)
 
-     if train_config.enable_fsdp:
 
-         setup()
 
-         # torchrun specific
 
-         local_rank = int(os.environ["LOCAL_RANK"])
 
-         rank = int(os.environ["RANK"])
 
-         world_size = int(os.environ["WORLD_SIZE"])
 
-     if torch.distributed.is_initialized():
 
-         if is_xpu_available():
 
-             torch.xpu.set_device(local_rank)
 
-         elif torch.cuda.is_available():
 
-             torch.cuda.set_device(local_rank)
 
-         clear_gpu_cache(local_rank)
 
-         setup_environ_flags(rank)
 
-     wandb_run = None
 
-     if train_config.use_wandb:
 
-         if not train_config.enable_fsdp or rank == 0:
 
-             wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)
 
-      # setting quantization configs
 
-     bnb_config = None
 
-     if train_config.quantization:
 
-         if type(train_config.quantization) == type(True):
 
-             warn(
 
-                 "Quantization (--quantization) is a boolean, please specify quantization as '4bit' or '8bit'. Defaulting to '8bit' but this might change in the future.",
 
-                 FutureWarning,
 
-             )
 
-             train_config.quantization = "8bit"
 
-         if train_config.quantization == "8bit" and train_config.enable_fsdp:
 
-             raise ValueError(
 
-                 "8bit quantization is not supported with FSDP, please use 4bit quantization"
 
-             )
 
-         quant_config = QUANTIZATION_CONFIG()
 
-         update_config(quant_config, **kwargs)
 
-         bnb_config = quant_config.create_bnb_config(train_config.quantization)
 
-        
 
-         if train_config.enable_fsdp:
 
-             if train_config.quantization == "4bit":
 
-                 bnb_config.bnb_4bit_quant_storage = bnb_config.bnb_4bit_compute_dtype
 
-                 from logging import getLogger
 
-                 logger = getLogger()
 
-                 logger.warning(
 
-                     "FSDP and 4-bit QLoRA enabled. Setting `bnb_4bit_quant_storage` "
 
-                     f"to {bnb_config.bnb_4bit_compute_dtype} for compatibility."
 
-                 )
 
-     # Load the pre-trained model and setup its configuration
 
-     use_cache = False if train_config.enable_fsdp else None
 
-     config = AutoConfig.from_pretrained(train_config.model_name)
 
-     if config.model_type == "mllama":
 
-         is_vision = True
 
-         model = MllamaForConditionalGeneration.from_pretrained(
 
-             train_config.model_name,
 
-             quantization_config=bnb_config,
 
-             attn_implementation="sdpa" if train_config.use_fast_kernels else None,
 
-             device_map=(
 
-                 "auto"
 
-                 if train_config.quantization and not train_config.enable_fsdp
 
-                 else None
 
-             ),
 
-             torch_dtype=torch.float16 if train_config.use_fp16 else "auto",
 
-         )
 
-         processor = AutoProcessor.from_pretrained(
 
-             train_config.model_name
 
-             if train_config.tokenizer_name is None
 
-             else train_config.tokenizer_name
 
-         )
 
-         processor.tokenizer.padding_side = "right"
 
-         model.supports_gradient_checkpointing = True
 
-         model.language_model.supports_gradient_checkpointing = True
 
-     elif config.model_type == "llama":
 
-         is_vision = False
 
-         model = LlamaForCausalLM.from_pretrained(
 
-             train_config.model_name,
 
-             quantization_config=bnb_config,
 
-             use_cache=use_cache,
 
-             attn_implementation="sdpa" if train_config.use_fast_kernels else None,
 
-             device_map=(
 
-                 "auto"
 
-                 if train_config.quantization and not train_config.enable_fsdp
 
-                 else None
 
-             ),
 
-             torch_dtype=torch.float16 if train_config.use_fp16 else "auto",
 
-         )
 
-     else:
 
-         raise ValueError(
 
-             f"Model type {config.model_type} is not supported. Please use llama or mllama model."
 
-         )
 
-     # Load the tokenizer and add special tokens
 
-     tokenizer = AutoTokenizer.from_pretrained(
 
-         train_config.model_name
 
-         if train_config.tokenizer_name is None
 
-         else train_config.tokenizer_name
 
-     )
 
-     if not tokenizer.pad_token_id:
 
-         tokenizer.pad_token_id = tokenizer.eos_token_id
 
-     # If there is a mismatch between tokenizer vocab size and embedding matrix,
 
-     # throw a warning and then expand the embedding matrix
 
-     if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
 
-         print(
 
-             "WARNING: Resizing the embedding matrix to match the tokenizer vocab size."
 
-         )
 
-         model.resize_token_embeddings(len(tokenizer))
 
-     print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
 
-     # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
 
-     if (
 
-         train_config.enable_fsdp
 
-         and fsdp_config.pure_bf16
 
-         and not train_config.quantization
 
-     ):
 
-         model.to(torch.bfloat16)
 
-     if train_config.use_peft:
 
-         # Load the pre-trained peft model checkpoint and setup its configuration
 
-         if train_config.from_peft_checkpoint:
 
-             model = PeftModel.from_pretrained(
 
-                 model, train_config.from_peft_checkpoint, is_trainable=True
 
-             )
 
-             peft_config = model.peft_config
 
-         # Generate the peft config and start fine-tuning from original model
 
-         else:
 
-             peft_config = generate_peft_config(train_config, kwargs)
 
-             model = get_peft_model(model, peft_config)
 
-         if wandb_run:
 
-             wandb_run.config.update(peft_config)
 
-         model.print_trainable_parameters()
 
-     hsdp_device_mesh_plan = None
 
-     if (
 
-         fsdp_config.hsdp
 
-         and fsdp_config.sharding_strategy == ShardingStrategy.HYBRID_SHARD
 
-     ):
 
-         hsdp_device_mesh_plan = hsdp_device_mesh(
 
-             replica_group_size=fsdp_config.replica_group_size,
 
-             sharding_group_size=fsdp_config.sharding_group_size,
 
-         )
 
-         print("HSDP device mesh is ready")
 
-     # setting up FSDP if enable_fsdp is enabled
 
-     if train_config.enable_fsdp:
 
-         check_fsdp_config(fsdp_config)
 
-         if not train_config.use_peft and train_config.freeze_layers:
 
-             freeze_transformer_layers(model, train_config.num_freeze_layers)
 
-             # print model size and frozen layers after freezing layers
 
-             print_frozen_model_status(model, train_config, rank if train_config.enable_fsdp else 0)
 
-         if not train_config.use_peft and train_config.freeze_LLM_only and config.model_type == "mllama":
 
-             freeze_LLM_only(model)
 
-             # print model size and frozen layers after freezing layers
 
-             print_frozen_model_status(model, train_config, rank if train_config.enable_fsdp else 0)
 
-         mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
 
-         # Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaCrossAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models
 
-         if is_vision:
 
-             my_auto_wrapping_policy = fsdp_auto_wrap_policy(
 
-                 model,
 
-                 [
 
-                     MllamaSelfAttentionDecoderLayer,
 
-                     MllamaCrossAttentionDecoderLayer,
 
-                     MllamaVisionEncoderLayer,
 
-                 ],
 
-             )
 
-         else:
 
-             # Create the FSDP wrapper for LlamaDecoderLayer in text models
 
-             my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer])
 
-         device_id = 0
 
-         if is_xpu_available():
 
-             device_id = torch.xpu.current_device()
 
-         elif torch.cuda.is_available():
 
-             device_id = torch.cuda.current_device()
 
-        
 
-         use_orig_params = train_config.freeze_LLM_only or train_config.use_peft
 
-         
 
-         model = FSDP(
 
-             model,
 
-             auto_wrap_policy=(
 
-                 my_auto_wrapping_policy if train_config.use_peft else wrapping_policy
 
-             ),
 
-             cpu_offload=(
 
-                 CPUOffload(offload_params=True)
 
-                 if fsdp_config.fsdp_cpu_offload
 
-                 else None
 
-             ),
 
-             mixed_precision=(
 
-                 mixed_precision_policy if not fsdp_config.pure_bf16 else None
 
-             ),
 
-             sharding_strategy=fsdp_config.sharding_strategy,
 
-             device_mesh=hsdp_device_mesh_plan,
 
-             device_id=device_id,
 
-             limit_all_gathers=True,
 
-             sync_module_states=train_config.low_cpu_fsdp,
 
-             param_init_fn=(
 
-                 (
 
-                     lambda module: module.to_empty(
 
-                         device=torch.device("cuda"), recurse=False
 
-                     )
 
-                 )
 
-                 if train_config.low_cpu_fsdp and rank != 0
 
-                 else None
 
-             ),
 
-             use_orig_params=use_orig_params,
 
-         )
 
-         if fsdp_config.fsdp_activation_checkpointing:
 
-             model.enable_input_require_grads()
 
-             model.gradient_checkpointing_enable()
 
-             apply_fsdp_checkpointing(model)
 
-     elif not train_config.quantization and not train_config.enable_fsdp:
 
-         if is_xpu_available():
 
-             model.to("xpu:0")
 
-         elif torch.cuda.is_available():
 
-             model.to("cuda")
 
-     dataset_config = generate_dataset_config(train_config, kwargs)
 
-     if is_vision:
 
-         dataset_processer = processor
 
-     else:
 
-         dataset_processer = tokenizer
 
-     # Load and preprocess the dataset for training and validation
 
-     dataset_train = get_preprocessed_dataset(
 
-         dataset_processer,
 
-         dataset_config,
 
-         split="train",
 
-     )
 
-     if not train_config.enable_fsdp or rank == 0:
 
-         print(f"--> Training Set Length = {len(dataset_train)}")
 
-     dataset_val = get_preprocessed_dataset(
 
-         dataset_processer,
 
-         dataset_config,
 
-         split="test",
 
-     )
 
-     if not train_config.enable_fsdp or rank == 0:
 
-         print(f"--> Validation Set Length = {len(dataset_val)}")
 
-     if train_config.batching_strategy == "packing":
 
-         if is_vision:
 
-             raise ValueError("Packing is not supported for vision datasets")
 
-         else:
 
-             dataset_train = ConcatDataset(
 
-                 dataset_train, chunk_size=train_config.context_length
 
-             )
 
-     train_dl_kwargs = get_dataloader_kwargs(
 
-         train_config, dataset_train, dataset_processer, "train"
 
-     )
 
-     print("length of dataset_train", len(dataset_train))
 
-     custom_data_collator = get_custom_data_collator(dataset_processer, dataset_config)
 
-     if custom_data_collator:
 
-         print("custom_data_collator is used")
 
-         train_dl_kwargs["collate_fn"] = custom_data_collator
 
-     # Create DataLoaders for the training and validation dataset
 
-     train_dataloader = torch.utils.data.DataLoader(
 
-         dataset_train,
 
-         num_workers=train_config.num_workers_dataloader,
 
-         pin_memory=True,
 
-         **train_dl_kwargs,
 
-     )
 
-     print(f"--> Num of Training Set Batches loaded = {len(train_dataloader)}")
 
-     eval_dataloader = None
 
-     if train_config.run_validation:
 
-         if train_config.batching_strategy == "packing":
 
-             if is_vision:
 
-                 raise ValueError("Packing is not supported for vision datasets")
 
-             else:
 
-                 dataset_val = ConcatDataset(
 
-                     dataset_val, chunk_size=train_config.context_length
 
-                 )
 
-         val_dl_kwargs = get_dataloader_kwargs(
 
-             train_config, dataset_val, dataset_processer, "val"
 
-         )
 
-         if custom_data_collator:
 
-             val_dl_kwargs["collate_fn"] = custom_data_collator
 
-         eval_dataloader = torch.utils.data.DataLoader(
 
-             dataset_val,
 
-             num_workers=train_config.num_workers_dataloader,
 
-             pin_memory=True,
 
-             **val_dl_kwargs,
 
-         )
 
-         print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
 
-         if len(eval_dataloader) == 0:
 
-             raise ValueError(
 
-                 f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({len(eval_dataloader)=})"
 
-             )
 
-         else:
 
-             print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
 
-     # Initialize the optimizer and learning rate scheduler
 
-     if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
 
-         optimizer = AnyPrecisionAdamW(
 
-             model.parameters(),
 
-             lr=train_config.lr,
 
-             momentum_dtype=torch.bfloat16,
 
-             variance_dtype=torch.bfloat16,
 
-             use_kahan_summation=False,
 
-             weight_decay=train_config.weight_decay,
 
-         )
 
-     else:
 
-         optimizer = optim.AdamW(
 
-             model.parameters(),
 
-             lr=train_config.lr,
 
-             weight_decay=train_config.weight_decay,
 
-         )
 
-     scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
 
-     results = train(
 
-         model,
 
-         train_dataloader,
 
-         eval_dataloader,
 
-         tokenizer,
 
-         optimizer,
 
-         scheduler,
 
-         train_config.gradient_accumulation_steps,
 
-         train_config,
 
-         fsdp_config if train_config.enable_fsdp else None,
 
-         local_rank if train_config.enable_fsdp else None,
 
-         rank if train_config.enable_fsdp else None,
 
-         wandb_run,
 
-     )
 
-     if not train_config.enable_fsdp or rank == 0:
 
-         [print(f"Key: {k}, Value: {v}") for k, v in results.items()]
 
-         if train_config.use_wandb:
 
-             for k, v in results.items():
 
-                 wandb_run.summary[k] = v
 
- if __name__ == "__main__":
 
-     fire.Fire(main)
 
 
  |