浏览代码

[WIP] enabled fsdpv2 checkpointing, peft+full-vision enabled for fsdpv2, refactoring

Matthias Reso 6 月之前
父节点
当前提交
6d5b221f44

+ 67 - 118
src/llama_recipes/finetuning.py

@@ -14,14 +14,14 @@ import torch.optim as optim
 from accelerate.utils import is_xpu_available
 from accelerate.utils import is_xpu_available
 
 
 from llama_recipes.configs import (
 from llama_recipes.configs import (
-    fsdp_config as FSDP_CONFIG,
-    quantization_config as QUANTIZATION_CONFIG,
-    train_config as TRAIN_CONFIG,
+    fsdp_config as FsdpConfig,
+    quantization_config as QuantizationConfig,
+    train_config as TrainConfig,
 )
 )
 from llama_recipes.data.concatenator import ConcatDataset
 from llama_recipes.data.concatenator import ConcatDataset
 from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing
 from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing
 
 
-from llama_recipes.utils import fsdp_auto_wrap_policy
+from llama_recipes.utils import get_model_and_data_processor
 from llama_recipes.utils.config_utils import (
 from llama_recipes.utils.config_utils import (
     check_fsdp_config,
     check_fsdp_config,
     generate_dataset_config,
     generate_dataset_config,
@@ -38,8 +38,6 @@ from llama_recipes.utils.fsdp_utils import hsdp_device_mesh
 from llama_recipes.utils.train_utils import (
 from llama_recipes.utils.train_utils import (
     clear_gpu_cache,
     clear_gpu_cache,
     freeze_transformer_layers,
     freeze_transformer_layers,
-    get_policies,
-    print_model_size,
     setup,
     setup,
     setup_environ_flags,
     setup_environ_flags,
     train,
     train,
@@ -53,8 +51,6 @@ from transformers import (
     AutoProcessor,
     AutoProcessor,
     AutoTokenizer,
     AutoTokenizer,
     BitsAndBytesConfig,
     BitsAndBytesConfig,
-    LlamaForCausalLM,
-    MllamaForConditionalGeneration,
 )
 )
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
 from transformers.models.mllama.modeling_mllama import (
 from transformers.models.mllama.modeling_mllama import (
@@ -72,9 +68,9 @@ def setup_wandb(train_config, fsdp_config, **kwargs):
             "You are trying to use wandb which is not currently installed. "
             "You are trying to use wandb which is not currently installed. "
             "Please install it using pip install wandb"
             "Please install it using pip install wandb"
         )
         )
-    from llama_recipes.configs import wandb_config as WANDB_CONFIG
+    from llama_recipes.configs import wandb_config as WandBConfig
 
 
-    wandb_config = WANDB_CONFIG()
+    wandb_config = WandBConfig()
     update_config(wandb_config, **kwargs)
     update_config(wandb_config, **kwargs)
     init_dict = dataclasses.asdict(wandb_config)
     init_dict = dataclasses.asdict(wandb_config)
     run = wandb.init(**init_dict)
     run = wandb.init(**init_dict)
@@ -85,7 +81,7 @@ def setup_wandb(train_config, fsdp_config, **kwargs):
 
 
 def main(**kwargs):
 def main(**kwargs):
     # Update the configuration for the training and sharding process
     # Update the configuration for the training and sharding process
-    train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG()
+    train_config, fsdp_config = TrainConfig(), FsdpConfig()
     update_config((train_config, fsdp_config), **kwargs)
     update_config((train_config, fsdp_config), **kwargs)
     # Set the seeds for reproducibility
     # Set the seeds for reproducibility
     if is_xpu_available():
     if is_xpu_available():
@@ -116,7 +112,7 @@ def main(**kwargs):
             wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)
             wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)
 
 
     # setting quantization configs
     # setting quantization configs
-    bnb_config = None
+    quant_config = None
     if train_config.quantization:
     if train_config.quantization:
         if type(train_config.quantization) == type(True):
         if type(train_config.quantization) == type(True):
             warn(
             warn(
@@ -130,70 +126,15 @@ def main(**kwargs):
                 "8bit quantization is not supported with FSDP, please use 4bit quantization"
                 "8bit quantization is not supported with FSDP, please use 4bit quantization"
             )
             )
 
 
-        quant_config = QUANTIZATION_CONFIG()
+        quant_config = QuantizationConfig()
         update_config(quant_config, **kwargs)
         update_config(quant_config, **kwargs)
-        bnb_config = quant_config.create_bnb_config(train_config.quantization)
 
 
     # Load the pre-trained model and setup its configuration
     # Load the pre-trained model and setup its configuration
-    use_cache = False if train_config.enable_fsdp else None
-    config = AutoConfig.from_pretrained(train_config.model_name)
-    if config.model_type == "mllama":
-        is_vision = True
-        model = MllamaForConditionalGeneration.from_pretrained(
-            train_config.model_name,
-            quantization_config=bnb_config,
-            attn_implementation="sdpa" if train_config.use_fast_kernels else None,
-            device_map=(
-                "auto"
-                if train_config.quantization and not train_config.enable_fsdp
-                else None
-            ),
-            torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
-        )
-        processor = AutoProcessor.from_pretrained(
-            train_config.model_name
-            if train_config.tokenizer_name is None
-            else train_config.tokenizer_name
-        )
-        processor.tokenizer.padding_side = "right"
-        model.supports_gradient_checkpointing = True
-        model.language_model.supports_gradient_checkpointing = True
-    elif config.model_type == "llama":
-        is_vision = False
-        model = LlamaForCausalLM.from_pretrained(
-            train_config.model_name,
-            quantization_config=bnb_config,
-            use_cache=use_cache,
-            attn_implementation="sdpa" if train_config.use_fast_kernels else None,
-            device_map=(
-                "auto"
-                if train_config.quantization and not train_config.enable_fsdp
-                else None
-            ),
-            torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
-        )
+    model, dataset_processer, is_vision = get_model_and_data_processor(train_config, quant_config)
+    if is_vision:
+        tokenizer = dataset_processer.tokenizer
     else:
     else:
-        raise ValueError(
-            f"Model type {config.model_type} is not supported. Please use llama or mllama model."
-        )
-    # Load the tokenizer and add special tokens
-    tokenizer = AutoTokenizer.from_pretrained(
-        train_config.model_name
-        if train_config.tokenizer_name is None
-        else train_config.tokenizer_name
-    )
-    if not tokenizer.pad_token_id:
-        tokenizer.pad_token_id = tokenizer.eos_token_id
-
-    # If there is a mismatch between tokenizer vocab size and embedding matrix,
-    # throw a warning and then expand the embedding matrix
-    if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
-        print(
-            "WARNING: Resizing the embedding matrix to match the tokenizer vocab size."
-        )
-        model.resize_token_embeddings(len(tokenizer))
-
-    print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
+        tokenizer = dataset_processer
 
 
     # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
     # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
     if (
     if (
@@ -235,71 +176,79 @@ def main(**kwargs):
 
 
         if not train_config.use_peft and train_config.freeze_layers:
         if not train_config.use_peft and train_config.freeze_layers:
             freeze_transformer_layers(model, train_config.num_freeze_layers)
             freeze_transformer_layers(model, train_config.num_freeze_layers)
+            
+        device_id = 0
+        if is_xpu_available():
+            device_id = torch.xpu.current_device()
+        elif torch.cuda.is_available():
+            device_id = torch.cuda.current_device()
+        from llama_recipes.utils.fsdp_utils import parallelize_model
+
+        # model = FSDP(
+        #     
+        #     cpu_offload=(
+        #         CPUOffload(offload_params=True)
+        #         if fsdp_config.fsdp_cpu_offload
+        #         else None
+        #     ),
+        #     mixed_precision=(
+        #         mixed_precision_policy if not fsdp_config.pure_bf16 else None
+        #     ),
+        #     sharding_strategy=fsdp_config.sharding_strategy,
+        #     device_mesh=hsdp_device_mesh_plan,
+        #     device_id=device_id,
+        #     limit_all_gathers=True,
+        #     sync_module_states=train_config.low_cpu_fsdp,
+        #     param_init_fn=(
+        #         (
+        #             lambda module: module.to_empty(
+        #                 device=torch.device("cuda"), recurse=False
+        #             )
+        #         )
+        #         if train_config.low_cpu_fsdp and rank != 0
+        #         else None
+        #     ),
+        # )
 
 
-        mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
-        # Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models
         if is_vision:
         if is_vision:
-            my_auto_wrapping_policy = fsdp_auto_wrap_policy(
-                model,
-                [
+            MODS = (
                     MllamaSelfAttentionDecoderLayer,
                     MllamaSelfAttentionDecoderLayer,
                     MllamaSelfAttentionDecoderLayer,
                     MllamaSelfAttentionDecoderLayer,
                     MllamaVisionEncoderLayer,
                     MllamaVisionEncoderLayer,
-                ],
             )
             )
+            sharding_conditions = [
+                lambda m: any(isinstance(m,n) for n in MODS),
+            ]
         else:
         else:
-            # Create the FSDP wrapper for LlamaDecoderLayer in text models
-            my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer])
-        device_id = 0
-        if is_xpu_available():
-            device_id = torch.xpu.current_device()
-        elif torch.cuda.is_available():
-            device_id = torch.cuda.current_device()
-        model = FSDP(
-            model,
-            auto_wrap_policy=(
-                my_auto_wrapping_policy if train_config.use_peft else wrapping_policy
-            ),
-            cpu_offload=(
-                CPUOffload(offload_params=True)
-                if fsdp_config.fsdp_cpu_offload
-                else None
-            ),
-            mixed_precision=(
-                mixed_precision_policy if not fsdp_config.pure_bf16 else None
-            ),
-            sharding_strategy=fsdp_config.sharding_strategy,
-            device_mesh=hsdp_device_mesh_plan,
-            device_id=device_id,
-            limit_all_gathers=True,
-            sync_module_states=train_config.low_cpu_fsdp,
-            param_init_fn=(
-                (
-                    lambda module: module.to_empty(
-                        device=torch.device("cuda"), recurse=False
-                    )
+            sharding_conditions = [lambda m: isinstance(m, LlamaDecoderLayer)]
+
+        if train_config.use_peft:
+            sharding_conditions += [
+                lambda m: (
+                    len(list(m.named_children())) == 0
+                    and getattr(m, "weight", None) is not None
+                    and m.weight.requires_grad
                 )
                 )
-                if train_config.low_cpu_fsdp and rank != 0
-                else None
-            ),
+            ]
+
+        parallelize_model(
+            model,
+            fsdp_config,
+            device_mesh = hsdp_device_mesh_plan,
+            sharding_conditions = sharding_conditions,
         )
         )
+        
         if fsdp_config.fsdp_activation_checkpointing:
         if fsdp_config.fsdp_activation_checkpointing:
             model.enable_input_require_grads()
             model.enable_input_require_grads()
             model.gradient_checkpointing_enable()
             model.gradient_checkpointing_enable()
-            apply_fsdp_checkpointing(model)
     elif not train_config.quantization and not train_config.enable_fsdp:
     elif not train_config.quantization and not train_config.enable_fsdp:
         if is_xpu_available():
         if is_xpu_available():
             model.to("xpu:0")
             model.to("xpu:0")
         elif torch.cuda.is_available():
         elif torch.cuda.is_available():
             model.to("cuda")
             model.to("cuda")
     dataset_config = generate_dataset_config(train_config, kwargs)
     dataset_config = generate_dataset_config(train_config, kwargs)
-    if is_vision:
-        dataset_processer = processor
-    else:
-        dataset_processer = tokenizer
-
+    
     # Load and preprocess the dataset for training and validation
     # Load and preprocess the dataset for training and validation
-
     dataset_train = get_preprocessed_dataset(
     dataset_train = get_preprocessed_dataset(
         dataset_processer,
         dataset_processer,
         dataset_config,
         dataset_config,

+ 4 - 4
src/llama_recipes/model_checkpointing/__init__.py

@@ -3,12 +3,12 @@
 
 
 from llama_recipes.model_checkpointing.checkpoint_handler import (
 from llama_recipes.model_checkpointing.checkpoint_handler import (
     load_model_checkpoint,
     load_model_checkpoint,
-    save_fsdp_model_checkpoint_full,
+    save_fsdp_checkpoint_full,
+    save_fsdp_checkpoint_sharded,
     save_peft_checkpoint,
     save_peft_checkpoint,
     save_model_checkpoint,
     save_model_checkpoint,
+    save_checkpoint,
     load_optimizer_checkpoint,
     load_optimizer_checkpoint,
-    save_optimizer_checkpoint,
-    save_model_and_optimizer_sharded,
-    load_model_sharded,
+    load_fsdp_checkpoint_sharded,
     load_sharded_model_single_gpu
     load_sharded_model_single_gpu
 )
 )

+ 128 - 103
src/llama_recipes/model_checkpointing/checkpoint_handler.py

@@ -8,6 +8,9 @@ from pathlib import Path
 import torch
 import torch
 import torch.distributed as dist
 import torch.distributed as dist
 
 
+from torch.distributed.checkpoint.state_dict import get_state_dict, StateDictOptions
+from torch.distributed.checkpoint.state_dict_saver import save
+from torch.distributed.checkpoint.state_dict_loader import load
 from torch.distributed.checkpoint import (
 from torch.distributed.checkpoint import (
     FileSystemReader,
     FileSystemReader,
     FileSystemWriter,
     FileSystemWriter,
@@ -47,118 +50,128 @@ def get_date_of_run():
 fullstate_save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
 fullstate_save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
 
 
 
 
-def load_model_sharded(model, rank, cfg):
-    # torch.manual_seed(103)
-    folder_name = (
-        cfg.dist_checkpoint_root_folder
-        + "/"
-        + cfg.dist_checkpoint_folder
-        + "-"
-        + cfg.model_name
-    )
+def load_fsdp_checkpoint_sharded(model, cfg, epoch=1, optimizer=None):
+    rank = dist.get_rank()
+    folder_name = "-".join((cfg.dist_checkpoint_folder, cfg.model_name, str(epoch)))
 
 
-    load_dir = Path.cwd() / folder_name
+    load_dir = Path.cwd() / cfg.dist_checkpoint_root_folder / folder_name
 
 
     if not load_dir.exists():
     if not load_dir.exists():
         if rank == 0:
         if rank == 0:
-            print(f"No sharded_state_dict checkpoint directory found...skipping")
+            print(f"No sharded_state_dict checkpoint directory at {load_dir.as_posix()} found...skipping")
         return
         return
     if rank == 0:
     if rank == 0:
-        print(f"loading model from model path: {load_dir} ")
+        print(f"loading model from model path: {load_dir.as_posix()} ")
     reader = FileSystemReader(load_dir)
     reader = FileSystemReader(load_dir)
 
 
-    with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
-        checkpoint = {"model": model.state_dict()}
-        if rank == 0:
-            ck = checkpoint.keys()
-            print(f" checkpoint key len = {len(ck)} and \n keys =  {ck}")
+    checkpoint = {"model": model}
+    if optimizer is not None:
+        checkpoint["optimizer"] = optimizer
+    if rank == 0:
+        ck = checkpoint.keys()
+        print(f" checkpoint key len = {len(ck)} and \n keys =  {ck}")
+
+    load(
+        state_dict=checkpoint,
+        storage_reader=reader,
+    )
+    if rank == 0:
+        print(f"checkpoint after load_state_dict()")
+        ck = checkpoint.keys()
+        print(f" checkpoint key len = {len(ck)} and \n keys =  {ck}")
+
+    model.load_state_dict(checkpoint["model"])
+    if optimizer is not None:
+        optimizer.load_state_dict(checkpoint["optimizer"])
 
 
-        load_state_dict(
-            state_dict=checkpoint,
-            storage_reader=reader,
-        )
-        if rank == 0:
-            print(f"checkpoint after load_state_dict()")
-            ck = checkpoint.keys()
-            print(f" checkpoint key len = {len(ck)} and \n keys =  {ck}")
-        model.load_state_dict(checkpoint["model"])
     if rank == 0:
     if rank == 0:
         print(f"Sharded state checkpoint loaded from {load_dir}")
         print(f"Sharded state checkpoint loaded from {load_dir}")
 
 
 
 
-def save_model_and_optimizer_sharded(model, rank, cfg, optim=None):
+def save_fsdp_checkpoint_sharded(model, optimizer, train_config, epoch=1):
     """save model and optimizer via sharded_state_dict to save_dir"""
     """save model and optimizer via sharded_state_dict to save_dir"""
 
 
-    folder_name = (
-        cfg.dist_checkpoint_root_folder
-        + "/"
-        + cfg.dist_checkpoint_folder
-        + "-"
-        + cfg.model_name
-    )
+    folder_name = "-".join((train_config.dist_checkpoint_folder, train_config.model_name, str(epoch)))
+
+    save_dir = Path.cwd() / train_config.dist_checkpoint_root_folder / folder_name
+
+    rank = dist.get_rank()
 
 
-    save_dir = Path.cwd() / folder_name
     if rank == 0:
     if rank == 0:
-        print(f"Saving model to {save_dir}")
+        print(f"Saving model to {save_dir.as_posix()}")
 
 
     distributed_writer = FileSystemWriter(
     distributed_writer = FileSystemWriter(
         save_dir,
         save_dir,
+        overwrite=True,
     )
     )
     t0 = time.perf_counter()
     t0 = time.perf_counter()
 
 
-    with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
+    options = StateDictOptions(
+        full_state_dict=False,
+    )
 
 
-        state_dict = {"model": model.state_dict()}
-        if optim is not None:
-            state_dict["optim"] = FSDP.optim_state_dict(model, optim)
+    optim = optimizer if train_config.save_optimizer else []
 
 
-        save_state_dict(
-            state_dict=state_dict,
-            storage_writer=distributed_writer,
-            planner=DefaultSavePlanner(),
-        )
+    state_dict = {"model": model}
+    if train_config.save_optimizer:
+        state_dict["optimizer"] = optimizer
+
+    save(
+        state_dict=state_dict,
+        storage_writer=distributed_writer,
+        planner=DefaultSavePlanner(),
+    )
     dist.barrier()
     dist.barrier()
     t1 = time.perf_counter()
     t1 = time.perf_counter()
     if rank == 0:
     if rank == 0:
-        print(f"Sharded state checkpoint saved to {save_dir}")
+        print(f"Sharded state checkpoint saved to {save_dir.as_posix()}")
         print(f"Checkpoint Time = {t1-t0:.4f}\n")
         print(f"Checkpoint Time = {t1-t0:.4f}\n")
 
 
 
 
-def save_fsdp_model_checkpoint_full(
+def save_fsdp_checkpoint_full(
     model,
     model,
     optimizer,
     optimizer,
-    rank,
-    cfg,
+    train_config,
     epoch=1,
     epoch=1,
 ):
 ):
     """saving model via rank0 cpu streaming and full_state_dict"""
     """saving model via rank0 cpu streaming and full_state_dict"""
 
 
-    with FSDP.state_dict_type(
-        model, StateDictType.FULL_STATE_DICT, fullstate_save_policy
-    ):
-        cpu_state = model.state_dict()
+    options = StateDictOptions(
+        full_state_dict=True,
+    )
 
 
-        print(f"saving process: rank {rank}  done w model state_dict\n")
+    optim = optimizer if train_config.save_optimizer else []
+
+    model_state, optim_state = get_state_dict(model, optim, options=options)
+
+    rank = dist.get_rank()
 
 
     if rank == 0:
     if rank == 0:
         print(f"--> saving model ...")
         print(f"--> saving model ...")
         # create save path
         # create save path
-        folder_name = (
-            cfg.dist_checkpoint_root_folder
-            + "/"
-            + cfg.dist_checkpoint_folder
-            + "-"
-            + cfg.model_name
-        )
-        save_dir = Path.cwd() / folder_name
+        folder_name = "-".join((train_config.dist_checkpoint_folder, train_config.model_name))
+        save_dir = Path.cwd() / train_config.dist_checkpoint_root_folder / folder_name
         save_dir.mkdir(parents=True, exist_ok=True)
         save_dir.mkdir(parents=True, exist_ok=True)
-        save_name = cfg.model_name.replace("/", "--") + "-" + str(epoch) + ".pt"
-        save_full_path = str(save_dir) + "/" + save_name
+
+        save_name = train_config.model_name.replace("/", "--") + "-" + str(epoch) + ".pt"
+        save_full_path = save_dir / save_name
 
 
         # save model
         # save model
-        torch.save(cpu_state, save_full_path)
+        torch.save(model_state, save_full_path)
+
+        print(f"model checkpoint saved for epoch {epoch} at {save_full_path.as_posix()}\n")
+
+        if not train_config.save_optimizer:
+            return
+
+        opt_save_name = "optimizer" + "-" + train_config.model_name.replace("/", "--") + "-" + str(epoch) + ".pt"
+        opt_save_full_path = save_dir / opt_save_name
+
+        print(f"--> saving optimizer state...")
 
 
-        print(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n")
+        torch.save(optim_state, opt_save_full_path)
+
+        print(f"--> saved {opt_save_full_path.as_posix()} to disk")
 
 
 
 
 def load_model_checkpoint(model, rank, cfg):
 def load_model_checkpoint(model, rank, cfg):
@@ -186,38 +199,6 @@ def load_model_checkpoint(model, rank, cfg):
     print(f"model checkpoint loaded to rank0 cpu")
     print(f"model checkpoint loaded to rank0 cpu")
 
 
 
 
-def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1):
-    """save optimizer state via full state dict"""
-
-    print(f"--> optim state call on rank {rank}\n")
-
-    # pull all sharded optimizer states to rank0 cpu...
-
-    optim_state = FSDP.full_optim_state_dict(model, optimizer)
-
-    print(f"optim state dict ready on {rank} and len of {len(optim_state)}\n")
-
-    if rank == 0:
-        folder_name = (
-            cfg.dist_checkpoint_root_folder
-            + "/"
-            + cfg.dist_checkpoint_folder
-            + "-"
-            + cfg.model_name
-        )
-        save_dir = Path.cwd() / folder_name
-        save_dir.mkdir(parents=True, exist_ok=True)
-
-        opt_save_name = "optimizer" + "-" + cfg.model_name + "-" + str(epoch) + ".pt"
-        opt_save_full_path = save_dir / opt_save_name
-
-        print(f"--> saving optimizer state...")
-
-        torch.save(optim_state, opt_save_full_path)
-
-        print(f"--> saved {opt_save_full_path} to disk")
-
-
 def load_optimizer_checkpoint(model, optimizer_checkpoint_path, rank):
 def load_optimizer_checkpoint(model, optimizer_checkpoint_path, rank):
     """load an fsdp optimizer full_state checkpoint using scatter method
     """load an fsdp optimizer full_state checkpoint using scatter method
     this ensures only rank 0 loads the optimizer state dict and scatters to other ranks
     this ensures only rank 0 loads the optimizer state dict and scatters to other ranks
@@ -258,14 +239,20 @@ def load_sharded_model_single_gpu(model, model_path):
     return model
     return model
 
 
 
 
-def save_peft_checkpoint(model, model_path):
+def save_peft_checkpoint(model, train_config):
     """save_pretrained peft model"""
     """save_pretrained peft model"""
+    if train_config.enable_fsdp:
+        options = StateDictOptions(
+            full_state_dict=True,
+            cpu_offload=True,
+        )
 
 
-    options = StateDictOptions(full_state_dict=True, cpu_offload=True)
+        model_state, _ = get_state_dict(model, [], options=options)
 
 
-    if isinstance(model, FSDP):
-        state_dict = get_model_state_dict(model, options=options)
-        model.save_pretrained(model_path, state_dict=state_dict)
+        rank = dist.get_rank()
+        if rank == 0:
+            model_path = train_config.output_dir
+            model.save_pretrained(model_path, state_dict=model_state)
     else:
     else:
         model.save_pretrained(model_path)
         model.save_pretrained(model_path)
 
 
@@ -278,3 +265,41 @@ def save_model_checkpoint(model, output_dir):
     state_dict = model.state_dict()
     state_dict = model.state_dict()
 
 
     torch.save(state_dict, output_file)
     torch.save(state_dict, output_file)
+
+
+def save_checkpoint(model, optimizer, train_config, fsdp_config, epoch):
+    """save model and optimizer"""
+    rank = dist.get_rank() if train_config.enable_fsdp else 0
+
+    if train_config.enable_fsdp:
+        dist.barrier()
+    if train_config.use_peft:
+        if rank == 0:
+            print(f"we are about to save the PEFT modules")
+        save_peft_checkpoint(model, train_config)
+        
+        if rank == 0:
+            print(f"PEFT modules are saved in {train_config.output_dir} directory")
+
+    else:
+        if not train_config.enable_fsdp:
+            save_model_checkpoint(model, train_config.output_dir)
+
+        elif fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
+            if rank == 0:
+                print(" Saving the FSDP model checkpoint using FULL_STATE_DICT")
+                print("=====================================================")
+            save_fsdp_checkpoint_full(
+                model, optimizer, train_config, epoch=epoch
+            )
+
+        elif fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
+            if rank == 0:
+                print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
+                print("=====================================================")
+            save_fsdp_checkpoint_sharded(
+                model, optimizer, train_config, epoch=epoch
+            )
+
+    if train_config.enable_fsdp:
+        dist.barrier()

+ 41 - 14
src/llama_recipes/policies/mixed_precision.py

@@ -2,37 +2,64 @@
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 
 import torch
 import torch
-
-from torch.distributed.fsdp import (
-    MixedPrecision,
-)
+import torch.cuda.nccl as nccl
+import torch.distributed as dist
+from torch.distributed._composable.fsdp import MixedPrecisionPolicy
+ 
 
 
 # requires grad scaler in main loop
 # requires grad scaler in main loop
-fpSixteen = MixedPrecision(
+fpSixteen = MixedPrecisionPolicy(
     param_dtype=torch.float16,
     param_dtype=torch.float16,
     # Gradient communication precision.
     # Gradient communication precision.
     reduce_dtype=torch.float16,
     reduce_dtype=torch.float16,
-    # Buffer precision.
-    buffer_dtype=torch.float16,
 )
 )
 
 
-bfSixteen = MixedPrecision(
+bfSixteen = MixedPrecisionPolicy(
     param_dtype=torch.bfloat16,
     param_dtype=torch.bfloat16,
     # Gradient communication precision.
     # Gradient communication precision.
     reduce_dtype=torch.bfloat16,
     reduce_dtype=torch.bfloat16,
-    # Buffer precision.
-    buffer_dtype=torch.bfloat16,
     cast_forward_inputs=True,
     cast_forward_inputs=True,
 )
 )
 
 
-bfSixteen_mixed = MixedPrecision(
+bfSixteen_mixed = MixedPrecisionPolicy(
     param_dtype=torch.float32,
     param_dtype=torch.float32,
     reduce_dtype=torch.bfloat16,
     reduce_dtype=torch.bfloat16,
-    buffer_dtype=torch.bfloat16,
 )
 )
 
 
-fp32_policy = MixedPrecision(
+fp32_policy = MixedPrecisionPolicy(
     param_dtype=torch.float32,
     param_dtype=torch.float32,
     reduce_dtype=torch.float32,
     reduce_dtype=torch.float32,
-    buffer_dtype=torch.float32,
 )
 )
+
+
+def get_mixed_precision_policies(cfg):
+    """Get the policies for mixed precision and fsdp wrapping"""
+
+    rank = dist.get_rank()
+
+    verify_bfloat_support = (
+        torch.version.cuda
+        and torch.cuda.is_bf16_supported()
+        and torch.version.cuda >= "11.0"
+        and dist.is_nccl_available()
+        and nccl.version() >= (2, 10)
+    ) or (is_xpu_available())
+
+    mixed_precision_policy = None
+
+    # Mixed precision
+    if cfg.mixed_precision:
+        bf16_ready = verify_bfloat_support
+
+        if bf16_ready and not cfg.use_fp16:
+            mixed_precision_policy = bfSixteen
+            if rank == 0:
+                print(f"bFloat16 enabled for mixed precision - using bfSixteen policy")
+        elif cfg.use_fp16:
+            mixed_precision_policy = fpSixteen
+            if rank == 0:
+                print(f"FP16 enabled")
+        else:
+            if rank == 0:
+                print(f"bFloat16 support not present. Using FP32, and not mixed precision")
+    return mixed_precision_policy

+ 3 - 2
src/llama_recipes/utils/__init__.py

@@ -1,7 +1,8 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 
+from llama_recipes.utils.model_utils import get_model_and_data_processor
 from llama_recipes.utils.memory_utils import MemoryTrace
 from llama_recipes.utils.memory_utils import MemoryTrace
 from llama_recipes.utils.dataset_utils import *
 from llama_recipes.utils.dataset_utils import *
-from llama_recipes.utils.fsdp_utils import fsdp_auto_wrap_policy, hsdp_device_mesh
-from llama_recipes.utils.train_utils import *
+from llama_recipes.utils.fsdp_utils import hsdp_device_mesh
+from llama_recipes.utils.train_utils import *

+ 78 - 31
src/llama_recipes/utils/fsdp_utils.py

@@ -1,30 +1,14 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
-from torch.distributed._tensor.device_mesh import init_device_mesh
-import os 
+import os
 
 
-def fsdp_auto_wrap_policy(model, transformer_layer_names):
-    import functools
-
-    from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy
-
-    def lambda_policy_fn(module):
-        if (
-            len(list(module.named_children())) == 0
-            and getattr(module, "weight", None) is not None
-            and module.weight.requires_grad
-        ):
-            return True
-        return False
-
-    lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
-    transformer_wrap_policy = functools.partial(
-        transformer_auto_wrap_policy,
-        transformer_layer_cls=set(transformer_layer_names)
-    )
-
-    auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy])
-    return auto_wrap_policy
+import torch
+import torch.nn as nn
+from llama_recipes.configs.fsdp import fsdp_config as FSDP_CONFIG
+from llama_recipes.policies import get_mixed_precision_policies
+from torch.distributed._composable.fsdp import fully_shard, CPUOffloadPolicy
+from torch.distributed._tensor.device_mesh import DeviceMesh, init_device_mesh
+from typing import List, Callable
 
 
 
 
 def hsdp_device_mesh(replica_group_size, sharding_group_size, device=None):
 def hsdp_device_mesh(replica_group_size, sharding_group_size, device=None):
@@ -33,11 +17,11 @@ def hsdp_device_mesh(replica_group_size, sharding_group_size, device=None):
 
 
     This function requires explicit sizes for replica and sharding groups to accommodate models
     This function requires explicit sizes for replica and sharding groups to accommodate models
     whose GPU fit is unknown, providing flexibility in distributed training setups.
     whose GPU fit is unknown, providing flexibility in distributed training setups.
-    
+
     Args:
     Args:
         replica_group_size (int): The size of each replica group. Must be provided to ensure
         replica_group_size (int): The size of each replica group. Must be provided to ensure
             the model fits within the available resources.
             the model fits within the available resources.
-        sharding_group_size (int): The size of each sharding group that the model can fit. Must be provided to 
+        sharding_group_size (int): The size of each sharding group that the model can fit. Must be provided to
             ensure the correct distribution of model parameters.
             ensure the correct distribution of model parameters.
         device (str, optional): The device to use (e.g., "cuda:0"). If None, defaults to "cuda"
         device (str, optional): The device to use (e.g., "cuda:0"). If None, defaults to "cuda"
             with the local rank as the device index.
             with the local rank as the device index.
@@ -59,7 +43,9 @@ def hsdp_device_mesh(replica_group_size, sharding_group_size, device=None):
     """
     """
 
 
     if replica_group_size is None or sharding_group_size is None:
     if replica_group_size is None or sharding_group_size is None:
-        raise ValueError("Both replica_group_size and sharding_group_size must be provided.")
+        raise ValueError(
+            "Both replica_group_size and sharding_group_size must be provided."
+        )
 
 
     local_rank = int(os.getenv("LOCAL_RANK", "0"))
     local_rank = int(os.getenv("LOCAL_RANK", "0"))
     world_size = int(os.getenv("WORLD_SIZE", "1"))
     world_size = int(os.getenv("WORLD_SIZE", "1"))
@@ -67,15 +53,76 @@ def hsdp_device_mesh(replica_group_size, sharding_group_size, device=None):
     device = device or f"cuda"
     device = device or f"cuda"
 
 
     if world_size % sharding_group_size != 0:
     if world_size % sharding_group_size != 0:
-        raise ValueError(f"World size {world_size} is not evenly divisible by "
-                         f"sharding group size {sharding_group_size}.")
+        raise ValueError(
+            f"World size {world_size} is not evenly divisible by "
+            f"sharding group size {sharding_group_size}."
+        )
 
 
     if (world_size // sharding_group_size) % replica_group_size != 0:
     if (world_size // sharding_group_size) % replica_group_size != 0:
-        raise ValueError(f"The calculated number of replica groups is not evenly divisible by "
-                         f"replica_group_size {replica_group_size}.")
+        raise ValueError(
+            f"The calculated number of replica groups is not evenly divisible by "
+            f"replica_group_size {replica_group_size}."
+        )
 
 
     device_mesh = init_device_mesh(device, (replica_group_size, sharding_group_size))
     device_mesh = init_device_mesh(device, (replica_group_size, sharding_group_size))
     if device_mesh is None:
     if device_mesh is None:
         raise RuntimeError("Failed to create a valid device mesh.")
         raise RuntimeError("Failed to create a valid device mesh.")
 
 
     return device_mesh
     return device_mesh
+
+
+def parallelize_model(
+    model: nn.Module,
+    fsdp_config: FSDP_CONFIG,
+    device_mesh: DeviceMesh = None,
+    sharding_conditions: List[Callable] = None,
+) -> nn.Module:
+    """
+    Parallelizes a Llama model using FSDP.
+
+    Args:
+        model (nn.Module): The Llama model to parallelize.
+        fsdp_config (FSDP_CONFIG): The FSDP configuration.
+        device_mesh (torch.device_mesh): The device mesh to use for parallelization.
+
+    Returns:
+        None
+    """
+
+    mp_policy = get_mixed_precision_policies(fsdp_config)
+    fsdp_config = {
+        "mesh": device_mesh,
+        "mp_policy": None if fsdp_config.pure_bf16 else mp_policy,
+        "offload_policy": CPUOffloadPolicy() if fsdp_config.fsdp_cpu_offload else None
+        }
+
+    # Following torchtune's approach to wrap Lora first as dtype is different from base
+    for m in reversed(list(model.modules())):
+        if any(c(m) for c in sharding_conditions):
+            fully_shard(m, reshard_after_forward=True)
+
+    # 
+    # if hasattr(model, "base_model") and hasattr(model.base_model, "model"):
+    #     for n, m in reversed(list(model.named_modules())):
+    #         if any(c(m) for c in sharding_conditions):
+    #         # if (
+    #         #     len(list(m.named_children())) == 0
+    #         #     and getattr(m, "weight", None) is not None
+    #         #     and m.weight.requires_grad
+    #         # ):
+    #             fully_shard(m, reshard_after_forward=True)
+    #     layers = model.base_model.model.model.layers
+    # else:
+    #     layers = model.model.layers
+
+    # for idx, layer in enumerate(layers):
+    #     # Following torch titan we will not reshard the last layer
+    #     # https://github.com/pytorch/torchtitan/blob/7310abea8782bbe459b662bc6d8411fe8d55f62c/torchtitan/parallelisms/parallelize_llama.py#L347
+    #     reshard_after_forward = idx < len(layers) - 1
+    #     fully_shard(
+    #         layer,
+    #         reshard_after_forward=reshard_after_forward,
+    #     )
+
+    # Shard remaining modules like embeddings
+    fully_shard(model, **fsdp_config)

+ 108 - 0
src/llama_recipes/utils/model_utils.py

@@ -0,0 +1,108 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import torch
+import torch.nn as nn
+import torch.distributed as dist
+from llama_recipes.configs import (
+    quantization_config as QuantizationConfig,
+    train_config as TrainConfig
+)
+from transformers import (
+    AutoConfig,
+    AutoProcessor,
+    AutoTokenizer,
+    LlamaForCausalLM,
+    MllamaForConditionalGeneration,
+)
+
+
+def print_model_size(model: nn.Module, config: TrainConfig, rank: int = 0) -> None:
+    """
+    Print model name, the number of trainable parameters and initialization time.
+
+    Args:
+        model: The PyTorch model.
+        model_name (str): Name of the model.
+        init_time_start (float): Initialization start time.
+        init_time_end (float): Initialization end time.
+        rank (int, optional): Current process's rank. Defaults to 0.
+    """
+    if rank == 0:
+        print(f"--> Model {config.model_name}")
+        total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+        print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n")
+
+
+def get_model_and_data_processor(
+    train_config: TrainConfig, quant_config: QuantizationConfig
+):
+    bnb_config = None
+    if quant_config:
+        bnb_config = quant_config.create_bnb_config(train_config.quantization)
+
+    use_cache = False if train_config.enable_fsdp else None
+    config = AutoConfig.from_pretrained(train_config.model_name)
+    if config.model_type == "mllama":
+        is_vision = True
+        model = MllamaForConditionalGeneration.from_pretrained(
+            train_config.model_name,
+            quantization_config=bnb_config,
+            attn_implementation="sdpa" if train_config.use_fast_kernels else None,
+            device_map=(
+                "auto"
+                if train_config.quantization and not train_config.enable_fsdp
+                else None
+            ),
+            torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
+        )
+        processor = AutoProcessor.from_pretrained(
+            train_config.model_name
+            if train_config.tokenizer_name is None
+            else train_config.tokenizer_name
+        )
+        processor.tokenizer.padding_side = "right"
+        model.supports_gradient_checkpointing = True
+        model.language_model.supports_gradient_checkpointing = True
+    elif config.model_type == "llama":
+        is_vision = False
+        model = LlamaForCausalLM.from_pretrained(
+            train_config.model_name,
+            quantization_config=bnb_config,
+            use_cache=use_cache,
+            attn_implementation="sdpa" if train_config.use_fast_kernels else None,
+            device_map=(
+                "auto"
+                if train_config.quantization and not train_config.enable_fsdp
+                else None
+            ),
+            torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
+        )
+
+        # Load the tokenizer and add special tokens
+        processor = AutoTokenizer.from_pretrained(
+            train_config.model_name
+            if train_config.tokenizer_name is None
+            else train_config.tokenizer_name
+        )
+        if not processor.pad_token_id:
+            processor.pad_token_id = processor.eos_token_id
+
+        # If there is a mismatch between tokenizer vocab size and embedding matrix,
+        # throw a warning and then expand the embedding matrix
+        if len(processor) > model.get_input_embeddings().weight.shape[0]:
+            print(
+                "WARNING: Resizing the embedding matrix to match the tokenizer vocab size."
+            )
+            model.resize_token_embeddings(len(processor))
+
+    else:
+        raise ValueError(
+            f"Model type {config.model_type} is not supported. Please use llama or mllama model."
+        )
+
+    print_model_size(
+        model, train_config, dist.get_rank() if train_config.enable_fsdp else 0
+    )
+
+    return model, processor, is_vision

+ 239 - 207
src/llama_recipes/utils/train_utils.py

@@ -1,34 +1,34 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 
+import contextlib
+import json
 import os
 import os
 import time
 import time
-import yaml
 from contextlib import nullcontext
 from contextlib import nullcontext
-from pathlib import Path
 from datetime import datetime
 from datetime import datetime
-import contextlib
-
+from pathlib import Path
 
 
 import torch
 import torch
-import torch.cuda.nccl as nccl
 import torch.distributed as dist
 import torch.distributed as dist
+import yaml
+from accelerate.utils import is_ccl_available, is_xpu_available
+
+from llama_recipes.model_checkpointing import save_checkpoint
+from llama_recipes.policies import bfSixteen, fpSixteen, get_llama_wrapper
+from llama_recipes.utils.flop_utils import FlopMeasure
+from llama_recipes.utils.memory_utils import MemoryTrace
 from torch.distributed.fsdp import StateDictType
 from torch.distributed.fsdp import StateDictType
 from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
 from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
 from tqdm import tqdm
 from tqdm import tqdm
 from transformers import LlamaTokenizer
 from transformers import LlamaTokenizer
-import json
 
 
 
 
-from llama_recipes.model_checkpointing import save_fsdp_model_checkpoint_full, save_model_and_optimizer_sharded, save_optimizer_checkpoint, save_peft_checkpoint, save_model_checkpoint
-from llama_recipes.policies import fpSixteen,bfSixteen, get_llama_wrapper
-from llama_recipes.utils.memory_utils import MemoryTrace
-from accelerate.utils import is_xpu_available, is_ccl_available
-from llama_recipes.utils.flop_utils import FlopMeasure
 def set_tokenizer_params(tokenizer: LlamaTokenizer):
 def set_tokenizer_params(tokenizer: LlamaTokenizer):
     tokenizer.pad_token_id = 0
     tokenizer.pad_token_id = 0
     tokenizer.padding_side = "left"
     tokenizer.padding_side = "left"
 
 
+
 @contextlib.contextmanager
 @contextlib.contextmanager
 def profile(cfg, local_rank=None):
 def profile(cfg, local_rank=None):
     use_profiler: bool = cfg.use_profiler
     use_profiler: bool = cfg.use_profiler
@@ -40,17 +40,21 @@ def profile(cfg, local_rank=None):
         wait_step, warmup_step, active_step = 1, 2, 3
         wait_step, warmup_step, active_step = 1, 2, 3
         min_step = wait_step + warmup_step + active_step + 1
         min_step = wait_step + warmup_step + active_step + 1
         if cfg.max_train_step > 0 and cfg.max_train_step < min_step:
         if cfg.max_train_step > 0 and cfg.max_train_step < min_step:
-            raise ValueError(f"pytorch profiler requires at least {min_step} train steps to finish the warm-up and recording stage, {wait_step} for wait_step, {warmup_step} for warmup_step, {active_step} for profiling step, please increase the max_train_step, current max_train_step {cfg.max_train_step}")
-        print(f"pytorch profiling is activated and results will be saved in {cfg.profiler_dir}")
+            raise ValueError(
+                f"pytorch profiler requires at least {min_step} train steps to finish the warm-up and recording stage, {wait_step} for wait_step, {warmup_step} for warmup_step, {active_step} for profiling step, please increase the max_train_step, current max_train_step {cfg.max_train_step}"
+            )
+        print(
+            f"pytorch profiling is activated and results will be saved in {cfg.profiler_dir}"
+        )
         with torch.profiler.profile(
         with torch.profiler.profile(
             activities=[
             activities=[
                 torch.profiler.ProfilerActivity.CPU,
                 torch.profiler.ProfilerActivity.CPU,
                 torch.profiler.ProfilerActivity.CUDA,
                 torch.profiler.ProfilerActivity.CUDA,
             ],
             ],
-            schedule=torch.profiler.schedule(wait=wait_step, warmup=warmup_step, active=active_step, repeat=1),
-            on_trace_ready=torch.profiler.tensorboard_trace_handler(
-                cfg.profiler_dir
+            schedule=torch.profiler.schedule(
+                wait=wait_step, warmup=warmup_step, active=active_step, repeat=1
             ),
             ),
+            on_trace_ready=torch.profiler.tensorboard_trace_handler(cfg.profiler_dir),
             profile_memory=True,
             profile_memory=True,
             with_stack=False,
             with_stack=False,
             with_flops=True,
             with_flops=True,
@@ -59,15 +63,32 @@ def profile(cfg, local_rank=None):
             yield torch_profiler
             yield torch_profiler
     elif use_flop_counter:
     elif use_flop_counter:
         if cfg.max_train_step > 0 and cfg.max_train_step <= cfg.flop_counter_start:
         if cfg.max_train_step > 0 and cfg.max_train_step <= cfg.flop_counter_start:
-            raise ValueError(f"flop counter requires at least {cfg.flop_counter_start + 1} train steps, please increase the max_train_step, current max_train_step {cfg.max_train_step}")
-        with FlopMeasure(rank=local_rank,warmup_step=cfg.flop_counter_start) as flop_counter:
+            raise ValueError(
+                f"flop counter requires at least {cfg.flop_counter_start + 1} train steps, please increase the max_train_step, current max_train_step {cfg.max_train_step}"
+            )
+        with FlopMeasure(
+            rank=local_rank, warmup_step=cfg.flop_counter_start
+        ) as flop_counter:
             yield flop_counter
             yield flop_counter
     else:
     else:
         torch_profiler = contextlib.nullcontext()
         torch_profiler = contextlib.nullcontext()
         yield None
         yield None
 
 
 
 
-def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None, wandb_run=None):
+def train(
+    model,
+    train_dataloader,
+    eval_dataloader,
+    tokenizer,
+    optimizer,
+    lr_scheduler,
+    gradient_accumulation_steps,
+    train_config,
+    fsdp_config=None,
+    local_rank=None,
+    rank=None,
+    wandb_run=None,
+):
     """
     """
     Trains the model on the given dataloader
     Trains the model on the given dataloader
 
 
@@ -93,13 +114,11 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
     if train_config.enable_fsdp:
     if train_config.enable_fsdp:
         world_size = int(os.environ["WORLD_SIZE"])
         world_size = int(os.environ["WORLD_SIZE"])
 
 
-
-
     autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext
     autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext
     train_prep = []
     train_prep = []
     train_loss = []
     train_loss = []
     val_prep = []
     val_prep = []
-    val_loss =[]
+    val_loss = []
 
 
     if train_config.save_metrics:
     if train_config.save_metrics:
         if not os.path.exists(train_config.output_dir):
         if not os.path.exists(train_config.output_dir):
@@ -127,45 +146,70 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         with MemoryTrace() as memtrace:  # track the memory usage
         with MemoryTrace() as memtrace:  # track the memory usage
             model.train()
             model.train()
             total_loss = 0.0
             total_loss = 0.0
-            total_length = len(train_dataloader)//gradient_accumulation_steps
-            pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True)
-            with profile(train_config,local_rank) as profile_context:
+            total_length = len(train_dataloader) // gradient_accumulation_steps
+            pbar = tqdm(
+                colour="blue",
+                desc=f"Training Epoch: {epoch+1}",
+                total=total_length,
+                dynamic_ncols=True,
+            )
+            with profile(train_config, local_rank) as profile_context:
                 for step, batch in enumerate(train_dataloader):
                 for step, batch in enumerate(train_dataloader):
                     total_train_steps += 1
                     total_train_steps += 1
                     # stop when the maximum number of training steps is reached
                     # stop when the maximum number of training steps is reached
-                    if train_config.max_train_step > 0 and total_train_steps > train_config.max_train_step:
+                    if (
+                        train_config.max_train_step > 0
+                        and total_train_steps > train_config.max_train_step
+                    ):
                         max_steps_reached = True
                         max_steps_reached = True
-                        if not train_config.enable_fsdp or local_rank==0:
-                            print("max training steps reached, stopping training, total train steps finished: ", total_train_steps-1)
+                        if not train_config.enable_fsdp or local_rank == 0:
+                            print(
+                                "max training steps reached, stopping training, total train steps finished: ",
+                                total_train_steps - 1,
+                            )
                         break
                         break
                     for key in batch.keys():
                     for key in batch.keys():
                         if train_config.enable_fsdp:
                         if train_config.enable_fsdp:
                             if is_xpu_available():
                             if is_xpu_available():
-                                batch[key] = batch[key].to(torch.device(f"xpu:{local_rank}"))
+                                batch[key] = batch[key].to(
+                                    torch.device(f"xpu:{local_rank}")
+                                )
                             else:
                             else:
                                 batch[key] = batch[key].to(local_rank)
                                 batch[key] = batch[key].to(local_rank)
                         else:
                         else:
                             if is_xpu_available():
                             if is_xpu_available():
-                                batch[key] = batch[key].to('xpu:0')
+                                batch[key] = batch[key].to("xpu:0")
                             elif torch.cuda.is_available():
                             elif torch.cuda.is_available():
-                                batch[key] = batch[key].to('cuda:0')
+                                batch[key] = batch[key].to("cuda:0")
                     with autocast():
                     with autocast():
                         loss = model(**batch).loss
                         loss = model(**batch).loss
                     total_loss += loss.detach().float()
                     total_loss += loss.detach().float()
                     loss = loss / gradient_accumulation_steps
                     loss = loss / gradient_accumulation_steps
                     if train_config.save_metrics:
                     if train_config.save_metrics:
                         train_step_loss.append(loss.detach().float().item())
                         train_step_loss.append(loss.detach().float().item())
-                        train_step_perplexity.append(float(torch.exp(loss.detach().float())))
+                        train_step_perplexity.append(
+                            float(torch.exp(loss.detach().float()))
+                        )
                     if train_config.use_fp16:
                     if train_config.use_fp16:
                         # if fp16 is enabled, use gradient scaler to handle gradient update
                         # if fp16 is enabled, use gradient scaler to handle gradient update
                         scaler.scale(loss).backward()
                         scaler.scale(loss).backward()
-                        if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
-                            if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
+                        if (step + 1) % gradient_accumulation_steps == 0 or step == len(
+                            train_dataloader
+                        ) - 1:
+                            if (
+                                train_config.gradient_clipping
+                                and train_config.gradient_clipping_threshold > 0.0
+                            ):
                                 scaler.unscale_(optimizer)
                                 scaler.unscale_(optimizer)
                                 if train_config.enable_fsdp:
                                 if train_config.enable_fsdp:
-                                    model.clip_grad_norm_(train_config.gradient_clipping_threshold)
+                                    model.clip_grad_norm_(
+                                        train_config.gradient_clipping_threshold
+                                    )
                                 else:
                                 else:
-                                    torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold)
+                                    torch.nn.utils.clip_grad_norm_(
+                                        model.parameters(),
+                                        train_config.gradient_clipping_threshold,
+                                    )
                             scaler.step(optimizer)
                             scaler.step(optimizer)
                             scaler.update()
                             scaler.update()
                             optimizer.zero_grad()
                             optimizer.zero_grad()
@@ -173,12 +217,22 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                     else:
                     else:
                         # regular backpropagation when fp16 is not used
                         # regular backpropagation when fp16 is not used
                         loss.backward()
                         loss.backward()
-                        if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
-                            if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
+                        if (step + 1) % gradient_accumulation_steps == 0 or step == len(
+                            train_dataloader
+                        ) - 1:
+                            if (
+                                train_config.gradient_clipping
+                                and train_config.gradient_clipping_threshold > 0.0
+                            ):
                                 if train_config.enable_fsdp:
                                 if train_config.enable_fsdp:
-                                    model.clip_grad_norm_(train_config.gradient_clipping_threshold)
+                                    model.clip_grad_norm_(
+                                        train_config.gradient_clipping_threshold
+                                    )
                                 else:
                                 else:
-                                    torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold)
+                                    torch.nn.utils.clip_grad_norm_(
+                                        model.parameters(),
+                                        train_config.gradient_clipping_threshold,
+                                    )
                             optimizer.step()
                             optimizer.step()
                             optimizer.zero_grad()
                             optimizer.zero_grad()
                             pbar.update(1)
                             pbar.update(1)
@@ -187,96 +241,71 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                     if train_config.flop_counter and profile_context.is_done():
                     if train_config.flop_counter and profile_context.is_done():
                         TFlops = profile_context.get_flops_per_sec() / 1e12
                         TFlops = profile_context.get_flops_per_sec() / 1e12
                     if wandb_run:
                     if wandb_run:
-                        if not train_config.enable_fsdp or rank==0:
-                            wandb_run.log({
-                                'train/epoch': epoch + 1,
-                                'train/step': epoch * len(train_dataloader) + step,
-                                'train/loss': loss.detach().float(),
-                            })
-
-                    pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
+                        if not train_config.enable_fsdp or rank == 0:
+                            wandb_run.log(
+                                {
+                                    "train/epoch": epoch + 1,
+                                    "train/step": epoch * len(train_dataloader) + step,
+                                    "train/loss": loss.detach().float(),
+                                }
+                            )
+
+                    pbar.set_description(
+                        f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})"
+                    )
 
 
                     if train_config.save_metrics:
                     if train_config.save_metrics:
-                        save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep)
+                        save_to_json(
+                            metrics_filename,
+                            train_step_loss,
+                            train_loss,
+                            train_step_perplexity,
+                            train_prep,
+                            val_step_loss,
+                            val_loss,
+                            val_step_perplexity,
+                            val_prep,
+                        )
                 pbar.close()
                 pbar.close()
 
 
-        epoch_end_time = time.perf_counter()-epoch_start_time
+        epoch_end_time = time.perf_counter() - epoch_start_time
         epoch_times.append(epoch_end_time)
         epoch_times.append(epoch_end_time)
         # Reducing total_loss across all devices if there's more than one CUDA device
         # Reducing total_loss across all devices if there's more than one CUDA device
-        if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp):
+        if is_xpu_available() and (
+            torch.xpu.device_count() > 1 and train_config.enable_fsdp
+        ):
             dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
             dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
         elif torch.cuda.device_count() > 1 and train_config.enable_fsdp:
         elif torch.cuda.device_count() > 1 and train_config.enable_fsdp:
             dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
             dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
         train_epoch_loss = total_loss / len(train_dataloader)
         train_epoch_loss = total_loss / len(train_dataloader)
         if train_config.enable_fsdp:
         if train_config.enable_fsdp:
-            train_epoch_loss = train_epoch_loss/world_size
+            train_epoch_loss = train_epoch_loss / world_size
         train_perplexity = torch.exp(train_epoch_loss)
         train_perplexity = torch.exp(train_epoch_loss)
 
 
         train_prep.append(float(train_perplexity))
         train_prep.append(float(train_perplexity))
         train_loss.append(float(train_epoch_loss))
         train_loss.append(float(train_epoch_loss))
 
 
-        if not train_config.enable_fsdp or rank==0:
+        if not train_config.enable_fsdp or rank == 0:
             memtrace.print_stats()
             memtrace.print_stats()
 
 
         # Update the learning rate as needed
         # Update the learning rate as needed
         lr_scheduler.step()
         lr_scheduler.step()
         should_save_model = train_config.save_model
         should_save_model = train_config.save_model
         if train_config.run_validation:
         if train_config.run_validation:
-            eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer, wandb_run)
+            eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(
+                model, train_config, eval_dataloader, local_rank, tokenizer, wandb_run
+            )
             if train_config.save_metrics:
             if train_config.save_metrics:
                 val_step_loss.extend(temp_val_loss)
                 val_step_loss.extend(temp_val_loss)
                 val_step_perplexity.extend(temp_step_perplexity)
                 val_step_perplexity.extend(temp_step_perplexity)
-            should_save_model = train_config.save_model and eval_epoch_loss < best_val_loss
-        
+            should_save_model = (
+                train_config.save_model and eval_epoch_loss < best_val_loss
+            )
+
         checkpoint_start_time = time.perf_counter()
         checkpoint_start_time = time.perf_counter()
         if should_save_model:
         if should_save_model:
-            if train_config.enable_fsdp:
-                dist.barrier()
-            if train_config.use_peft:
-                if train_config.enable_fsdp:
-                    if rank==0:
-                        print(f"we are about to save the PEFT modules")
-                else:
-                    print(f"we are about to save the PEFT modules")
-                save_peft_checkpoint(model, train_config.output_dir)
-                if train_config.enable_fsdp:
-                    if rank==0:
-                        print(f"PEFT modules are saved in {train_config.output_dir} directory")
-                else:
-                    print(f"PEFT modules are saved in {train_config.output_dir} directory")
-
-            else:
-                if not train_config.enable_fsdp:
-                    save_model_checkpoint(model, train_config.output_dir)
-                    
-                elif fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
-                    print(" Saving the FSDP model checkpoint using FULL_STATE_DICT")
-                    print("=====================================================")
-                    save_fsdp_model_checkpoint_full(
-                        model, optimizer, rank, train_config, epoch=epoch
-                    )
-                    
-                    if train_config.save_optimizer:
-                        print(" Saving the FSDP optimizer using FULL_STATE_DICT")
-                        print("=====================================================")
-                        save_optimizer_checkpoint(
-                            model, optimizer, rank, train_config, epoch=epoch
-                        )
-                    
-                elif fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
+            save_checkpoint(model, optimizer, train_config, fsdp_config, epoch)
 
 
-                    if train_config.save_optimizer:
-                        print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
-                        print("=====================================================")
-                        save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
-                    else:
-                        print(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT")
-                        print("=====================================================")
-                        save_model_and_optimizer_sharded(model, rank, train_config)
-
-                    
-            if train_config.enable_fsdp:
-                dist.barrier()
         checkpoint_end_time = time.perf_counter() - checkpoint_start_time
         checkpoint_end_time = time.perf_counter() - checkpoint_start_time
         checkpoint_times.append(checkpoint_end_time)
         checkpoint_times.append(checkpoint_end_time)
 
 
@@ -284,48 +313,67 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
             if eval_epoch_loss < best_val_loss:
             if eval_epoch_loss < best_val_loss:
                 best_val_loss = eval_epoch_loss
                 best_val_loss = eval_epoch_loss
                 if train_config.enable_fsdp:
                 if train_config.enable_fsdp:
-                    if rank==0:
+                    if rank == 0:
                         print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
                         print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
                 else:
                 else:
-                        print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
+                    print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
             val_loss.append(float(eval_epoch_loss))
             val_loss.append(float(eval_epoch_loss))
             val_prep.append(float(eval_ppl))
             val_prep.append(float(eval_ppl))
         if train_config.enable_fsdp:
         if train_config.enable_fsdp:
-            if rank==0:
-                print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
+            if rank == 0:
+                print(
+                    f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s"
+                )
         else:
         else:
-            print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
+            print(
+                f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s"
+            )
 
 
         # Saving the results every epoch to plot later
         # Saving the results every epoch to plot later
         if train_config.save_metrics:
         if train_config.save_metrics:
-            save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep)
+            save_to_json(
+                metrics_filename,
+                train_step_loss,
+                train_loss,
+                train_step_perplexity,
+                train_prep,
+                val_step_loss,
+                val_loss,
+                val_step_perplexity,
+                val_prep,
+            )
 
 
-    avg_epoch_time = sum(epoch_times)/ len(epoch_times)
-    avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times) if len(checkpoint_times) > 0 else 0
-    avg_train_prep = sum(train_prep)/len(train_prep)
-    avg_train_loss = sum(train_loss)/len(train_loss)
+    avg_epoch_time = sum(epoch_times) / len(epoch_times)
+    avg_checkpoint_time = (
+        sum(checkpoint_times) / len(checkpoint_times)
+        if len(checkpoint_times) > 0
+        else 0
+    )
+    avg_train_prep = sum(train_prep) / len(train_prep)
+    avg_train_loss = sum(train_loss) / len(train_loss)
     if train_config.run_validation:
     if train_config.run_validation:
-        avg_eval_prep = sum(val_prep)/len(val_prep)
-        avg_eval_loss = sum(val_loss)/len(val_loss)
+        avg_eval_prep = sum(val_prep) / len(val_prep)
+        avg_eval_loss = sum(val_loss) / len(val_loss)
 
 
-    results['avg_train_prep'] = avg_train_prep
-    results['avg_train_loss'] = avg_train_loss
+    results["avg_train_prep"] = avg_train_prep
+    results["avg_train_loss"] = avg_train_loss
     if train_config.run_validation:
     if train_config.run_validation:
-        results['avg_eval_prep'] = avg_eval_prep
-        results['avg_eval_loss'] = avg_eval_loss
+        results["avg_eval_prep"] = avg_eval_prep
+        results["avg_eval_loss"] = avg_eval_loss
     results["avg_epoch_time"] = avg_epoch_time
     results["avg_epoch_time"] = avg_epoch_time
     results["avg_checkpoint_time"] = avg_checkpoint_time
     results["avg_checkpoint_time"] = avg_checkpoint_time
     if train_config.save_metrics:
     if train_config.save_metrics:
         results["metrics_filename"] = metrics_filename
         results["metrics_filename"] = metrics_filename
     if train_config.flop_counter:
     if train_config.flop_counter:
-        results["model_tflops"]= TFlops
-    #saving the training params including fsdp setting for reference.
-    if train_config.enable_fsdp and not train_config.use_peft and rank==0:
+        results["model_tflops"] = TFlops
+    # saving the training params including fsdp setting for reference.
+    if train_config.enable_fsdp and not train_config.use_peft and rank == 0:
         save_train_params(train_config, fsdp_config, rank)
         save_train_params(train_config, fsdp_config, rank)
 
 
     return results
     return results
 
 
-def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb_run):
+
+def evaluation(model, train_config, eval_dataloader, local_rank, tokenizer, wandb_run):
     """
     """
     Evaluates the model on the given dataloader
     Evaluates the model on the given dataloader
 
 
@@ -346,21 +394,34 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb
     eval_loss = 0.0  # Initialize evaluation loss
     eval_loss = 0.0  # Initialize evaluation loss
     total_eval_steps = 0
     total_eval_steps = 0
     with MemoryTrace() as memtrace:
     with MemoryTrace() as memtrace:
-        for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch", dynamic_ncols=True)):
+        for step, batch in enumerate(
+            tqdm(
+                eval_dataloader,
+                colour="green",
+                desc="evaluating Epoch",
+                dynamic_ncols=True,
+            )
+        ):
             total_eval_steps += 1
             total_eval_steps += 1
             # stop when the maximum number of eval steps is reached
             # stop when the maximum number of eval steps is reached
-            if train_config.max_eval_step > 0 and total_eval_steps > train_config.max_eval_step:
-                if not train_config.enable_fsdp or local_rank==0:
-                    print("max eval steps reached, stopping evaluation, total_eval_steps: ", total_eval_steps - 1)
+            if (
+                train_config.max_eval_step > 0
+                and total_eval_steps > train_config.max_eval_step
+            ):
+                if not train_config.enable_fsdp or local_rank == 0:
+                    print(
+                        "max eval steps reached, stopping evaluation, total_eval_steps: ",
+                        total_eval_steps - 1,
+                    )
                 break
                 break
             for key in batch.keys():
             for key in batch.keys():
                 if train_config.enable_fsdp:
                 if train_config.enable_fsdp:
                     batch[key] = batch[key].to(local_rank)
                     batch[key] = batch[key].to(local_rank)
                 else:
                 else:
                     if is_xpu_available():
                     if is_xpu_available():
-                        batch[key] = batch[key].to('xpu:0')
+                        batch[key] = batch[key].to("xpu:0")
                     else:
                     else:
-                        batch[key] = batch[key].to('cuda:0')
+                        batch[key] = batch[key].to("cuda:0")
             # Ensure no gradients are computed for this scope to save memory
             # Ensure no gradients are computed for this scope to save memory
             with torch.no_grad():
             with torch.no_grad():
                 # Forward pass and compute loss
                 # Forward pass and compute loss
@@ -374,11 +435,15 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb
             # Decode predictions and add to evaluation predictions list
             # Decode predictions and add to evaluation predictions list
             preds = torch.argmax(outputs.logits, -1)
             preds = torch.argmax(outputs.logits, -1)
             eval_preds.extend(
             eval_preds.extend(
-                tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True)
+                tokenizer.batch_decode(
+                    preds.detach().cpu().numpy(), skip_special_tokens=True
+                )
             )
             )
 
 
     # If there's more than one CUDA device, reduce evaluation loss across all devices
     # If there's more than one CUDA device, reduce evaluation loss across all devices
-    if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp):
+    if is_xpu_available() and (
+        torch.xpu.device_count() > 1 and train_config.enable_fsdp
+    ):
         dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
         dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
     if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
     if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
         dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
         dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
@@ -386,35 +451,39 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb
     # Compute average loss and perplexity
     # Compute average loss and perplexity
     eval_epoch_loss = eval_loss / len(eval_dataloader)
     eval_epoch_loss = eval_loss / len(eval_dataloader)
     if train_config.enable_fsdp:
     if train_config.enable_fsdp:
-        eval_epoch_loss = eval_epoch_loss/world_size
+        eval_epoch_loss = eval_epoch_loss / world_size
     eval_ppl = torch.exp(eval_epoch_loss)
     eval_ppl = torch.exp(eval_epoch_loss)
 
 
     # Print evaluation metrics
     # Print evaluation metrics
     if train_config.enable_fsdp:
     if train_config.enable_fsdp:
-        if local_rank==0:
+        if local_rank == 0:
             print(f" {eval_ppl=} {eval_epoch_loss=}")
             print(f" {eval_ppl=} {eval_epoch_loss=}")
     else:
     else:
         print(f" {eval_ppl=} {eval_epoch_loss=}")
         print(f" {eval_ppl=} {eval_epoch_loss=}")
 
 
     if wandb_run:
     if wandb_run:
-        wandb_run.log({
-                        'eval/perplexity': eval_ppl,
-                        'eval/loss': eval_epoch_loss,
-                    }, commit=False)
+        wandb_run.log(
+            {
+                "eval/perplexity": eval_ppl,
+                "eval/loss": eval_epoch_loss,
+            },
+            commit=False,
+        )
 
 
     return eval_ppl, eval_epoch_loss, val_step_loss, val_step_perplexity
     return eval_ppl, eval_epoch_loss, val_step_loss, val_step_perplexity
 
 
+
 def freeze_transformer_layers(model, num_layer):
 def freeze_transformer_layers(model, num_layer):
-   for i, layer in enumerate(model.model.layers):
-            if i < num_layer:
-                for param in layer.parameters():
-                    param.requires_grad = False
+    for i, layer in enumerate(model.model.layers):
+        if i < num_layer:
+            for param in layer.parameters():
+                param.requires_grad = False
 
 
 
 
 def check_frozen_layers_peft_model(model):
 def check_frozen_layers_peft_model(model):
-     for i, layer in enumerate(model.base_model.model.model.layers):
-            for name, param in layer.named_parameters():
-                print(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}")
+    for i, layer in enumerate(model.base_model.model.model.layers):
+        for name, param in layer.named_parameters():
+            print(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}")
 
 
 
 
 def setup():
 def setup():
@@ -460,58 +529,6 @@ def get_parameter_dtypes(model):
         parameter_dtypes[name] = parameter.dtype
         parameter_dtypes[name] = parameter.dtype
     return parameter_dtypes
     return parameter_dtypes
 
 
-def print_model_size(model, config, rank: int = 0) -> None:
-    """
-    Print model name, the number of trainable parameters and initialization time.
-
-    Args:
-        model: The PyTorch model.
-        model_name (str): Name of the model.
-        init_time_start (float): Initialization start time.
-        init_time_end (float): Initialization end time.
-        rank (int, optional): Current process's rank. Defaults to 0.
-    """
-    if rank == 0:
-        print(f"--> Model {config.model_name}")
-        total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
-        print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n")
-
-
-
-
-def get_policies(cfg, rank):
-    """Get the policies for mixed precision and fsdp wrapping"""
-
-
-    verify_bfloat_support = ((
-    torch.version.cuda
-    and torch.cuda.is_bf16_supported()
-    and torch.version.cuda >= "11.0"
-    and dist.is_nccl_available()
-    and nccl.version() >= (2, 10)
-    ) or
-    (is_xpu_available()))
-
-
-    mixed_precision_policy = None
-    wrapping_policy = None
-
-    # Mixed precision
-    if cfg.mixed_precision:
-        bf16_ready = verify_bfloat_support
-
-        if bf16_ready and not cfg.use_fp16:
-            mixed_precision_policy = bfSixteen
-            if rank == 0:
-                print(f"bFloat16 enabled for mixed precision - using bfSixteen policy")
-        elif cfg.use_fp16:
-            mixed_precision_policy = fpSixteen
-            if rank == 0:
-                print(f"FP16 enabled")
-        else:
-            print(f"bFloat16 support not present. Using FP32, and not mixed precision")
-    wrapping_policy = get_llama_wrapper()
-    return mixed_precision_policy, wrapping_policy
 
 
 def save_train_params(train_config, fsdp_config, rank):
 def save_train_params(train_config, fsdp_config, rank):
     """
     """
@@ -521,17 +538,21 @@ def save_train_params(train_config, fsdp_config, rank):
     """
     """
     # Convert the train_config and fsdp_config objects to dictionaries,
     # Convert the train_config and fsdp_config objects to dictionaries,
     # converting all values to strings to ensure they can be serialized into a YAML file
     # converting all values to strings to ensure they can be serialized into a YAML file
-    train_config_dict = {k: str(v) for k, v in vars(train_config).items() if not k.startswith('__')}
-    fsdp_config_dict = {k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith('__')}
+    train_config_dict = {
+        k: str(v) for k, v in vars(train_config).items() if not k.startswith("__")
+    }
+    fsdp_config_dict = {
+        k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith("__")
+    }
     # Merge the two dictionaries into one
     # Merge the two dictionaries into one
     train_params_dict = {**train_config_dict, **fsdp_config_dict}
     train_params_dict = {**train_config_dict, **fsdp_config_dict}
     # Construct the folder name (follwoing FSDP checkpointing style) using properties of the train_config object
     # Construct the folder name (follwoing FSDP checkpointing style) using properties of the train_config object
     folder_name = (
     folder_name = (
-    train_config.dist_checkpoint_root_folder
-    + "/"
-    + train_config.dist_checkpoint_folder
-    + "-"
-    + train_config.model_name
+        train_config.dist_checkpoint_root_folder
+        + "/"
+        + train_config.dist_checkpoint_folder
+        + "-"
+        + train_config.model_name
     )
     )
 
 
     save_dir = Path.cwd() / folder_name
     save_dir = Path.cwd() / folder_name
@@ -540,19 +561,30 @@ def save_train_params(train_config, fsdp_config, rank):
         os.makedirs(save_dir)
         os.makedirs(save_dir)
     # Convert the dictionary to a YAML string
     # Convert the dictionary to a YAML string
     config_yaml = yaml.dump(train_params_dict, indent=4)
     config_yaml = yaml.dump(train_params_dict, indent=4)
-    file_name = os.path.join(save_dir,'train_params.yaml')
+    file_name = os.path.join(save_dir, "train_params.yaml")
 
 
     # Check if there's a directory with the same name as the file
     # Check if there's a directory with the same name as the file
     if os.path.isdir(file_name):
     if os.path.isdir(file_name):
         print(f"Error: {file_name} is a directory, not a file.")
         print(f"Error: {file_name} is a directory, not a file.")
     else:
     else:
         # Write the YAML string to the file
         # Write the YAML string to the file
-        with open(file_name, 'w') as f:
+        with open(file_name, "w") as f:
             f.write(config_yaml)
             f.write(config_yaml)
-        if rank==0:
+        if rank == 0:
             print(f"training params are saved in {file_name}")
             print(f"training params are saved in {file_name}")
 
 
-def save_to_json(output_filename, train_step_loss, train_epoch_loss, train_step_ppl, train_epoch_ppl, val_step_loss, val_epoch_loss, val_step_ppl, val_epoch_ppl):
+
+def save_to_json(
+    output_filename,
+    train_step_loss,
+    train_epoch_loss,
+    train_step_ppl,
+    train_epoch_ppl,
+    val_step_loss,
+    val_epoch_loss,
+    val_step_ppl,
+    val_epoch_ppl,
+):
     metrics_data = {
     metrics_data = {
         "train_step_loss": train_step_loss,
         "train_step_loss": train_step_loss,
         "train_epoch_loss": train_epoch_loss,
         "train_epoch_loss": train_epoch_loss,
@@ -561,7 +593,7 @@ def save_to_json(output_filename, train_step_loss, train_epoch_loss, train_step_
         "val_step_loss": val_step_loss,
         "val_step_loss": val_step_loss,
         "val_epoch_loss": val_epoch_loss,
         "val_epoch_loss": val_epoch_loss,
         "val_step_perplexity": val_step_ppl,
         "val_step_perplexity": val_step_ppl,
-        "val_epoch_perplexity": val_epoch_ppl
+        "val_epoch_perplexity": val_epoch_ppl,
     }
     }
     with open(output_filename, "w") as f:
     with open(output_filename, "w") as f:
         json.dump(metrics_data, f)
         json.dump(metrics_data, f)