فهرست منبع

Merge branch 'main' of https://github.com/facebookresearch/llama-recipes into optimizer_overlap

Hamid Shojanazeri 1 سال پیش
والد
کامیت
7f23559bdd

+ 1 - 2
README.md

@@ -39,7 +39,7 @@ Llama 2 is a new technology that carries potential risks with use. Testing condu
 To run the examples, make sure to install the requirements using
 To run the examples, make sure to install the requirements using
 
 
 ```bash
 ```bash
-
+# python 3.9 or higher recommended
 pip install -r requirements.txt
 pip install -r requirements.txt
 
 
 ```
 ```
@@ -55,7 +55,6 @@ Given that the original checkpoint resides under models/7B you can install all r
 ## Install HuggingFace Transformers from source
 ## Install HuggingFace Transformers from source
 pip freeze | grep transformers ## verify it is version 4.31.0 or higher
 pip freeze | grep transformers ## verify it is version 4.31.0 or higher
 
 
-```bash
 git clone git@github.com:huggingface/transformers.git
 git clone git@github.com:huggingface/transformers.git
 cd transformers
 cd transformers
 pip install protobuf
 pip install protobuf

تفاوت فایلی نمایش داده نمی شود زیرا این فایل بسیار بزرگ است
+ 13 - 0
docs/FAQ.md


+ 25 - 0
docs/inference.md

@@ -34,6 +34,31 @@ The inference folder also includes a chat completion example, that adds built-in
 python inference/chat_completion.py --model_name "PATH/TO/MODEL/7B/" --prompt_file inference/chats.json  --quantization --use_auditnlg
 python inference/chat_completion.py --model_name "PATH/TO/MODEL/7B/" --prompt_file inference/chats.json  --quantization --use_auditnlg
 
 
 ```
 ```
+## Loading back FSDP checkpoints
+
+In case you have fine-tuned your model with pure FSDP and saved the checkpoints with "SHARDED_STATE_DICT" as shown [here](../configs/fsdp.py), you can use this converter script to convert the FSDP Sharded checkpoints into HuggingFace checkpoints. This enables you to use the inference script normally as mentioned above.
+**To convert the checkpoint use the following command**:
+
+This is helpful if you have fine-tuned you model using FSDP only as follows:
+
+```bash
+torchrun --nnodes 1 --nproc_per_node 8  llama_finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned --pure_bf16 
+```
+Then convert your FSDP checkpoint to HuggingFace checkpoints using:
+```bash
+ python inference/checkpoint_converter_fsdp_hf.py --fsdp_checkpoint_path  PATH/to/FSDP/Checkpoints --consolidated_model_path PATH/to/save/checkpoints --HF_model_path_or_name PATH/or/HF/model_name
+
+ # --HF_model_path_or_name specifies the HF Llama model name or path where it has config.json and tokenizer.json
+ ```
+By default, training parameter are saved in `train_params.yaml` in the path where FSDP checkpoints are saved, in the converter script we frist try to find the HugingFace model name used in the fine-tuning to load the model with configs from there, if not found user need to provide it.
+
+Then run inference using:
+
+```bash
+python inference/inference.py --model_name <training_config.output_dir> --prompt_file <test_prompt_file> 
+
+```
+
 
 
 ## Other Inference Options
 ## Other Inference Options
 
 

+ 1 - 1
inference/README.md

@@ -2,7 +2,7 @@
 
 
 This folder contains inference examples for Llama 2. So far, we have provided support for three methods of inference:
 This folder contains inference examples for Llama 2. So far, we have provided support for three methods of inference:
 
 
-1. [inference script](inference.py) script provides support for Hugging Face accelerate and PEFT fine tuned models.
+1. [inference script](inference.py) script provides support for Hugging Face accelerate, PEFT and FSDP fine tuned models.
 
 
 2. [vLLM_inference.py](vLLM_inference.py) script takes advantage of vLLM's paged attention concept for low latency.
 2. [vLLM_inference.py](vLLM_inference.py) script takes advantage of vLLM's paged attention concept for low latency.
 
 

+ 63 - 0
inference/checkpoint_converter_fsdp_hf.py

@@ -0,0 +1,63 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+# from accelerate import init_empty_weights, load_checkpoint_and_dispatch
+
+import fire
+import torch
+import os
+import sys
+import yaml
+from transformers import LlamaTokenizer
+from model_utils import  load_llama_from_config
+# Get the current file's directory
+current_directory = os.path.dirname(os.path.abspath(__file__))
+
+# Get the parent directory
+parent_directory = os.path.dirname(current_directory)
+
+# Append the parent directory to sys.path
+sys.path.append(parent_directory)
+from model_checkpointing import load_sharded_model_single_gpu
+
+def main(
+    fsdp_checkpoint_path="", # Path to FSDP Sharded model checkpoints
+    consolidated_model_path="", # Path to save the HF converted model checkpoints
+    HF_model_path_or_name="" # Path/ name of the HF model that include config.json and tokenizer_config.json (e.g. meta-llama/Llama-2-7b-chat-hf)
+    ):
+    
+    try:
+        file_name = 'train_params.yaml'
+        # Combine the directory and file name to create the full path
+        train_params_path = os.path.join(fsdp_checkpoint_path, file_name)
+        # Open the file
+        with open(train_params_path, 'r') as file:
+            # Load the YAML data
+            data = yaml.safe_load(file)
+
+            # Access the 'model_name' field
+            HF_model_path_or_name = data.get('model_name')
+
+            print(f"Model name: {HF_model_path_or_name}")
+    except FileNotFoundError:
+        print(f"The file {train_params_path} does not exist.")
+        HF_model_path_or_name = input("Please enter the model name: ")
+        print(f"Model name: {HF_model_path_or_name}")
+    except Exception as e:
+        print(f"An error occurred: {e}")
+        
+        
+    #load the HF model definition from config
+    model_def = load_llama_from_config(HF_model_path_or_name)
+    print("model is loaded from config")
+    #load the FSDP sharded checkpoints into the model
+    model = load_sharded_model_single_gpu(model_def, fsdp_checkpoint_path)
+    print("model is loaded from FSDP checkpoints")
+    #loading the tokenizer form the  model_path
+    tokenizer = LlamaTokenizer.from_pretrained(HF_model_path_or_name)
+    tokenizer.save_pretrained(consolidated_model_path)
+    #save the FSDP sharded checkpoints in HF format
+    model.save_pretrained(consolidated_model_path)
+    print(f"HuggingFace model checkpoints has been saved in {consolidated_model_path}")
+if __name__ == "__main__":
+    fire.Fire(main)

+ 1 - 2
inference/inference.py

@@ -12,8 +12,7 @@ from typing import List
 
 
 from transformers import LlamaTokenizer
 from transformers import LlamaTokenizer
 from safety_utils import get_safety_checker
 from safety_utils import get_safety_checker
-from model_utils import load_model, load_peft_model
-
+from model_utils import load_model, load_peft_model, load_llama_from_config
 
 
 def main(
 def main(
     model_name,
     model_name,

+ 10 - 2
inference/model_utils.py

@@ -2,7 +2,7 @@
 # This software may be used and distributed according to the terms of the GNU General Public License version 3.
 # This software may be used and distributed according to the terms of the GNU General Public License version 3.
 
 
 from peft import PeftModel
 from peft import PeftModel
-from transformers import LlamaForCausalLM
+from transformers import LlamaForCausalLM, LlamaConfig
 
 
 # Function to load the main model for text generation
 # Function to load the main model for text generation
 def load_model(model_name, quantization):
 def load_model(model_name, quantization):
@@ -19,4 +19,12 @@ def load_model(model_name, quantization):
 # Function to load the PeftModel for performance optimization
 # Function to load the PeftModel for performance optimization
 def load_peft_model(model, peft_model):
 def load_peft_model(model, peft_model):
     peft_model = PeftModel.from_pretrained(model, peft_model)
     peft_model = PeftModel.from_pretrained(model, peft_model)
-    return peft_model
+    return peft_model
+
+# Loading the model from config to load FSDP checkpoints into that
+def load_llama_from_config(config_path):
+    model_config = LlamaConfig.from_pretrained(config_path) 
+    model = LlamaForCausalLM(config=model_config)
+    return model
+    
+    

+ 1 - 2
model_checkpointing/__init__.py

@@ -4,10 +4,9 @@
 from .checkpoint_handler import (
 from .checkpoint_handler import (
     load_model_checkpoint,
     load_model_checkpoint,
     save_model_checkpoint,
     save_model_checkpoint,
-    save_distributed_model_checkpoint,
-    load_distributed_model_checkpoint,
     load_optimizer_checkpoint,
     load_optimizer_checkpoint,
     save_optimizer_checkpoint,
     save_optimizer_checkpoint,
     save_model_and_optimizer_sharded,
     save_model_and_optimizer_sharded,
     load_model_sharded,
     load_model_sharded,
+    load_sharded_model_single_gpu
 )
 )

+ 49 - 88
model_checkpointing/checkpoint_handler.py

@@ -44,7 +44,7 @@ def get_date_of_run():
 fullstate_save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
 fullstate_save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
 
 
 
 
-def load_model_sharded(model, rank, cfg, verbose=True):
+def load_model_sharded(model, rank, cfg):
     # torch.manual_seed(103)
     # torch.manual_seed(103)
     folder_name = (
     folder_name = (
         cfg.dist_checkpoint_root_folder
         cfg.dist_checkpoint_root_folder
@@ -83,7 +83,7 @@ def load_model_sharded(model, rank, cfg, verbose=True):
         print(f"Sharded state checkpoint loaded from {load_dir}")
         print(f"Sharded state checkpoint loaded from {load_dir}")
 
 
 
 
-def save_model_and_optimizer_sharded(model, rank, cfg,optim=None, verbose=True):
+def save_model_and_optimizer_sharded(model, rank, cfg,optim=None):
     """save model and optimizer via sharded_state_dict to save_dir"""
     """save model and optimizer via sharded_state_dict to save_dir"""
     
     
     folder_name = (
     folder_name = (
@@ -142,7 +142,14 @@ def save_model_checkpoint(
     if rank == 0:
     if rank == 0:
         print(f"--> saving model ...")
         print(f"--> saving model ...")
         # create save path
         # create save path
-        save_dir = Path.cwd() / cfg.checkpoint_folder
+        folder_name = (
+        cfg.dist_checkpoint_root_folder
+        + "/"
+        + cfg.dist_checkpoint_folder
+        + "-"
+        + cfg.model_name
+        )
+        save_dir = Path.cwd() / folder_name
         save_dir.mkdir(parents=True, exist_ok=True)
         save_dir.mkdir(parents=True, exist_ok=True)
         save_name = cfg.model_name + "-" + str(epoch) + ".pt"
         save_name = cfg.model_name + "-" + str(epoch) + ".pt"
         save_full_path = str(save_dir) + "/" + save_name
         save_full_path = str(save_dir) + "/" + save_name
@@ -150,12 +157,12 @@ def save_model_checkpoint(
         # save model
         # save model
         torch.save(cpu_state, save_full_path)
         torch.save(cpu_state, save_full_path)
 
 
-        if cfg.verbose:
-            print(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n")
+        
+        print(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n")
       
       
 
 
 
 
-def load_model_checkpoint(model, rank, cfg, verbose=True):
+def load_model_checkpoint(model, rank, cfg):
     """load local checkpoint to rank0 cpu
     """load local checkpoint to rank0 cpu
     must be called * before * passing to FSDP"""
     must be called * before * passing to FSDP"""
 
 
@@ -178,8 +185,8 @@ def load_model_checkpoint(model, rank, cfg, verbose=True):
     # integrate into loaded model
     # integrate into loaded model
     model.load_state_dict(model_checkpoint)
     model.load_state_dict(model_checkpoint)
 
 
-    if cfg.verbose:
-        print(f"model checkpoint loaded to rank0 cpu")
+    
+    print(f"model checkpoint loaded to rank0 cpu")
 
 
 
 
 def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1):
 def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1):
@@ -192,15 +199,22 @@ def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1):
 
 
     optim_state = FSDP.full_optim_state_dict(model, optimizer)
     optim_state = FSDP.full_optim_state_dict(model, optimizer)
 
 
-    if cfg.verbose:
-        print(f"optim state dict ready on {rank} and len of {len(optim_state)}\n")
+    
+    print(f"optim state dict ready on {rank} and len of {len(optim_state)}\n")
 
 
     if rank == 0:
     if rank == 0:
-        save_dir = Path.cwd() / cfg.checkpoint_folder
+        folder_name = (
+        cfg.dist_checkpoint_root_folder
+        + "/"
+        + cfg.dist_checkpoint_folder
+        + "-"
+        + cfg.model_name
+        )
+        save_dir = Path.cwd() / folder_name
         save_dir.mkdir(parents=True, exist_ok=True)
         save_dir.mkdir(parents=True, exist_ok=True)
 
 
         opt_save_name = (
         opt_save_name = (
-            cfg.optimizer_name + "-" + cfg.model_name + "-" + str(epoch) + ".pt"
+            "optimizer" + "-" + cfg.model_name + "-" + str(epoch) + ".pt"
         )
         )
         opt_save_full_path = save_dir / opt_save_name
         opt_save_full_path = save_dir / opt_save_name
 
 
@@ -211,96 +225,43 @@ def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1):
         print(f"--> saved {opt_save_full_path} to disk")
         print(f"--> saved {opt_save_full_path} to disk")
 
 
 
 
-def load_optimizer_checkpoint(model, optimizer, rank, cfg):
+def load_optimizer_checkpoint(model, optimizer_checkpoint_path, rank):
     """load an fsdp optimizer full_state checkpoint using scatter method
     """load an fsdp optimizer full_state checkpoint using scatter method
     this ensures only rank 0 loads the optimizer state dict and scatters to other ranks
     this ensures only rank 0 loads the optimizer state dict and scatters to other ranks
     """
     """
 
 
-    opt_file_path = Path.cwd() / cfg.checkpoint_folder / cfg.optimizer_checkpoint_file
 
 
-    if not opt_file_path.is_file():
+    if not optimizer_checkpoint_path.is_file():
         print(
         print(
-            f"warning - optimizer checkpoint not present {opt_file_path}. Returning. "
+            f"warning - optimizer checkpoint not present {optimizer_checkpoint_path}. Returning. "
         )
         )
         return
         return
 
 
     full_osd = None
     full_osd = None
 
 
     if rank == 0:
     if rank == 0:
-        full_osd = torch.load(opt_file_path)
-
-        if cfg.verbose:
-            print(f"loaded full osd on rank 0")
+        full_osd = torch.load(optimizer_checkpoint_path)
 
 
     # called from all ranks, though only rank0 has a valid param for full_osd
     # called from all ranks, though only rank0 has a valid param for full_osd
     sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, model)
     sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, model)
 
 
-    if cfg.verbose:
-        print(f"optimizer shard loaded on rank {rank}")
-
-
-
-def load_distributed_model_checkpoint(model, rank, cfg):
-    if cfg.checkpoint_type == StateDictType.LOCAL_STATE_DICT:
-        print(f"loading distributed checkpoint, rank {rank}...")
-        folder_name = (
-            cfg.dist_checkpoint_root_folder
-            + "/"
-            + cfg.dist_checkpoint_folder
-            + "-"
-            + cfg.model_name
-        )
+    print(f"optimizer shard loaded on rank {rank}")
 
 
-        checkdir = Path.cwd() / folder_name
-
-        if not checkdir.exists():
-            if rank == 0:
-                print(f"No checkpoint directory found...skipping")
-            return
-
-
-        reader = FileSystemReader(checkdir)
-
-        with FSDP.state_dict_type(
-            model,
-            StateDictType.LOCAL_STATE_DICT,
-        ):
-            state_dict = model.state_dict()
-            load_state_dict(state_dict, reader)
-            model.load_state_dict(state_dict)
-
-        print(f"--> local state loaded on rank {rank}")
-
-        return
-
-
-def save_distributed_model_checkpoint(model, rank, cfg, epoch=1):
-    # distributed checkpoint saving
-
-    # confirm type of checkpoint and save
-    if cfg.checkpoint_type == StateDictType.LOCAL_STATE_DICT:
-        # create writer to current path
-        folder_name = (
-            cfg.dist_checkpoint_root_folder
-            + "/"
-            + cfg.dist_checkpoint_folder
-            + "-"
-            + cfg.model_name
-        )
-        save_dir = Path.cwd() / folder_name
-
-        writer = FileSystemWriter(
-            save_dir,
-        )
-
-        with FSDP.state_dict_type(
-            model,
-            StateDictType.LOCAL_STATE_DICT,
-        ):
-            state_dict = model.state_dict()
-       
-
-        # write out distributed checkpoint
-        save_state_dict(state_dict, writer)
-
-        return
+def load_sharded_model_single_gpu(model,model_path):
+    
+    reader = FileSystemReader(model_path)
+    
+    state_dict = {
+        "model": model.state_dict()
+    }
+    
+    dist_cp.load_state_dict(
+                state_dict=state_dict,
+                storage_reader= FileSystemReader(model_path),
+                no_dist=True,
+            )
+    
+    model.load_state_dict(state_dict["model"])
+    
+    print(f"Sharded state checkpoint loaded from {model_path}")
+    return model

+ 23 - 1
scripts/spellcheck_conf/wordlist.txt

@@ -1067,4 +1067,26 @@ chatGPT
 Llama
 Llama
 PEFT
 PEFT
 LORA
 LORA
-FSDP
+FSDP
+AuditNLG
+finetune
+fsdp
+ineference
+lora
+peft
+samsum
+vLLM
+TGI
+vLLM
+vLLM's
+OOM
+RTX
+SKU
+TPUs
+checkpointing
+enviroment
+fragmentations
+intra
+nightlies
+recenly
+uncomment

+ 1 - 0
utils/memory_utils.py

@@ -52,6 +52,7 @@ class MemoryTrace:
         cuda_info = torch.cuda.memory_stats()
         cuda_info = torch.cuda.memory_stats()
         self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
         self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
         self.cuda_malloc_retires = cuda_info.get("num_alloc_retries", 0)
         self.cuda_malloc_retires = cuda_info.get("num_alloc_retries", 0)
+        self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
         self.m_cuda_ooms = cuda_info.get("num_ooms", 0)
         self.m_cuda_ooms = cuda_info.get("num_ooms", 0)
         self.used = byte2gb(self.end - self.begin)
         self.used = byte2gb(self.end - self.begin)
         self.peaked = byte2gb(self.peak - self.begin)
         self.peaked = byte2gb(self.peak - self.begin)

+ 106 - 33
utils/train_utils.py

@@ -4,6 +4,7 @@
 import os
 import os
 import sys
 import sys
 from typing import List
 from typing import List
+import yaml
 
 
 import fire
 import fire
 import torch
 import torch
@@ -67,7 +68,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         scaler = ShardedGradScaler()
         scaler = ShardedGradScaler()
     elif train_config.use_fp16 and not train_config.enable_fsdp:
     elif train_config.use_fp16 and not train_config.enable_fsdp:
         scaler = torch.cuda.amp.GradScaler() 
         scaler = torch.cuda.amp.GradScaler() 
-        
+    if train_config.enable_fsdp:
+        world_size = int(os.environ["WORLD_SIZE"]) 
     train_prep = []
     train_prep = []
     train_loss = []
     train_loss = []
     val_prep = []
     val_prep = []
@@ -80,7 +82,6 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         with MemoryTrace() as memtrace:  # track the memory usage
         with MemoryTrace() as memtrace:  # track the memory usage
             model.train()
             model.train()
             total_loss = 0.0
             total_loss = 0.0
-            data_set_len = 0
             for step, batch in enumerate(tqdm(train_dataloader,colour="blue", desc=f"Training Epoch{epoch}")):
             for step, batch in enumerate(tqdm(train_dataloader,colour="blue", desc=f"Training Epoch{epoch}")):
                 for key in batch.keys():
                 for key in batch.keys():
                     if train_config.enable_fsdp:
                     if train_config.enable_fsdp:
@@ -90,8 +91,6 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                 loss = model(**batch).loss
                 loss = model(**batch).loss
                 loss = loss / gradient_accumulation_steps
                 loss = loss / gradient_accumulation_steps
                 total_loss += loss.detach().float()
                 total_loss += loss.detach().float()
-                first_key = next(iter(batch))
-                data_set_len += len(batch[first_key])
                 if train_config.use_fp16:
                 if train_config.use_fp16:
                     # if fp16 is enabled, use gradient scaler to handle gradient update
                     # if fp16 is enabled, use gradient scaler to handle gradient update
                     scaler.scale(loss).backward()
                     scaler.scale(loss).backward()
@@ -122,12 +121,19 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         
         
         train_prep.append(train_perplexity)
         train_prep.append(train_perplexity)
         train_loss.append(train_epoch_loss)
         train_loss.append(train_epoch_loss)
-        
-        print(f"Max CUDA memory allocated was {memtrace.peak} GB")
-        print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
-        print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
-        print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
-        print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
+        if train_config.enable_fsdp:
+            if rank==0:
+                print(f"Max CUDA memory allocated was {memtrace.peak} GB")
+                print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
+                print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
+                print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
+                print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
+        else:
+            print(f"Max CUDA memory allocated was {memtrace.peak} GB")
+            print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
+            print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
+            print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
+            print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
         
         
         # Update the learning rate as needed
         # Update the learning rate as needed
         lr_scheduler.step()
         lr_scheduler.step()
@@ -135,35 +141,53 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         if train_config.run_validation:
         if train_config.run_validation:
             eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, rank, tokenizer)   
             eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, rank, tokenizer)   
             if train_config.save_model and eval_epoch_loss < best_val_loss:
             if train_config.save_model and eval_epoch_loss < best_val_loss:
-                
-                if  train_config.use_peft:
-                    
-                    print(f"we are in the saving the PEFT modules")
-                    model.save_pretrained(train_config.output_dir)   
-                    print(f"PEFT modules are saved in {train_config.output_dir} directory")
-                    
+                if train_config.enable_fsdp:
+                    dist.barrier()
+                if train_config.use_peft:
+                    if train_config.enable_fsdp:
+                        if rank==0:
+                            print(f"we are about to save the PEFT modules")
+                    else:
+                        print(f"we are about to save the PEFT modules")
+                    model.save_pretrained(train_config.output_dir)  
+                    if train_config.enable_fsdp:
+                        if rank==0: 
+                            print(f"PEFT modules are saved in {train_config.output_dir} directory")
+                    else:
+                        print(f"PEFT modules are saved in {train_config.output_dir} directory")
+                        
                 else:
                 else:
                     if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
                     if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
                         
                         
                         model_checkpointing.save_model_checkpoint(
                         model_checkpointing.save_model_checkpoint(
-                            model, optimizer, rank, train_config, epoch=1
+                            model, optimizer, rank, train_config, epoch=epoch
                         )
                         )
                     elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
                     elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
-                        print(" we are about to save the models *******")
+                        print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
+                        print("=====================================================")
                         
                         
                         model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config)
                         model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config)
                         if train_config.save_optimizer:
                         if train_config.save_optimizer:
                             model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
                             model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
+                            print(" Saving the FSDP model checkpoints qnd optimizer using SHARDED_STATE_DICT")
+                            print("=====================================================")
 
 
                     if not train_config.use_peft and  train_config.save_optimizer:
                     if not train_config.use_peft and  train_config.save_optimizer:
                         model_checkpointing.save_optimizer_checkpoint(
                         model_checkpointing.save_optimizer_checkpoint(
-                            model, optimizer, rank, train_config, epoch=1
-                        )   
-                                
+                            model, optimizer, rank, train_config, epoch=epoch
+                        )
+                        print(" Saving the FSDP model checkpoints qnd optimizer using FULL_STATE_DICT")
+                        print("=====================================================")                     
+                if train_config.enable_fsdp:
+                    dist.barrier()
             
             
-            if local_rank == 0 and eval_epoch_loss < best_val_loss:
+            if eval_epoch_loss < best_val_loss:
                 best_val_loss = eval_epoch_loss
                 best_val_loss = eval_epoch_loss
-                print(f"best eval loss on epoch {epoch} is {best_val_loss}")
+                if train_config.enable_fsdp:
+                    if rank==0:
+                        print(f"best eval loss on epoch {epoch} is {best_val_loss}")
+                else:
+                    print(f"best eval loss on epoch {epoch} is {best_val_loss}")
             val_loss.append(best_val_loss)
             val_loss.append(best_val_loss)
             val_prep.append(eval_ppl)
             val_prep.append(eval_ppl)
         
         
@@ -171,7 +195,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}")
         print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}")
         lr_scheduler.step()
         lr_scheduler.step()
     avg_epoch_time = sum(epoch_times)/len(epoch_times)
     avg_epoch_time = sum(epoch_times)/len(epoch_times)
-    print("avg epoch time is {avg_epoch_time}")
+    print(f"avg epoch time is {avg_epoch_time}")
     print("==========================================")
     print("==========================================")
     avg_train_prep = sum(train_prep)/len(train_prep)
     avg_train_prep = sum(train_prep)/len(train_prep)
     avg_train_loss = sum(train_loss)/len(train_loss)
     avg_train_loss = sum(train_loss)/len(train_loss)
@@ -185,7 +209,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         results['avg_eval_prep'] = avg_eval_prep
         results['avg_eval_prep'] = avg_eval_prep
         results['avg_eval_loss'] = avg_eval_loss
         results['avg_eval_loss'] = avg_eval_loss
         
         
-
+    #saving the training params including fsdp setting for reference.
+    if train_config.enable_fsdp and not train_config.use_peft:
+        save_train_params(train_config, fsdp_config, rank)
+        
     return results
     return results
 
 
 def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
 def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
@@ -200,10 +227,11 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
     
     
     Returns: eval_ppl, eval_epoch_loss
     Returns: eval_ppl, eval_epoch_loss
     """
     """
+    if train_config.enable_fsdp:
+        world_size = int(os.environ["WORLD_SIZE"]) 
     model.eval()
     model.eval()
     eval_preds = []
     eval_preds = []
     eval_loss = 0.0  # Initialize evaluation loss
     eval_loss = 0.0  # Initialize evaluation loss
-    eval_dataset_len = 0
     with MemoryTrace() as memtrace:
     with MemoryTrace() as memtrace:
         for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch")):
         for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch")):
             for key in batch.keys():
             for key in batch.keys():
@@ -217,9 +245,6 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
                 outputs = model(**batch)
                 outputs = model(**batch)
                 loss = outputs.loss
                 loss = outputs.loss
                 eval_loss += loss.detach().float()
                 eval_loss += loss.detach().float()
-                first_key = next(iter(batch))
-                eval_dataset_len+= len(batch[first_key])
-                
             # Decode predictions and add to evaluation predictions list
             # Decode predictions and add to evaluation predictions list
             preds = torch.argmax(outputs.logits, -1)
             preds = torch.argmax(outputs.logits, -1)
             eval_preds.extend(
             eval_preds.extend(
@@ -233,11 +258,17 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
     
     
     # Compute average loss and perplexity
     # Compute average loss and perplexity
     eval_epoch_loss = eval_loss / len(eval_dataloader)
     eval_epoch_loss = eval_loss / len(eval_dataloader)
-    eval_epoch_loss = eval_epoch_loss/world_size
+    if train_config.enable_fsdp:
+        eval_epoch_loss = eval_epoch_loss/world_size
     eval_ppl = torch.exp(eval_epoch_loss)
     eval_ppl = torch.exp(eval_epoch_loss)
     
     
     # Print evaluation metrics
     # Print evaluation metrics
-    print(f" {eval_ppl=} {eval_epoch_loss=}")
+    if train_config.enable_fsdp:
+        if local_rank==0:
+            print(f" {eval_ppl=} {eval_epoch_loss=}")
+    else:
+        print(f" {eval_ppl=} {eval_epoch_loss=}")
+        
     return eval_ppl, eval_epoch_loss
     return eval_ppl, eval_epoch_loss
 
 
 def freeze_transformer_layers(model, num_layer):
 def freeze_transformer_layers(model, num_layer):
@@ -262,7 +293,10 @@ def setup_environ_flags(rank):
     """Set environment flags for debugging purposes"""
     """Set environment flags for debugging purposes"""
     os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1)
     os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1)
     os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1)
     os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1)
-    os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
+    # os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
+    # This flag will help with CUDA memory fragmentations that can lead into OOM in some cases.
+    # Note this is only availble in PyTorch Nighlies (as of July 30 2023)
+    # os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True' 
     if rank == 0:
     if rank == 0:
         print(f"--> Running with torch dist debug set to detail")
         print(f"--> Running with torch dist debug set to detail")
 
 
@@ -336,3 +370,42 @@ def get_policies(cfg, rank):
             print(f"bFloat16 support not present. Using FP32, and not mixed precision")
             print(f"bFloat16 support not present. Using FP32, and not mixed precision")
     wrapping_policy = get_llama_wrapper()
     wrapping_policy = get_llama_wrapper()
     return mixed_precision_policy, wrapping_policy
     return mixed_precision_policy, wrapping_policy
+
+def save_train_params(train_config, fsdp_config, rank):
+    """
+    This function saves the train_config and FSDP config into a train_params.yaml.
+    This will be used by converter script in the inference folder to fetch the HF model name or path.
+    It also would be hepful as a log for future references.
+    """
+    # Convert the train_config and fsdp_config objects to dictionaries, 
+    # converting all values to strings to ensure they can be serialized into a YAML file
+    train_config_dict = {k: str(v) for k, v in vars(train_config).items() if not k.startswith('__')}
+    fsdp_config_dict = {k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith('__')}
+    # Merge the two dictionaries into one
+    train_params_dict = {**train_config_dict, **fsdp_config_dict}
+    # Construct the folder name (follwoing FSDP checkpointing style) using properties of the train_config object
+    folder_name = (
+    train_config.dist_checkpoint_root_folder
+    + "/"
+    + train_config.dist_checkpoint_folder
+    + "-"
+    + train_config.model_name
+    )
+
+    save_dir = Path.cwd() / folder_name
+    # If the directory does not exist, create it
+    if not os.path.exists(save_dir):
+        os.makedirs(save_dir)
+    # Convert the dictionary to a YAML string
+    config_yaml = yaml.dump(train_params_dict, indent=4)
+    file_name = os.path.join(save_dir,'train_params.yaml')
+
+    # Check if there's a directory with the same name as the file
+    if os.path.isdir(file_name):
+        print(f"Error: {file_name} is a directory, not a file.")
+    else:
+        # Write the YAML string to the file
+        with open(file_name, 'w') as f:
+            f.write(config_yaml)
+        if rank==0:
+            print(f"training params are saved in {file_name}")