浏览代码

[WIP] enabled fsdpv2 checkpointing, peft+full-vision enabled for fsdpv2, refactoring

Matthias Reso 6 月之前
父节点
当前提交
6d5b221f44

+ 67 - 118
src/llama_recipes/finetuning.py

@@ -14,14 +14,14 @@ import torch.optim as optim
 from accelerate.utils import is_xpu_available
 
 from llama_recipes.configs import (
-    fsdp_config as FSDP_CONFIG,
-    quantization_config as QUANTIZATION_CONFIG,
-    train_config as TRAIN_CONFIG,
+    fsdp_config as FsdpConfig,
+    quantization_config as QuantizationConfig,
+    train_config as TrainConfig,
 )
 from llama_recipes.data.concatenator import ConcatDataset
 from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing
 
-from llama_recipes.utils import fsdp_auto_wrap_policy
+from llama_recipes.utils import get_model_and_data_processor
 from llama_recipes.utils.config_utils import (
     check_fsdp_config,
     generate_dataset_config,
@@ -38,8 +38,6 @@ from llama_recipes.utils.fsdp_utils import hsdp_device_mesh
 from llama_recipes.utils.train_utils import (
     clear_gpu_cache,
     freeze_transformer_layers,
-    get_policies,
-    print_model_size,
     setup,
     setup_environ_flags,
     train,
@@ -53,8 +51,6 @@ from transformers import (
     AutoProcessor,
     AutoTokenizer,
     BitsAndBytesConfig,
-    LlamaForCausalLM,
-    MllamaForConditionalGeneration,
 )
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
 from transformers.models.mllama.modeling_mllama import (
@@ -72,9 +68,9 @@ def setup_wandb(train_config, fsdp_config, **kwargs):
             "You are trying to use wandb which is not currently installed. "
             "Please install it using pip install wandb"
         )
-    from llama_recipes.configs import wandb_config as WANDB_CONFIG
+    from llama_recipes.configs import wandb_config as WandBConfig
 
-    wandb_config = WANDB_CONFIG()
+    wandb_config = WandBConfig()
     update_config(wandb_config, **kwargs)
     init_dict = dataclasses.asdict(wandb_config)
     run = wandb.init(**init_dict)
@@ -85,7 +81,7 @@ def setup_wandb(train_config, fsdp_config, **kwargs):
 
 def main(**kwargs):
     # Update the configuration for the training and sharding process
-    train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG()
+    train_config, fsdp_config = TrainConfig(), FsdpConfig()
     update_config((train_config, fsdp_config), **kwargs)
     # Set the seeds for reproducibility
     if is_xpu_available():
@@ -116,7 +112,7 @@ def main(**kwargs):
             wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)
 
     # setting quantization configs
-    bnb_config = None
+    quant_config = None
     if train_config.quantization:
         if type(train_config.quantization) == type(True):
             warn(
@@ -130,70 +126,15 @@ def main(**kwargs):
                 "8bit quantization is not supported with FSDP, please use 4bit quantization"
             )
 
-        quant_config = QUANTIZATION_CONFIG()
+        quant_config = QuantizationConfig()
         update_config(quant_config, **kwargs)
-        bnb_config = quant_config.create_bnb_config(train_config.quantization)
 
     # Load the pre-trained model and setup its configuration
-    use_cache = False if train_config.enable_fsdp else None
-    config = AutoConfig.from_pretrained(train_config.model_name)
-    if config.model_type == "mllama":
-        is_vision = True
-        model = MllamaForConditionalGeneration.from_pretrained(
-            train_config.model_name,
-            quantization_config=bnb_config,
-            attn_implementation="sdpa" if train_config.use_fast_kernels else None,
-            device_map=(
-                "auto"
-                if train_config.quantization and not train_config.enable_fsdp
-                else None
-            ),
-            torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
-        )
-        processor = AutoProcessor.from_pretrained(
-            train_config.model_name
-            if train_config.tokenizer_name is None
-            else train_config.tokenizer_name
-        )
-        processor.tokenizer.padding_side = "right"
-        model.supports_gradient_checkpointing = True
-        model.language_model.supports_gradient_checkpointing = True
-    elif config.model_type == "llama":
-        is_vision = False
-        model = LlamaForCausalLM.from_pretrained(
-            train_config.model_name,
-            quantization_config=bnb_config,
-            use_cache=use_cache,
-            attn_implementation="sdpa" if train_config.use_fast_kernels else None,
-            device_map=(
-                "auto"
-                if train_config.quantization and not train_config.enable_fsdp
-                else None
-            ),
-            torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
-        )
+    model, dataset_processer, is_vision = get_model_and_data_processor(train_config, quant_config)
+    if is_vision:
+        tokenizer = dataset_processer.tokenizer
     else:
-        raise ValueError(
-            f"Model type {config.model_type} is not supported. Please use llama or mllama model."
-        )
-    # Load the tokenizer and add special tokens
-    tokenizer = AutoTokenizer.from_pretrained(
-        train_config.model_name
-        if train_config.tokenizer_name is None
-        else train_config.tokenizer_name
-    )
-    if not tokenizer.pad_token_id:
-        tokenizer.pad_token_id = tokenizer.eos_token_id
-
-    # If there is a mismatch between tokenizer vocab size and embedding matrix,
-    # throw a warning and then expand the embedding matrix
-    if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
-        print(
-            "WARNING: Resizing the embedding matrix to match the tokenizer vocab size."
-        )
-        model.resize_token_embeddings(len(tokenizer))
-
-    print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
+        tokenizer = dataset_processer
 
     # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
     if (
@@ -235,71 +176,79 @@ def main(**kwargs):
 
         if not train_config.use_peft and train_config.freeze_layers:
             freeze_transformer_layers(model, train_config.num_freeze_layers)
+            
+        device_id = 0
+        if is_xpu_available():
+            device_id = torch.xpu.current_device()
+        elif torch.cuda.is_available():
+            device_id = torch.cuda.current_device()
+        from llama_recipes.utils.fsdp_utils import parallelize_model
+
+        # model = FSDP(
+        #     
+        #     cpu_offload=(
+        #         CPUOffload(offload_params=True)
+        #         if fsdp_config.fsdp_cpu_offload
+        #         else None
+        #     ),
+        #     mixed_precision=(
+        #         mixed_precision_policy if not fsdp_config.pure_bf16 else None
+        #     ),
+        #     sharding_strategy=fsdp_config.sharding_strategy,
+        #     device_mesh=hsdp_device_mesh_plan,
+        #     device_id=device_id,
+        #     limit_all_gathers=True,
+        #     sync_module_states=train_config.low_cpu_fsdp,
+        #     param_init_fn=(
+        #         (
+        #             lambda module: module.to_empty(
+        #                 device=torch.device("cuda"), recurse=False
+        #             )
+        #         )
+        #         if train_config.low_cpu_fsdp and rank != 0
+        #         else None
+        #     ),
+        # )
 
-        mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
-        # Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models
         if is_vision:
-            my_auto_wrapping_policy = fsdp_auto_wrap_policy(
-                model,
-                [
+            MODS = (
                     MllamaSelfAttentionDecoderLayer,
                     MllamaSelfAttentionDecoderLayer,
                     MllamaVisionEncoderLayer,
-                ],
             )
+            sharding_conditions = [
+                lambda m: any(isinstance(m,n) for n in MODS),
+            ]
         else:
-            # Create the FSDP wrapper for LlamaDecoderLayer in text models
-            my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer])
-        device_id = 0
-        if is_xpu_available():
-            device_id = torch.xpu.current_device()
-        elif torch.cuda.is_available():
-            device_id = torch.cuda.current_device()
-        model = FSDP(
-            model,
-            auto_wrap_policy=(
-                my_auto_wrapping_policy if train_config.use_peft else wrapping_policy
-            ),
-            cpu_offload=(
-                CPUOffload(offload_params=True)
-                if fsdp_config.fsdp_cpu_offload
-                else None
-            ),
-            mixed_precision=(
-                mixed_precision_policy if not fsdp_config.pure_bf16 else None
-            ),
-            sharding_strategy=fsdp_config.sharding_strategy,
-            device_mesh=hsdp_device_mesh_plan,
-            device_id=device_id,
-            limit_all_gathers=True,
-            sync_module_states=train_config.low_cpu_fsdp,
-            param_init_fn=(
-                (
-                    lambda module: module.to_empty(
-                        device=torch.device("cuda"), recurse=False
-                    )
+            sharding_conditions = [lambda m: isinstance(m, LlamaDecoderLayer)]
+
+        if train_config.use_peft:
+            sharding_conditions += [
+                lambda m: (
+                    len(list(m.named_children())) == 0
+                    and getattr(m, "weight", None) is not None
+                    and m.weight.requires_grad
                 )
-                if train_config.low_cpu_fsdp and rank != 0
-                else None
-            ),
+            ]
+
+        parallelize_model(
+            model,
+            fsdp_config,
+            device_mesh = hsdp_device_mesh_plan,
+            sharding_conditions = sharding_conditions,
         )
+        
         if fsdp_config.fsdp_activation_checkpointing:
             model.enable_input_require_grads()
             model.gradient_checkpointing_enable()
-            apply_fsdp_checkpointing(model)
     elif not train_config.quantization and not train_config.enable_fsdp:
         if is_xpu_available():
             model.to("xpu:0")
         elif torch.cuda.is_available():
             model.to("cuda")
     dataset_config = generate_dataset_config(train_config, kwargs)
-    if is_vision:
-        dataset_processer = processor
-    else:
-        dataset_processer = tokenizer
-
+    
     # Load and preprocess the dataset for training and validation
-
     dataset_train = get_preprocessed_dataset(
         dataset_processer,
         dataset_config,

+ 4 - 4
src/llama_recipes/model_checkpointing/__init__.py

@@ -3,12 +3,12 @@
 
 from llama_recipes.model_checkpointing.checkpoint_handler import (
     load_model_checkpoint,
-    save_fsdp_model_checkpoint_full,
+    save_fsdp_checkpoint_full,
+    save_fsdp_checkpoint_sharded,
     save_peft_checkpoint,
     save_model_checkpoint,
+    save_checkpoint,
     load_optimizer_checkpoint,
-    save_optimizer_checkpoint,
-    save_model_and_optimizer_sharded,
-    load_model_sharded,
+    load_fsdp_checkpoint_sharded,
     load_sharded_model_single_gpu
 )

+ 128 - 103
src/llama_recipes/model_checkpointing/checkpoint_handler.py

@@ -8,6 +8,9 @@ from pathlib import Path
 import torch
 import torch.distributed as dist
 
+from torch.distributed.checkpoint.state_dict import get_state_dict, StateDictOptions
+from torch.distributed.checkpoint.state_dict_saver import save
+from torch.distributed.checkpoint.state_dict_loader import load
 from torch.distributed.checkpoint import (
     FileSystemReader,
     FileSystemWriter,
@@ -47,118 +50,128 @@ def get_date_of_run():
 fullstate_save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
 
 
-def load_model_sharded(model, rank, cfg):
-    # torch.manual_seed(103)
-    folder_name = (
-        cfg.dist_checkpoint_root_folder
-        + "/"
-        + cfg.dist_checkpoint_folder
-        + "-"
-        + cfg.model_name
-    )
+def load_fsdp_checkpoint_sharded(model, cfg, epoch=1, optimizer=None):
+    rank = dist.get_rank()
+    folder_name = "-".join((cfg.dist_checkpoint_folder, cfg.model_name, str(epoch)))
 
-    load_dir = Path.cwd() / folder_name
+    load_dir = Path.cwd() / cfg.dist_checkpoint_root_folder / folder_name
 
     if not load_dir.exists():
         if rank == 0:
-            print(f"No sharded_state_dict checkpoint directory found...skipping")
+            print(f"No sharded_state_dict checkpoint directory at {load_dir.as_posix()} found...skipping")
         return
     if rank == 0:
-        print(f"loading model from model path: {load_dir} ")
+        print(f"loading model from model path: {load_dir.as_posix()} ")
     reader = FileSystemReader(load_dir)
 
-    with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
-        checkpoint = {"model": model.state_dict()}
-        if rank == 0:
-            ck = checkpoint.keys()
-            print(f" checkpoint key len = {len(ck)} and \n keys =  {ck}")
+    checkpoint = {"model": model}
+    if optimizer is not None:
+        checkpoint["optimizer"] = optimizer
+    if rank == 0:
+        ck = checkpoint.keys()
+        print(f" checkpoint key len = {len(ck)} and \n keys =  {ck}")
+
+    load(
+        state_dict=checkpoint,
+        storage_reader=reader,
+    )
+    if rank == 0:
+        print(f"checkpoint after load_state_dict()")
+        ck = checkpoint.keys()
+        print(f" checkpoint key len = {len(ck)} and \n keys =  {ck}")
+
+    model.load_state_dict(checkpoint["model"])
+    if optimizer is not None:
+        optimizer.load_state_dict(checkpoint["optimizer"])
 
-        load_state_dict(
-            state_dict=checkpoint,
-            storage_reader=reader,
-        )
-        if rank == 0:
-            print(f"checkpoint after load_state_dict()")
-            ck = checkpoint.keys()
-            print(f" checkpoint key len = {len(ck)} and \n keys =  {ck}")
-        model.load_state_dict(checkpoint["model"])
     if rank == 0:
         print(f"Sharded state checkpoint loaded from {load_dir}")
 
 
-def save_model_and_optimizer_sharded(model, rank, cfg, optim=None):
+def save_fsdp_checkpoint_sharded(model, optimizer, train_config, epoch=1):
     """save model and optimizer via sharded_state_dict to save_dir"""
 
-    folder_name = (
-        cfg.dist_checkpoint_root_folder
-        + "/"
-        + cfg.dist_checkpoint_folder
-        + "-"
-        + cfg.model_name
-    )
+    folder_name = "-".join((train_config.dist_checkpoint_folder, train_config.model_name, str(epoch)))
+
+    save_dir = Path.cwd() / train_config.dist_checkpoint_root_folder / folder_name
+
+    rank = dist.get_rank()
 
-    save_dir = Path.cwd() / folder_name
     if rank == 0:
-        print(f"Saving model to {save_dir}")
+        print(f"Saving model to {save_dir.as_posix()}")
 
     distributed_writer = FileSystemWriter(
         save_dir,
+        overwrite=True,
     )
     t0 = time.perf_counter()
 
-    with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
+    options = StateDictOptions(
+        full_state_dict=False,
+    )
 
-        state_dict = {"model": model.state_dict()}
-        if optim is not None:
-            state_dict["optim"] = FSDP.optim_state_dict(model, optim)
+    optim = optimizer if train_config.save_optimizer else []
 
-        save_state_dict(
-            state_dict=state_dict,
-            storage_writer=distributed_writer,
-            planner=DefaultSavePlanner(),
-        )
+    state_dict = {"model": model}
+    if train_config.save_optimizer:
+        state_dict["optimizer"] = optimizer
+
+    save(
+        state_dict=state_dict,
+        storage_writer=distributed_writer,
+        planner=DefaultSavePlanner(),
+    )
     dist.barrier()
     t1 = time.perf_counter()
     if rank == 0:
-        print(f"Sharded state checkpoint saved to {save_dir}")
+        print(f"Sharded state checkpoint saved to {save_dir.as_posix()}")
         print(f"Checkpoint Time = {t1-t0:.4f}\n")
 
 
-def save_fsdp_model_checkpoint_full(
+def save_fsdp_checkpoint_full(
     model,
     optimizer,
-    rank,
-    cfg,
+    train_config,
     epoch=1,
 ):
     """saving model via rank0 cpu streaming and full_state_dict"""
 
-    with FSDP.state_dict_type(
-        model, StateDictType.FULL_STATE_DICT, fullstate_save_policy
-    ):
-        cpu_state = model.state_dict()
+    options = StateDictOptions(
+        full_state_dict=True,
+    )
 
-        print(f"saving process: rank {rank}  done w model state_dict\n")
+    optim = optimizer if train_config.save_optimizer else []
+
+    model_state, optim_state = get_state_dict(model, optim, options=options)
+
+    rank = dist.get_rank()
 
     if rank == 0:
         print(f"--> saving model ...")
         # create save path
-        folder_name = (
-            cfg.dist_checkpoint_root_folder
-            + "/"
-            + cfg.dist_checkpoint_folder
-            + "-"
-            + cfg.model_name
-        )
-        save_dir = Path.cwd() / folder_name
+        folder_name = "-".join((train_config.dist_checkpoint_folder, train_config.model_name))
+        save_dir = Path.cwd() / train_config.dist_checkpoint_root_folder / folder_name
         save_dir.mkdir(parents=True, exist_ok=True)
-        save_name = cfg.model_name.replace("/", "--") + "-" + str(epoch) + ".pt"
-        save_full_path = str(save_dir) + "/" + save_name
+
+        save_name = train_config.model_name.replace("/", "--") + "-" + str(epoch) + ".pt"
+        save_full_path = save_dir / save_name
 
         # save model
-        torch.save(cpu_state, save_full_path)
+        torch.save(model_state, save_full_path)
+
+        print(f"model checkpoint saved for epoch {epoch} at {save_full_path.as_posix()}\n")
+
+        if not train_config.save_optimizer:
+            return
+
+        opt_save_name = "optimizer" + "-" + train_config.model_name.replace("/", "--") + "-" + str(epoch) + ".pt"
+        opt_save_full_path = save_dir / opt_save_name
+
+        print(f"--> saving optimizer state...")
 
-        print(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n")
+        torch.save(optim_state, opt_save_full_path)
+
+        print(f"--> saved {opt_save_full_path.as_posix()} to disk")
 
 
 def load_model_checkpoint(model, rank, cfg):
@@ -186,38 +199,6 @@ def load_model_checkpoint(model, rank, cfg):
     print(f"model checkpoint loaded to rank0 cpu")
 
 
-def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1):
-    """save optimizer state via full state dict"""
-
-    print(f"--> optim state call on rank {rank}\n")
-
-    # pull all sharded optimizer states to rank0 cpu...
-
-    optim_state = FSDP.full_optim_state_dict(model, optimizer)
-
-    print(f"optim state dict ready on {rank} and len of {len(optim_state)}\n")
-
-    if rank == 0:
-        folder_name = (
-            cfg.dist_checkpoint_root_folder
-            + "/"
-            + cfg.dist_checkpoint_folder
-            + "-"
-            + cfg.model_name
-        )
-        save_dir = Path.cwd() / folder_name
-        save_dir.mkdir(parents=True, exist_ok=True)
-
-        opt_save_name = "optimizer" + "-" + cfg.model_name + "-" + str(epoch) + ".pt"
-        opt_save_full_path = save_dir / opt_save_name
-
-        print(f"--> saving optimizer state...")
-
-        torch.save(optim_state, opt_save_full_path)
-
-        print(f"--> saved {opt_save_full_path} to disk")
-
-
 def load_optimizer_checkpoint(model, optimizer_checkpoint_path, rank):
     """load an fsdp optimizer full_state checkpoint using scatter method
     this ensures only rank 0 loads the optimizer state dict and scatters to other ranks
@@ -258,14 +239,20 @@ def load_sharded_model_single_gpu(model, model_path):
     return model
 
 
-def save_peft_checkpoint(model, model_path):
+def save_peft_checkpoint(model, train_config):
     """save_pretrained peft model"""
+    if train_config.enable_fsdp:
+        options = StateDictOptions(
+            full_state_dict=True,
+            cpu_offload=True,
+        )
 
-    options = StateDictOptions(full_state_dict=True, cpu_offload=True)
+        model_state, _ = get_state_dict(model, [], options=options)
 
-    if isinstance(model, FSDP):
-        state_dict = get_model_state_dict(model, options=options)
-        model.save_pretrained(model_path, state_dict=state_dict)
+        rank = dist.get_rank()
+        if rank == 0:
+            model_path = train_config.output_dir
+            model.save_pretrained(model_path, state_dict=model_state)
     else:
         model.save_pretrained(model_path)
 
@@ -278,3 +265,41 @@ def save_model_checkpoint(model, output_dir):
     state_dict = model.state_dict()
 
     torch.save(state_dict, output_file)
+
+
+def save_checkpoint(model, optimizer, train_config, fsdp_config, epoch):
+    """save model and optimizer"""
+    rank = dist.get_rank() if train_config.enable_fsdp else 0
+
+    if train_config.enable_fsdp:
+        dist.barrier()
+    if train_config.use_peft:
+        if rank == 0:
+            print(f"we are about to save the PEFT modules")
+        save_peft_checkpoint(model, train_config)
+        
+        if rank == 0:
+            print(f"PEFT modules are saved in {train_config.output_dir} directory")
+
+    else:
+        if not train_config.enable_fsdp:
+            save_model_checkpoint(model, train_config.output_dir)
+
+        elif fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
+            if rank == 0:
+                print(" Saving the FSDP model checkpoint using FULL_STATE_DICT")
+                print("=====================================================")
+            save_fsdp_checkpoint_full(
+                model, optimizer, train_config, epoch=epoch
+            )
+
+        elif fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
+            if rank == 0:
+                print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
+                print("=====================================================")
+            save_fsdp_checkpoint_sharded(
+                model, optimizer, train_config, epoch=epoch
+            )
+
+    if train_config.enable_fsdp:
+        dist.barrier()

+ 41 - 14
src/llama_recipes/policies/mixed_precision.py

@@ -2,37 +2,64 @@
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 import torch
-
-from torch.distributed.fsdp import (
-    MixedPrecision,
-)
+import torch.cuda.nccl as nccl
+import torch.distributed as dist
+from torch.distributed._composable.fsdp import MixedPrecisionPolicy
+ 
 
 # requires grad scaler in main loop
-fpSixteen = MixedPrecision(
+fpSixteen = MixedPrecisionPolicy(
     param_dtype=torch.float16,
     # Gradient communication precision.
     reduce_dtype=torch.float16,
-    # Buffer precision.
-    buffer_dtype=torch.float16,
 )
 
-bfSixteen = MixedPrecision(
+bfSixteen = MixedPrecisionPolicy(
     param_dtype=torch.bfloat16,
     # Gradient communication precision.
     reduce_dtype=torch.bfloat16,
-    # Buffer precision.
-    buffer_dtype=torch.bfloat16,
     cast_forward_inputs=True,
 )
 
-bfSixteen_mixed = MixedPrecision(
+bfSixteen_mixed = MixedPrecisionPolicy(
     param_dtype=torch.float32,
     reduce_dtype=torch.bfloat16,
-    buffer_dtype=torch.bfloat16,
 )
 
-fp32_policy = MixedPrecision(
+fp32_policy = MixedPrecisionPolicy(
     param_dtype=torch.float32,
     reduce_dtype=torch.float32,
-    buffer_dtype=torch.float32,
 )
+
+
+def get_mixed_precision_policies(cfg):
+    """Get the policies for mixed precision and fsdp wrapping"""
+
+    rank = dist.get_rank()
+
+    verify_bfloat_support = (
+        torch.version.cuda
+        and torch.cuda.is_bf16_supported()
+        and torch.version.cuda >= "11.0"
+        and dist.is_nccl_available()
+        and nccl.version() >= (2, 10)
+    ) or (is_xpu_available())
+
+    mixed_precision_policy = None
+
+    # Mixed precision
+    if cfg.mixed_precision:
+        bf16_ready = verify_bfloat_support
+
+        if bf16_ready and not cfg.use_fp16:
+            mixed_precision_policy = bfSixteen
+            if rank == 0:
+                print(f"bFloat16 enabled for mixed precision - using bfSixteen policy")
+        elif cfg.use_fp16:
+            mixed_precision_policy = fpSixteen
+            if rank == 0:
+                print(f"FP16 enabled")
+        else:
+            if rank == 0:
+                print(f"bFloat16 support not present. Using FP32, and not mixed precision")
+    return mixed_precision_policy

+ 3 - 2
src/llama_recipes/utils/__init__.py

@@ -1,7 +1,8 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
+from llama_recipes.utils.model_utils import get_model_and_data_processor
 from llama_recipes.utils.memory_utils import MemoryTrace
 from llama_recipes.utils.dataset_utils import *
-from llama_recipes.utils.fsdp_utils import fsdp_auto_wrap_policy, hsdp_device_mesh
-from llama_recipes.utils.train_utils import *
+from llama_recipes.utils.fsdp_utils import hsdp_device_mesh
+from llama_recipes.utils.train_utils import *

+ 78 - 31
src/llama_recipes/utils/fsdp_utils.py

@@ -1,30 +1,14 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
-from torch.distributed._tensor.device_mesh import init_device_mesh
-import os 
+import os
 
-def fsdp_auto_wrap_policy(model, transformer_layer_names):
-    import functools
-
-    from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy
-
-    def lambda_policy_fn(module):
-        if (
-            len(list(module.named_children())) == 0
-            and getattr(module, "weight", None) is not None
-            and module.weight.requires_grad
-        ):
-            return True
-        return False
-
-    lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
-    transformer_wrap_policy = functools.partial(
-        transformer_auto_wrap_policy,
-        transformer_layer_cls=set(transformer_layer_names)
-    )
-
-    auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy])
-    return auto_wrap_policy
+import torch
+import torch.nn as nn
+from llama_recipes.configs.fsdp import fsdp_config as FSDP_CONFIG
+from llama_recipes.policies import get_mixed_precision_policies
+from torch.distributed._composable.fsdp import fully_shard, CPUOffloadPolicy
+from torch.distributed._tensor.device_mesh import DeviceMesh, init_device_mesh
+from typing import List, Callable
 
 
 def hsdp_device_mesh(replica_group_size, sharding_group_size, device=None):
@@ -33,11 +17,11 @@ def hsdp_device_mesh(replica_group_size, sharding_group_size, device=None):
 
     This function requires explicit sizes for replica and sharding groups to accommodate models
     whose GPU fit is unknown, providing flexibility in distributed training setups.
-    
+
     Args:
         replica_group_size (int): The size of each replica group. Must be provided to ensure
             the model fits within the available resources.
-        sharding_group_size (int): The size of each sharding group that the model can fit. Must be provided to 
+        sharding_group_size (int): The size of each sharding group that the model can fit. Must be provided to
             ensure the correct distribution of model parameters.
         device (str, optional): The device to use (e.g., "cuda:0"). If None, defaults to "cuda"
             with the local rank as the device index.
@@ -59,7 +43,9 @@ def hsdp_device_mesh(replica_group_size, sharding_group_size, device=None):
     """
 
     if replica_group_size is None or sharding_group_size is None:
-        raise ValueError("Both replica_group_size and sharding_group_size must be provided.")
+        raise ValueError(
+            "Both replica_group_size and sharding_group_size must be provided."
+        )
 
     local_rank = int(os.getenv("LOCAL_RANK", "0"))
     world_size = int(os.getenv("WORLD_SIZE", "1"))
@@ -67,15 +53,76 @@ def hsdp_device_mesh(replica_group_size, sharding_group_size, device=None):
     device = device or f"cuda"
 
     if world_size % sharding_group_size != 0:
-        raise ValueError(f"World size {world_size} is not evenly divisible by "
-                         f"sharding group size {sharding_group_size}.")
+        raise ValueError(
+            f"World size {world_size} is not evenly divisible by "
+            f"sharding group size {sharding_group_size}."
+        )
 
     if (world_size // sharding_group_size) % replica_group_size != 0:
-        raise ValueError(f"The calculated number of replica groups is not evenly divisible by "
-                         f"replica_group_size {replica_group_size}.")
+        raise ValueError(
+            f"The calculated number of replica groups is not evenly divisible by "
+            f"replica_group_size {replica_group_size}."
+        )
 
     device_mesh = init_device_mesh(device, (replica_group_size, sharding_group_size))
     if device_mesh is None:
         raise RuntimeError("Failed to create a valid device mesh.")
 
     return device_mesh
+
+
+def parallelize_model(
+    model: nn.Module,
+    fsdp_config: FSDP_CONFIG,
+    device_mesh: DeviceMesh = None,
+    sharding_conditions: List[Callable] = None,
+) -> nn.Module:
+    """
+    Parallelizes a Llama model using FSDP.
+
+    Args:
+        model (nn.Module): The Llama model to parallelize.
+        fsdp_config (FSDP_CONFIG): The FSDP configuration.
+        device_mesh (torch.device_mesh): The device mesh to use for parallelization.
+
+    Returns:
+        None
+    """
+
+    mp_policy = get_mixed_precision_policies(fsdp_config)
+    fsdp_config = {
+        "mesh": device_mesh,
+        "mp_policy": None if fsdp_config.pure_bf16 else mp_policy,
+        "offload_policy": CPUOffloadPolicy() if fsdp_config.fsdp_cpu_offload else None
+        }
+
+    # Following torchtune's approach to wrap Lora first as dtype is different from base
+    for m in reversed(list(model.modules())):
+        if any(c(m) for c in sharding_conditions):
+            fully_shard(m, reshard_after_forward=True)
+
+    # 
+    # if hasattr(model, "base_model") and hasattr(model.base_model, "model"):
+    #     for n, m in reversed(list(model.named_modules())):
+    #         if any(c(m) for c in sharding_conditions):
+    #         # if (
+    #         #     len(list(m.named_children())) == 0
+    #         #     and getattr(m, "weight", None) is not None
+    #         #     and m.weight.requires_grad
+    #         # ):
+    #             fully_shard(m, reshard_after_forward=True)
+    #     layers = model.base_model.model.model.layers
+    # else:
+    #     layers = model.model.layers
+
+    # for idx, layer in enumerate(layers):
+    #     # Following torch titan we will not reshard the last layer
+    #     # https://github.com/pytorch/torchtitan/blob/7310abea8782bbe459b662bc6d8411fe8d55f62c/torchtitan/parallelisms/parallelize_llama.py#L347
+    #     reshard_after_forward = idx < len(layers) - 1
+    #     fully_shard(
+    #         layer,
+    #         reshard_after_forward=reshard_after_forward,
+    #     )
+
+    # Shard remaining modules like embeddings
+    fully_shard(model, **fsdp_config)

+ 108 - 0
src/llama_recipes/utils/model_utils.py

@@ -0,0 +1,108 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import torch
+import torch.nn as nn
+import torch.distributed as dist
+from llama_recipes.configs import (
+    quantization_config as QuantizationConfig,
+    train_config as TrainConfig
+)
+from transformers import (
+    AutoConfig,
+    AutoProcessor,
+    AutoTokenizer,
+    LlamaForCausalLM,
+    MllamaForConditionalGeneration,
+)
+
+
+def print_model_size(model: nn.Module, config: TrainConfig, rank: int = 0) -> None:
+    """
+    Print model name, the number of trainable parameters and initialization time.
+
+    Args:
+        model: The PyTorch model.
+        model_name (str): Name of the model.
+        init_time_start (float): Initialization start time.
+        init_time_end (float): Initialization end time.
+        rank (int, optional): Current process's rank. Defaults to 0.
+    """
+    if rank == 0:
+        print(f"--> Model {config.model_name}")
+        total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+        print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n")
+
+
+def get_model_and_data_processor(
+    train_config: TrainConfig, quant_config: QuantizationConfig
+):
+    bnb_config = None
+    if quant_config:
+        bnb_config = quant_config.create_bnb_config(train_config.quantization)
+
+    use_cache = False if train_config.enable_fsdp else None
+    config = AutoConfig.from_pretrained(train_config.model_name)
+    if config.model_type == "mllama":
+        is_vision = True
+        model = MllamaForConditionalGeneration.from_pretrained(
+            train_config.model_name,
+            quantization_config=bnb_config,
+            attn_implementation="sdpa" if train_config.use_fast_kernels else None,
+            device_map=(
+                "auto"
+                if train_config.quantization and not train_config.enable_fsdp
+                else None
+            ),
+            torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
+        )
+        processor = AutoProcessor.from_pretrained(
+            train_config.model_name
+            if train_config.tokenizer_name is None
+            else train_config.tokenizer_name
+        )
+        processor.tokenizer.padding_side = "right"
+        model.supports_gradient_checkpointing = True
+        model.language_model.supports_gradient_checkpointing = True
+    elif config.model_type == "llama":
+        is_vision = False
+        model = LlamaForCausalLM.from_pretrained(
+            train_config.model_name,
+            quantization_config=bnb_config,
+            use_cache=use_cache,
+            attn_implementation="sdpa" if train_config.use_fast_kernels else None,
+            device_map=(
+                "auto"
+                if train_config.quantization and not train_config.enable_fsdp
+                else None
+            ),
+            torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
+        )
+
+        # Load the tokenizer and add special tokens
+        processor = AutoTokenizer.from_pretrained(
+            train_config.model_name
+            if train_config.tokenizer_name is None
+            else train_config.tokenizer_name
+        )
+        if not processor.pad_token_id:
+            processor.pad_token_id = processor.eos_token_id
+
+        # If there is a mismatch between tokenizer vocab size and embedding matrix,
+        # throw a warning and then expand the embedding matrix
+        if len(processor) > model.get_input_embeddings().weight.shape[0]:
+            print(
+                "WARNING: Resizing the embedding matrix to match the tokenizer vocab size."
+            )
+            model.resize_token_embeddings(len(processor))
+
+    else:
+        raise ValueError(
+            f"Model type {config.model_type} is not supported. Please use llama or mllama model."
+        )
+
+    print_model_size(
+        model, train_config, dist.get_rank() if train_config.enable_fsdp else 0
+    )
+
+    return model, processor, is_vision

+ 239 - 207
src/llama_recipes/utils/train_utils.py

@@ -1,34 +1,34 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
+import contextlib
+import json
 import os
 import time
-import yaml
 from contextlib import nullcontext
-from pathlib import Path
 from datetime import datetime
-import contextlib
-
+from pathlib import Path
 
 import torch
-import torch.cuda.nccl as nccl
 import torch.distributed as dist
+import yaml
+from accelerate.utils import is_ccl_available, is_xpu_available
+
+from llama_recipes.model_checkpointing import save_checkpoint
+from llama_recipes.policies import bfSixteen, fpSixteen, get_llama_wrapper
+from llama_recipes.utils.flop_utils import FlopMeasure
+from llama_recipes.utils.memory_utils import MemoryTrace
 from torch.distributed.fsdp import StateDictType
 from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
 from tqdm import tqdm
 from transformers import LlamaTokenizer
-import json
 
 
-from llama_recipes.model_checkpointing import save_fsdp_model_checkpoint_full, save_model_and_optimizer_sharded, save_optimizer_checkpoint, save_peft_checkpoint, save_model_checkpoint
-from llama_recipes.policies import fpSixteen,bfSixteen, get_llama_wrapper
-from llama_recipes.utils.memory_utils import MemoryTrace
-from accelerate.utils import is_xpu_available, is_ccl_available
-from llama_recipes.utils.flop_utils import FlopMeasure
 def set_tokenizer_params(tokenizer: LlamaTokenizer):
     tokenizer.pad_token_id = 0
     tokenizer.padding_side = "left"
 
+
 @contextlib.contextmanager
 def profile(cfg, local_rank=None):
     use_profiler: bool = cfg.use_profiler
@@ -40,17 +40,21 @@ def profile(cfg, local_rank=None):
         wait_step, warmup_step, active_step = 1, 2, 3
         min_step = wait_step + warmup_step + active_step + 1
         if cfg.max_train_step > 0 and cfg.max_train_step < min_step:
-            raise ValueError(f"pytorch profiler requires at least {min_step} train steps to finish the warm-up and recording stage, {wait_step} for wait_step, {warmup_step} for warmup_step, {active_step} for profiling step, please increase the max_train_step, current max_train_step {cfg.max_train_step}")
-        print(f"pytorch profiling is activated and results will be saved in {cfg.profiler_dir}")
+            raise ValueError(
+                f"pytorch profiler requires at least {min_step} train steps to finish the warm-up and recording stage, {wait_step} for wait_step, {warmup_step} for warmup_step, {active_step} for profiling step, please increase the max_train_step, current max_train_step {cfg.max_train_step}"
+            )
+        print(
+            f"pytorch profiling is activated and results will be saved in {cfg.profiler_dir}"
+        )
         with torch.profiler.profile(
             activities=[
                 torch.profiler.ProfilerActivity.CPU,
                 torch.profiler.ProfilerActivity.CUDA,
             ],
-            schedule=torch.profiler.schedule(wait=wait_step, warmup=warmup_step, active=active_step, repeat=1),
-            on_trace_ready=torch.profiler.tensorboard_trace_handler(
-                cfg.profiler_dir
+            schedule=torch.profiler.schedule(
+                wait=wait_step, warmup=warmup_step, active=active_step, repeat=1
             ),
+            on_trace_ready=torch.profiler.tensorboard_trace_handler(cfg.profiler_dir),
             profile_memory=True,
             with_stack=False,
             with_flops=True,
@@ -59,15 +63,32 @@ def profile(cfg, local_rank=None):
             yield torch_profiler
     elif use_flop_counter:
         if cfg.max_train_step > 0 and cfg.max_train_step <= cfg.flop_counter_start:
-            raise ValueError(f"flop counter requires at least {cfg.flop_counter_start + 1} train steps, please increase the max_train_step, current max_train_step {cfg.max_train_step}")
-        with FlopMeasure(rank=local_rank,warmup_step=cfg.flop_counter_start) as flop_counter:
+            raise ValueError(
+                f"flop counter requires at least {cfg.flop_counter_start + 1} train steps, please increase the max_train_step, current max_train_step {cfg.max_train_step}"
+            )
+        with FlopMeasure(
+            rank=local_rank, warmup_step=cfg.flop_counter_start
+        ) as flop_counter:
             yield flop_counter
     else:
         torch_profiler = contextlib.nullcontext()
         yield None
 
 
-def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None, wandb_run=None):
+def train(
+    model,
+    train_dataloader,
+    eval_dataloader,
+    tokenizer,
+    optimizer,
+    lr_scheduler,
+    gradient_accumulation_steps,
+    train_config,
+    fsdp_config=None,
+    local_rank=None,
+    rank=None,
+    wandb_run=None,
+):
     """
     Trains the model on the given dataloader
 
@@ -93,13 +114,11 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
     if train_config.enable_fsdp:
         world_size = int(os.environ["WORLD_SIZE"])
 
-
-
     autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext
     train_prep = []
     train_loss = []
     val_prep = []
-    val_loss =[]
+    val_loss = []
 
     if train_config.save_metrics:
         if not os.path.exists(train_config.output_dir):
@@ -127,45 +146,70 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         with MemoryTrace() as memtrace:  # track the memory usage
             model.train()
             total_loss = 0.0
-            total_length = len(train_dataloader)//gradient_accumulation_steps
-            pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True)
-            with profile(train_config,local_rank) as profile_context:
+            total_length = len(train_dataloader) // gradient_accumulation_steps
+            pbar = tqdm(
+                colour="blue",
+                desc=f"Training Epoch: {epoch+1}",
+                total=total_length,
+                dynamic_ncols=True,
+            )
+            with profile(train_config, local_rank) as profile_context:
                 for step, batch in enumerate(train_dataloader):
                     total_train_steps += 1
                     # stop when the maximum number of training steps is reached
-                    if train_config.max_train_step > 0 and total_train_steps > train_config.max_train_step:
+                    if (
+                        train_config.max_train_step > 0
+                        and total_train_steps > train_config.max_train_step
+                    ):
                         max_steps_reached = True
-                        if not train_config.enable_fsdp or local_rank==0:
-                            print("max training steps reached, stopping training, total train steps finished: ", total_train_steps-1)
+                        if not train_config.enable_fsdp or local_rank == 0:
+                            print(
+                                "max training steps reached, stopping training, total train steps finished: ",
+                                total_train_steps - 1,
+                            )
                         break
                     for key in batch.keys():
                         if train_config.enable_fsdp:
                             if is_xpu_available():
-                                batch[key] = batch[key].to(torch.device(f"xpu:{local_rank}"))
+                                batch[key] = batch[key].to(
+                                    torch.device(f"xpu:{local_rank}")
+                                )
                             else:
                                 batch[key] = batch[key].to(local_rank)
                         else:
                             if is_xpu_available():
-                                batch[key] = batch[key].to('xpu:0')
+                                batch[key] = batch[key].to("xpu:0")
                             elif torch.cuda.is_available():
-                                batch[key] = batch[key].to('cuda:0')
+                                batch[key] = batch[key].to("cuda:0")
                     with autocast():
                         loss = model(**batch).loss
                     total_loss += loss.detach().float()
                     loss = loss / gradient_accumulation_steps
                     if train_config.save_metrics:
                         train_step_loss.append(loss.detach().float().item())
-                        train_step_perplexity.append(float(torch.exp(loss.detach().float())))
+                        train_step_perplexity.append(
+                            float(torch.exp(loss.detach().float()))
+                        )
                     if train_config.use_fp16:
                         # if fp16 is enabled, use gradient scaler to handle gradient update
                         scaler.scale(loss).backward()
-                        if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
-                            if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
+                        if (step + 1) % gradient_accumulation_steps == 0 or step == len(
+                            train_dataloader
+                        ) - 1:
+                            if (
+                                train_config.gradient_clipping
+                                and train_config.gradient_clipping_threshold > 0.0
+                            ):
                                 scaler.unscale_(optimizer)
                                 if train_config.enable_fsdp:
-                                    model.clip_grad_norm_(train_config.gradient_clipping_threshold)
+                                    model.clip_grad_norm_(
+                                        train_config.gradient_clipping_threshold
+                                    )
                                 else:
-                                    torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold)
+                                    torch.nn.utils.clip_grad_norm_(
+                                        model.parameters(),
+                                        train_config.gradient_clipping_threshold,
+                                    )
                             scaler.step(optimizer)
                             scaler.update()
                             optimizer.zero_grad()
@@ -173,12 +217,22 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                     else:
                         # regular backpropagation when fp16 is not used
                         loss.backward()
-                        if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
-                            if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
+                        if (step + 1) % gradient_accumulation_steps == 0 or step == len(
+                            train_dataloader
+                        ) - 1:
+                            if (
+                                train_config.gradient_clipping
+                                and train_config.gradient_clipping_threshold > 0.0
+                            ):
                                 if train_config.enable_fsdp:
-                                    model.clip_grad_norm_(train_config.gradient_clipping_threshold)
+                                    model.clip_grad_norm_(
+                                        train_config.gradient_clipping_threshold
+                                    )
                                 else:
-                                    torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold)
+                                    torch.nn.utils.clip_grad_norm_(
+                                        model.parameters(),
+                                        train_config.gradient_clipping_threshold,
+                                    )
                             optimizer.step()
                             optimizer.zero_grad()
                             pbar.update(1)
@@ -187,96 +241,71 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                     if train_config.flop_counter and profile_context.is_done():
                         TFlops = profile_context.get_flops_per_sec() / 1e12
                     if wandb_run:
-                        if not train_config.enable_fsdp or rank==0:
-                            wandb_run.log({
-                                'train/epoch': epoch + 1,
-                                'train/step': epoch * len(train_dataloader) + step,
-                                'train/loss': loss.detach().float(),
-                            })
-
-                    pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
+                        if not train_config.enable_fsdp or rank == 0:
+                            wandb_run.log(
+                                {
+                                    "train/epoch": epoch + 1,
+                                    "train/step": epoch * len(train_dataloader) + step,
+                                    "train/loss": loss.detach().float(),
+                                }
+                            )
+
+                    pbar.set_description(
+                        f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})"
+                    )
 
                     if train_config.save_metrics:
-                        save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep)
+                        save_to_json(
+                            metrics_filename,
+                            train_step_loss,
+                            train_loss,
+                            train_step_perplexity,
+                            train_prep,
+                            val_step_loss,
+                            val_loss,
+                            val_step_perplexity,
+                            val_prep,
+                        )
                 pbar.close()
 
-        epoch_end_time = time.perf_counter()-epoch_start_time
+        epoch_end_time = time.perf_counter() - epoch_start_time
         epoch_times.append(epoch_end_time)
         # Reducing total_loss across all devices if there's more than one CUDA device
-        if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp):
+        if is_xpu_available() and (
+            torch.xpu.device_count() > 1 and train_config.enable_fsdp
+        ):
             dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
         elif torch.cuda.device_count() > 1 and train_config.enable_fsdp:
             dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
         train_epoch_loss = total_loss / len(train_dataloader)
         if train_config.enable_fsdp:
-            train_epoch_loss = train_epoch_loss/world_size
+            train_epoch_loss = train_epoch_loss / world_size
         train_perplexity = torch.exp(train_epoch_loss)
 
         train_prep.append(float(train_perplexity))
         train_loss.append(float(train_epoch_loss))
 
-        if not train_config.enable_fsdp or rank==0:
+        if not train_config.enable_fsdp or rank == 0:
             memtrace.print_stats()
 
         # Update the learning rate as needed
         lr_scheduler.step()
         should_save_model = train_config.save_model
         if train_config.run_validation:
-            eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer, wandb_run)
+            eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(
+                model, train_config, eval_dataloader, local_rank, tokenizer, wandb_run
+            )
             if train_config.save_metrics:
                 val_step_loss.extend(temp_val_loss)
                 val_step_perplexity.extend(temp_step_perplexity)
-            should_save_model = train_config.save_model and eval_epoch_loss < best_val_loss
-        
+            should_save_model = (
+                train_config.save_model and eval_epoch_loss < best_val_loss
+            )
+
         checkpoint_start_time = time.perf_counter()
         if should_save_model:
-            if train_config.enable_fsdp:
-                dist.barrier()
-            if train_config.use_peft:
-                if train_config.enable_fsdp:
-                    if rank==0:
-                        print(f"we are about to save the PEFT modules")
-                else:
-                    print(f"we are about to save the PEFT modules")
-                save_peft_checkpoint(model, train_config.output_dir)
-                if train_config.enable_fsdp:
-                    if rank==0:
-                        print(f"PEFT modules are saved in {train_config.output_dir} directory")
-                else:
-                    print(f"PEFT modules are saved in {train_config.output_dir} directory")
-
-            else:
-                if not train_config.enable_fsdp:
-                    save_model_checkpoint(model, train_config.output_dir)
-                    
-                elif fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
-                    print(" Saving the FSDP model checkpoint using FULL_STATE_DICT")
-                    print("=====================================================")
-                    save_fsdp_model_checkpoint_full(
-                        model, optimizer, rank, train_config, epoch=epoch
-                    )
-                    
-                    if train_config.save_optimizer:
-                        print(" Saving the FSDP optimizer using FULL_STATE_DICT")
-                        print("=====================================================")
-                        save_optimizer_checkpoint(
-                            model, optimizer, rank, train_config, epoch=epoch
-                        )
-                    
-                elif fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
+            save_checkpoint(model, optimizer, train_config, fsdp_config, epoch)
 
-                    if train_config.save_optimizer:
-                        print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
-                        print("=====================================================")
-                        save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
-                    else:
-                        print(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT")
-                        print("=====================================================")
-                        save_model_and_optimizer_sharded(model, rank, train_config)
-
-                    
-            if train_config.enable_fsdp:
-                dist.barrier()
         checkpoint_end_time = time.perf_counter() - checkpoint_start_time
         checkpoint_times.append(checkpoint_end_time)
 
@@ -284,48 +313,67 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
             if eval_epoch_loss < best_val_loss:
                 best_val_loss = eval_epoch_loss
                 if train_config.enable_fsdp:
-                    if rank==0:
+                    if rank == 0:
                         print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
                 else:
-                        print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
+                    print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
             val_loss.append(float(eval_epoch_loss))
             val_prep.append(float(eval_ppl))
         if train_config.enable_fsdp:
-            if rank==0:
-                print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
+            if rank == 0:
+                print(
+                    f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s"
+                )
         else:
-            print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
+            print(
+                f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s"
+            )
 
         # Saving the results every epoch to plot later
         if train_config.save_metrics:
-            save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep)
+            save_to_json(
+                metrics_filename,
+                train_step_loss,
+                train_loss,
+                train_step_perplexity,
+                train_prep,
+                val_step_loss,
+                val_loss,
+                val_step_perplexity,
+                val_prep,
+            )
 
-    avg_epoch_time = sum(epoch_times)/ len(epoch_times)
-    avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times) if len(checkpoint_times) > 0 else 0
-    avg_train_prep = sum(train_prep)/len(train_prep)
-    avg_train_loss = sum(train_loss)/len(train_loss)
+    avg_epoch_time = sum(epoch_times) / len(epoch_times)
+    avg_checkpoint_time = (
+        sum(checkpoint_times) / len(checkpoint_times)
+        if len(checkpoint_times) > 0
+        else 0
+    )
+    avg_train_prep = sum(train_prep) / len(train_prep)
+    avg_train_loss = sum(train_loss) / len(train_loss)
     if train_config.run_validation:
-        avg_eval_prep = sum(val_prep)/len(val_prep)
-        avg_eval_loss = sum(val_loss)/len(val_loss)
+        avg_eval_prep = sum(val_prep) / len(val_prep)
+        avg_eval_loss = sum(val_loss) / len(val_loss)
 
-    results['avg_train_prep'] = avg_train_prep
-    results['avg_train_loss'] = avg_train_loss
+    results["avg_train_prep"] = avg_train_prep
+    results["avg_train_loss"] = avg_train_loss
     if train_config.run_validation:
-        results['avg_eval_prep'] = avg_eval_prep
-        results['avg_eval_loss'] = avg_eval_loss
+        results["avg_eval_prep"] = avg_eval_prep
+        results["avg_eval_loss"] = avg_eval_loss
     results["avg_epoch_time"] = avg_epoch_time
     results["avg_checkpoint_time"] = avg_checkpoint_time
     if train_config.save_metrics:
         results["metrics_filename"] = metrics_filename
     if train_config.flop_counter:
-        results["model_tflops"]= TFlops
-    #saving the training params including fsdp setting for reference.
-    if train_config.enable_fsdp and not train_config.use_peft and rank==0:
+        results["model_tflops"] = TFlops
+    # saving the training params including fsdp setting for reference.
+    if train_config.enable_fsdp and not train_config.use_peft and rank == 0:
         save_train_params(train_config, fsdp_config, rank)
 
     return results
 
-def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb_run):
+
+def evaluation(model, train_config, eval_dataloader, local_rank, tokenizer, wandb_run):
     """
     Evaluates the model on the given dataloader
 
@@ -346,21 +394,34 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb
     eval_loss = 0.0  # Initialize evaluation loss
     total_eval_steps = 0
     with MemoryTrace() as memtrace:
-        for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch", dynamic_ncols=True)):
+        for step, batch in enumerate(
+            tqdm(
+                eval_dataloader,
+                colour="green",
+                desc="evaluating Epoch",
+                dynamic_ncols=True,
+            )
+        ):
             total_eval_steps += 1
             # stop when the maximum number of eval steps is reached
-            if train_config.max_eval_step > 0 and total_eval_steps > train_config.max_eval_step:
-                if not train_config.enable_fsdp or local_rank==0:
-                    print("max eval steps reached, stopping evaluation, total_eval_steps: ", total_eval_steps - 1)
+            if (
+                train_config.max_eval_step > 0
+                and total_eval_steps > train_config.max_eval_step
+            ):
+                if not train_config.enable_fsdp or local_rank == 0:
+                    print(
+                        "max eval steps reached, stopping evaluation, total_eval_steps: ",
+                        total_eval_steps - 1,
+                    )
                 break
             for key in batch.keys():
                 if train_config.enable_fsdp:
                     batch[key] = batch[key].to(local_rank)
                 else:
                     if is_xpu_available():
-                        batch[key] = batch[key].to('xpu:0')
+                        batch[key] = batch[key].to("xpu:0")
                     else:
-                        batch[key] = batch[key].to('cuda:0')
+                        batch[key] = batch[key].to("cuda:0")
             # Ensure no gradients are computed for this scope to save memory
             with torch.no_grad():
                 # Forward pass and compute loss
@@ -374,11 +435,15 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb
             # Decode predictions and add to evaluation predictions list
             preds = torch.argmax(outputs.logits, -1)
             eval_preds.extend(
-                tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True)
+                tokenizer.batch_decode(
+                    preds.detach().cpu().numpy(), skip_special_tokens=True
+                )
             )
 
     # If there's more than one CUDA device, reduce evaluation loss across all devices
-    if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp):
+    if is_xpu_available() and (
+        torch.xpu.device_count() > 1 and train_config.enable_fsdp
+    ):
         dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
     if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
         dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
@@ -386,35 +451,39 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb
     # Compute average loss and perplexity
     eval_epoch_loss = eval_loss / len(eval_dataloader)
     if train_config.enable_fsdp:
-        eval_epoch_loss = eval_epoch_loss/world_size
+        eval_epoch_loss = eval_epoch_loss / world_size
     eval_ppl = torch.exp(eval_epoch_loss)
 
     # Print evaluation metrics
     if train_config.enable_fsdp:
-        if local_rank==0:
+        if local_rank == 0:
             print(f" {eval_ppl=} {eval_epoch_loss=}")
     else:
         print(f" {eval_ppl=} {eval_epoch_loss=}")
 
     if wandb_run:
-        wandb_run.log({
-                        'eval/perplexity': eval_ppl,
-                        'eval/loss': eval_epoch_loss,
-                    }, commit=False)
+        wandb_run.log(
+            {
+                "eval/perplexity": eval_ppl,
+                "eval/loss": eval_epoch_loss,
+            },
+            commit=False,
+        )
 
     return eval_ppl, eval_epoch_loss, val_step_loss, val_step_perplexity
 
+
 def freeze_transformer_layers(model, num_layer):
-   for i, layer in enumerate(model.model.layers):
-            if i < num_layer:
-                for param in layer.parameters():
-                    param.requires_grad = False
+    for i, layer in enumerate(model.model.layers):
+        if i < num_layer:
+            for param in layer.parameters():
+                param.requires_grad = False
 
 
 def check_frozen_layers_peft_model(model):
-     for i, layer in enumerate(model.base_model.model.model.layers):
-            for name, param in layer.named_parameters():
-                print(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}")
+    for i, layer in enumerate(model.base_model.model.model.layers):
+        for name, param in layer.named_parameters():
+            print(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}")
 
 
 def setup():
@@ -460,58 +529,6 @@ def get_parameter_dtypes(model):
         parameter_dtypes[name] = parameter.dtype
     return parameter_dtypes
 
-def print_model_size(model, config, rank: int = 0) -> None:
-    """
-    Print model name, the number of trainable parameters and initialization time.
-
-    Args:
-        model: The PyTorch model.
-        model_name (str): Name of the model.
-        init_time_start (float): Initialization start time.
-        init_time_end (float): Initialization end time.
-        rank (int, optional): Current process's rank. Defaults to 0.
-    """
-    if rank == 0:
-        print(f"--> Model {config.model_name}")
-        total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
-        print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n")
-
-
-
-
-def get_policies(cfg, rank):
-    """Get the policies for mixed precision and fsdp wrapping"""
-
-
-    verify_bfloat_support = ((
-    torch.version.cuda
-    and torch.cuda.is_bf16_supported()
-    and torch.version.cuda >= "11.0"
-    and dist.is_nccl_available()
-    and nccl.version() >= (2, 10)
-    ) or
-    (is_xpu_available()))
-
-
-    mixed_precision_policy = None
-    wrapping_policy = None
-
-    # Mixed precision
-    if cfg.mixed_precision:
-        bf16_ready = verify_bfloat_support
-
-        if bf16_ready and not cfg.use_fp16:
-            mixed_precision_policy = bfSixteen
-            if rank == 0:
-                print(f"bFloat16 enabled for mixed precision - using bfSixteen policy")
-        elif cfg.use_fp16:
-            mixed_precision_policy = fpSixteen
-            if rank == 0:
-                print(f"FP16 enabled")
-        else:
-            print(f"bFloat16 support not present. Using FP32, and not mixed precision")
-    wrapping_policy = get_llama_wrapper()
-    return mixed_precision_policy, wrapping_policy
 
 def save_train_params(train_config, fsdp_config, rank):
     """
@@ -521,17 +538,21 @@ def save_train_params(train_config, fsdp_config, rank):
     """
     # Convert the train_config and fsdp_config objects to dictionaries,
     # converting all values to strings to ensure they can be serialized into a YAML file
-    train_config_dict = {k: str(v) for k, v in vars(train_config).items() if not k.startswith('__')}
-    fsdp_config_dict = {k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith('__')}
+    train_config_dict = {
+        k: str(v) for k, v in vars(train_config).items() if not k.startswith("__")
+    }
+    fsdp_config_dict = {
+        k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith("__")
+    }
     # Merge the two dictionaries into one
     train_params_dict = {**train_config_dict, **fsdp_config_dict}
     # Construct the folder name (follwoing FSDP checkpointing style) using properties of the train_config object
     folder_name = (
-    train_config.dist_checkpoint_root_folder
-    + "/"
-    + train_config.dist_checkpoint_folder
-    + "-"
-    + train_config.model_name
+        train_config.dist_checkpoint_root_folder
+        + "/"
+        + train_config.dist_checkpoint_folder
+        + "-"
+        + train_config.model_name
     )
 
     save_dir = Path.cwd() / folder_name
@@ -540,19 +561,30 @@ def save_train_params(train_config, fsdp_config, rank):
         os.makedirs(save_dir)
     # Convert the dictionary to a YAML string
     config_yaml = yaml.dump(train_params_dict, indent=4)
-    file_name = os.path.join(save_dir,'train_params.yaml')
+    file_name = os.path.join(save_dir, "train_params.yaml")
 
     # Check if there's a directory with the same name as the file
     if os.path.isdir(file_name):
         print(f"Error: {file_name} is a directory, not a file.")
     else:
         # Write the YAML string to the file
-        with open(file_name, 'w') as f:
+        with open(file_name, "w") as f:
             f.write(config_yaml)
-        if rank==0:
+        if rank == 0:
             print(f"training params are saved in {file_name}")
 
-def save_to_json(output_filename, train_step_loss, train_epoch_loss, train_step_ppl, train_epoch_ppl, val_step_loss, val_epoch_loss, val_step_ppl, val_epoch_ppl):
+
+def save_to_json(
+    output_filename,
+    train_step_loss,
+    train_epoch_loss,
+    train_step_ppl,
+    train_epoch_ppl,
+    val_step_loss,
+    val_epoch_loss,
+    val_step_ppl,
+    val_epoch_ppl,
+):
     metrics_data = {
         "train_step_loss": train_step_loss,
         "train_epoch_loss": train_epoch_loss,
@@ -561,7 +593,7 @@ def save_to_json(output_filename, train_step_loss, train_epoch_loss, train_step_
         "val_step_loss": val_step_loss,
         "val_epoch_loss": val_epoch_loss,
         "val_step_perplexity": val_step_ppl,
-        "val_epoch_perplexity": val_epoch_ppl
+        "val_epoch_perplexity": val_epoch_ppl,
     }
     with open(output_filename, "w") as f:
         json.dump(metrics_data, f)