瀏覽代碼

Merge branch 'main' of https://github.com/facebookresearch/llama-recipes into optimizer_overlap

Hamid Shojanazeri 1 年之前
父節點
當前提交
7f23559bdd

+ 1 - 2
README.md

@@ -39,7 +39,7 @@ Llama 2 is a new technology that carries potential risks with use. Testing condu
 To run the examples, make sure to install the requirements using
 
 ```bash
-
+# python 3.9 or higher recommended
 pip install -r requirements.txt
 
 ```
@@ -55,7 +55,6 @@ Given that the original checkpoint resides under models/7B you can install all r
 ## Install HuggingFace Transformers from source
 pip freeze | grep transformers ## verify it is version 4.31.0 or higher
 
-```bash
 git clone git@github.com:huggingface/transformers.git
 cd transformers
 pip install protobuf

File diff suppressed because it is too large
+ 13 - 0
docs/FAQ.md


+ 25 - 0
docs/inference.md

@@ -34,6 +34,31 @@ The inference folder also includes a chat completion example, that adds built-in
 python inference/chat_completion.py --model_name "PATH/TO/MODEL/7B/" --prompt_file inference/chats.json  --quantization --use_auditnlg
 
 ```
+## Loading back FSDP checkpoints
+
+In case you have fine-tuned your model with pure FSDP and saved the checkpoints with "SHARDED_STATE_DICT" as shown [here](../configs/fsdp.py), you can use this converter script to convert the FSDP Sharded checkpoints into HuggingFace checkpoints. This enables you to use the inference script normally as mentioned above.
+**To convert the checkpoint use the following command**:
+
+This is helpful if you have fine-tuned you model using FSDP only as follows:
+
+```bash
+torchrun --nnodes 1 --nproc_per_node 8  llama_finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned --pure_bf16 
+```
+Then convert your FSDP checkpoint to HuggingFace checkpoints using:
+```bash
+ python inference/checkpoint_converter_fsdp_hf.py --fsdp_checkpoint_path  PATH/to/FSDP/Checkpoints --consolidated_model_path PATH/to/save/checkpoints --HF_model_path_or_name PATH/or/HF/model_name
+
+ # --HF_model_path_or_name specifies the HF Llama model name or path where it has config.json and tokenizer.json
+ ```
+By default, training parameter are saved in `train_params.yaml` in the path where FSDP checkpoints are saved, in the converter script we frist try to find the HugingFace model name used in the fine-tuning to load the model with configs from there, if not found user need to provide it.
+
+Then run inference using:
+
+```bash
+python inference/inference.py --model_name <training_config.output_dir> --prompt_file <test_prompt_file> 
+
+```
+
 
 ## Other Inference Options
 

+ 1 - 1
inference/README.md

@@ -2,7 +2,7 @@
 
 This folder contains inference examples for Llama 2. So far, we have provided support for three methods of inference:
 
-1. [inference script](inference.py) script provides support for Hugging Face accelerate and PEFT fine tuned models.
+1. [inference script](inference.py) script provides support for Hugging Face accelerate, PEFT and FSDP fine tuned models.
 
 2. [vLLM_inference.py](vLLM_inference.py) script takes advantage of vLLM's paged attention concept for low latency.
 

+ 63 - 0
inference/checkpoint_converter_fsdp_hf.py

@@ -0,0 +1,63 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+# from accelerate import init_empty_weights, load_checkpoint_and_dispatch
+
+import fire
+import torch
+import os
+import sys
+import yaml
+from transformers import LlamaTokenizer
+from model_utils import  load_llama_from_config
+# Get the current file's directory
+current_directory = os.path.dirname(os.path.abspath(__file__))
+
+# Get the parent directory
+parent_directory = os.path.dirname(current_directory)
+
+# Append the parent directory to sys.path
+sys.path.append(parent_directory)
+from model_checkpointing import load_sharded_model_single_gpu
+
+def main(
+    fsdp_checkpoint_path="", # Path to FSDP Sharded model checkpoints
+    consolidated_model_path="", # Path to save the HF converted model checkpoints
+    HF_model_path_or_name="" # Path/ name of the HF model that include config.json and tokenizer_config.json (e.g. meta-llama/Llama-2-7b-chat-hf)
+    ):
+    
+    try:
+        file_name = 'train_params.yaml'
+        # Combine the directory and file name to create the full path
+        train_params_path = os.path.join(fsdp_checkpoint_path, file_name)
+        # Open the file
+        with open(train_params_path, 'r') as file:
+            # Load the YAML data
+            data = yaml.safe_load(file)
+
+            # Access the 'model_name' field
+            HF_model_path_or_name = data.get('model_name')
+
+            print(f"Model name: {HF_model_path_or_name}")
+    except FileNotFoundError:
+        print(f"The file {train_params_path} does not exist.")
+        HF_model_path_or_name = input("Please enter the model name: ")
+        print(f"Model name: {HF_model_path_or_name}")
+    except Exception as e:
+        print(f"An error occurred: {e}")
+        
+        
+    #load the HF model definition from config
+    model_def = load_llama_from_config(HF_model_path_or_name)
+    print("model is loaded from config")
+    #load the FSDP sharded checkpoints into the model
+    model = load_sharded_model_single_gpu(model_def, fsdp_checkpoint_path)
+    print("model is loaded from FSDP checkpoints")
+    #loading the tokenizer form the  model_path
+    tokenizer = LlamaTokenizer.from_pretrained(HF_model_path_or_name)
+    tokenizer.save_pretrained(consolidated_model_path)
+    #save the FSDP sharded checkpoints in HF format
+    model.save_pretrained(consolidated_model_path)
+    print(f"HuggingFace model checkpoints has been saved in {consolidated_model_path}")
+if __name__ == "__main__":
+    fire.Fire(main)

+ 1 - 2
inference/inference.py

@@ -12,8 +12,7 @@ from typing import List
 
 from transformers import LlamaTokenizer
 from safety_utils import get_safety_checker
-from model_utils import load_model, load_peft_model
-
+from model_utils import load_model, load_peft_model, load_llama_from_config
 
 def main(
     model_name,

+ 10 - 2
inference/model_utils.py

@@ -2,7 +2,7 @@
 # This software may be used and distributed according to the terms of the GNU General Public License version 3.
 
 from peft import PeftModel
-from transformers import LlamaForCausalLM
+from transformers import LlamaForCausalLM, LlamaConfig
 
 # Function to load the main model for text generation
 def load_model(model_name, quantization):
@@ -19,4 +19,12 @@ def load_model(model_name, quantization):
 # Function to load the PeftModel for performance optimization
 def load_peft_model(model, peft_model):
     peft_model = PeftModel.from_pretrained(model, peft_model)
-    return peft_model
+    return peft_model
+
+# Loading the model from config to load FSDP checkpoints into that
+def load_llama_from_config(config_path):
+    model_config = LlamaConfig.from_pretrained(config_path) 
+    model = LlamaForCausalLM(config=model_config)
+    return model
+    
+    

+ 1 - 2
model_checkpointing/__init__.py

@@ -4,10 +4,9 @@
 from .checkpoint_handler import (
     load_model_checkpoint,
     save_model_checkpoint,
-    save_distributed_model_checkpoint,
-    load_distributed_model_checkpoint,
     load_optimizer_checkpoint,
     save_optimizer_checkpoint,
     save_model_and_optimizer_sharded,
     load_model_sharded,
+    load_sharded_model_single_gpu
 )

+ 49 - 88
model_checkpointing/checkpoint_handler.py

@@ -44,7 +44,7 @@ def get_date_of_run():
 fullstate_save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
 
 
-def load_model_sharded(model, rank, cfg, verbose=True):
+def load_model_sharded(model, rank, cfg):
     # torch.manual_seed(103)
     folder_name = (
         cfg.dist_checkpoint_root_folder
@@ -83,7 +83,7 @@ def load_model_sharded(model, rank, cfg, verbose=True):
         print(f"Sharded state checkpoint loaded from {load_dir}")
 
 
-def save_model_and_optimizer_sharded(model, rank, cfg,optim=None, verbose=True):
+def save_model_and_optimizer_sharded(model, rank, cfg,optim=None):
     """save model and optimizer via sharded_state_dict to save_dir"""
     
     folder_name = (
@@ -142,7 +142,14 @@ def save_model_checkpoint(
     if rank == 0:
         print(f"--> saving model ...")
         # create save path
-        save_dir = Path.cwd() / cfg.checkpoint_folder
+        folder_name = (
+        cfg.dist_checkpoint_root_folder
+        + "/"
+        + cfg.dist_checkpoint_folder
+        + "-"
+        + cfg.model_name
+        )
+        save_dir = Path.cwd() / folder_name
         save_dir.mkdir(parents=True, exist_ok=True)
         save_name = cfg.model_name + "-" + str(epoch) + ".pt"
         save_full_path = str(save_dir) + "/" + save_name
@@ -150,12 +157,12 @@ def save_model_checkpoint(
         # save model
         torch.save(cpu_state, save_full_path)
 
-        if cfg.verbose:
-            print(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n")
+        
+        print(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n")
       
 
 
-def load_model_checkpoint(model, rank, cfg, verbose=True):
+def load_model_checkpoint(model, rank, cfg):
     """load local checkpoint to rank0 cpu
     must be called * before * passing to FSDP"""
 
@@ -178,8 +185,8 @@ def load_model_checkpoint(model, rank, cfg, verbose=True):
     # integrate into loaded model
     model.load_state_dict(model_checkpoint)
 
-    if cfg.verbose:
-        print(f"model checkpoint loaded to rank0 cpu")
+    
+    print(f"model checkpoint loaded to rank0 cpu")
 
 
 def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1):
@@ -192,15 +199,22 @@ def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1):
 
     optim_state = FSDP.full_optim_state_dict(model, optimizer)
 
-    if cfg.verbose:
-        print(f"optim state dict ready on {rank} and len of {len(optim_state)}\n")
+    
+    print(f"optim state dict ready on {rank} and len of {len(optim_state)}\n")
 
     if rank == 0:
-        save_dir = Path.cwd() / cfg.checkpoint_folder
+        folder_name = (
+        cfg.dist_checkpoint_root_folder
+        + "/"
+        + cfg.dist_checkpoint_folder
+        + "-"
+        + cfg.model_name
+        )
+        save_dir = Path.cwd() / folder_name
         save_dir.mkdir(parents=True, exist_ok=True)
 
         opt_save_name = (
-            cfg.optimizer_name + "-" + cfg.model_name + "-" + str(epoch) + ".pt"
+            "optimizer" + "-" + cfg.model_name + "-" + str(epoch) + ".pt"
         )
         opt_save_full_path = save_dir / opt_save_name
 
@@ -211,96 +225,43 @@ def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1):
         print(f"--> saved {opt_save_full_path} to disk")
 
 
-def load_optimizer_checkpoint(model, optimizer, rank, cfg):
+def load_optimizer_checkpoint(model, optimizer_checkpoint_path, rank):
     """load an fsdp optimizer full_state checkpoint using scatter method
     this ensures only rank 0 loads the optimizer state dict and scatters to other ranks
     """
 
-    opt_file_path = Path.cwd() / cfg.checkpoint_folder / cfg.optimizer_checkpoint_file
 
-    if not opt_file_path.is_file():
+    if not optimizer_checkpoint_path.is_file():
         print(
-            f"warning - optimizer checkpoint not present {opt_file_path}. Returning. "
+            f"warning - optimizer checkpoint not present {optimizer_checkpoint_path}. Returning. "
         )
         return
 
     full_osd = None
 
     if rank == 0:
-        full_osd = torch.load(opt_file_path)
-
-        if cfg.verbose:
-            print(f"loaded full osd on rank 0")
+        full_osd = torch.load(optimizer_checkpoint_path)
 
     # called from all ranks, though only rank0 has a valid param for full_osd
     sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, model)
 
-    if cfg.verbose:
-        print(f"optimizer shard loaded on rank {rank}")
-
-
-
-def load_distributed_model_checkpoint(model, rank, cfg):
-    if cfg.checkpoint_type == StateDictType.LOCAL_STATE_DICT:
-        print(f"loading distributed checkpoint, rank {rank}...")
-        folder_name = (
-            cfg.dist_checkpoint_root_folder
-            + "/"
-            + cfg.dist_checkpoint_folder
-            + "-"
-            + cfg.model_name
-        )
+    print(f"optimizer shard loaded on rank {rank}")
 
-        checkdir = Path.cwd() / folder_name
-
-        if not checkdir.exists():
-            if rank == 0:
-                print(f"No checkpoint directory found...skipping")
-            return
-
-
-        reader = FileSystemReader(checkdir)
-
-        with FSDP.state_dict_type(
-            model,
-            StateDictType.LOCAL_STATE_DICT,
-        ):
-            state_dict = model.state_dict()
-            load_state_dict(state_dict, reader)
-            model.load_state_dict(state_dict)
-
-        print(f"--> local state loaded on rank {rank}")
-
-        return
-
-
-def save_distributed_model_checkpoint(model, rank, cfg, epoch=1):
-    # distributed checkpoint saving
-
-    # confirm type of checkpoint and save
-    if cfg.checkpoint_type == StateDictType.LOCAL_STATE_DICT:
-        # create writer to current path
-        folder_name = (
-            cfg.dist_checkpoint_root_folder
-            + "/"
-            + cfg.dist_checkpoint_folder
-            + "-"
-            + cfg.model_name
-        )
-        save_dir = Path.cwd() / folder_name
-
-        writer = FileSystemWriter(
-            save_dir,
-        )
-
-        with FSDP.state_dict_type(
-            model,
-            StateDictType.LOCAL_STATE_DICT,
-        ):
-            state_dict = model.state_dict()
-       
-
-        # write out distributed checkpoint
-        save_state_dict(state_dict, writer)
-
-        return
+def load_sharded_model_single_gpu(model,model_path):
+    
+    reader = FileSystemReader(model_path)
+    
+    state_dict = {
+        "model": model.state_dict()
+    }
+    
+    dist_cp.load_state_dict(
+                state_dict=state_dict,
+                storage_reader= FileSystemReader(model_path),
+                no_dist=True,
+            )
+    
+    model.load_state_dict(state_dict["model"])
+    
+    print(f"Sharded state checkpoint loaded from {model_path}")
+    return model

+ 23 - 1
scripts/spellcheck_conf/wordlist.txt

@@ -1067,4 +1067,26 @@ chatGPT
 Llama
 PEFT
 LORA
-FSDP
+FSDP
+AuditNLG
+finetune
+fsdp
+ineference
+lora
+peft
+samsum
+vLLM
+TGI
+vLLM
+vLLM's
+OOM
+RTX
+SKU
+TPUs
+checkpointing
+enviroment
+fragmentations
+intra
+nightlies
+recenly
+uncomment

+ 1 - 0
utils/memory_utils.py

@@ -52,6 +52,7 @@ class MemoryTrace:
         cuda_info = torch.cuda.memory_stats()
         self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
         self.cuda_malloc_retires = cuda_info.get("num_alloc_retries", 0)
+        self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
         self.m_cuda_ooms = cuda_info.get("num_ooms", 0)
         self.used = byte2gb(self.end - self.begin)
         self.peaked = byte2gb(self.peak - self.begin)

+ 106 - 33
utils/train_utils.py

@@ -4,6 +4,7 @@
 import os
 import sys
 from typing import List
+import yaml
 
 import fire
 import torch
@@ -67,7 +68,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         scaler = ShardedGradScaler()
     elif train_config.use_fp16 and not train_config.enable_fsdp:
         scaler = torch.cuda.amp.GradScaler() 
-        
+    if train_config.enable_fsdp:
+        world_size = int(os.environ["WORLD_SIZE"]) 
     train_prep = []
     train_loss = []
     val_prep = []
@@ -80,7 +82,6 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         with MemoryTrace() as memtrace:  # track the memory usage
             model.train()
             total_loss = 0.0
-            data_set_len = 0
             for step, batch in enumerate(tqdm(train_dataloader,colour="blue", desc=f"Training Epoch{epoch}")):
                 for key in batch.keys():
                     if train_config.enable_fsdp:
@@ -90,8 +91,6 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                 loss = model(**batch).loss
                 loss = loss / gradient_accumulation_steps
                 total_loss += loss.detach().float()
-                first_key = next(iter(batch))
-                data_set_len += len(batch[first_key])
                 if train_config.use_fp16:
                     # if fp16 is enabled, use gradient scaler to handle gradient update
                     scaler.scale(loss).backward()
@@ -122,12 +121,19 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         
         train_prep.append(train_perplexity)
         train_loss.append(train_epoch_loss)
-        
-        print(f"Max CUDA memory allocated was {memtrace.peak} GB")
-        print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
-        print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
-        print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
-        print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
+        if train_config.enable_fsdp:
+            if rank==0:
+                print(f"Max CUDA memory allocated was {memtrace.peak} GB")
+                print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
+                print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
+                print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
+                print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
+        else:
+            print(f"Max CUDA memory allocated was {memtrace.peak} GB")
+            print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
+            print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
+            print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
+            print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
         
         # Update the learning rate as needed
         lr_scheduler.step()
@@ -135,35 +141,53 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         if train_config.run_validation:
             eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, rank, tokenizer)   
             if train_config.save_model and eval_epoch_loss < best_val_loss:
-                
-                if  train_config.use_peft:
-                    
-                    print(f"we are in the saving the PEFT modules")
-                    model.save_pretrained(train_config.output_dir)   
-                    print(f"PEFT modules are saved in {train_config.output_dir} directory")
-                    
+                if train_config.enable_fsdp:
+                    dist.barrier()
+                if train_config.use_peft:
+                    if train_config.enable_fsdp:
+                        if rank==0:
+                            print(f"we are about to save the PEFT modules")
+                    else:
+                        print(f"we are about to save the PEFT modules")
+                    model.save_pretrained(train_config.output_dir)  
+                    if train_config.enable_fsdp:
+                        if rank==0: 
+                            print(f"PEFT modules are saved in {train_config.output_dir} directory")
+                    else:
+                        print(f"PEFT modules are saved in {train_config.output_dir} directory")
+                        
                 else:
                     if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
                         
                         model_checkpointing.save_model_checkpoint(
-                            model, optimizer, rank, train_config, epoch=1
+                            model, optimizer, rank, train_config, epoch=epoch
                         )
                     elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
-                        print(" we are about to save the models *******")
+                        print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
+                        print("=====================================================")
                         
                         model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config)
                         if train_config.save_optimizer:
                             model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
+                            print(" Saving the FSDP model checkpoints qnd optimizer using SHARDED_STATE_DICT")
+                            print("=====================================================")
 
                     if not train_config.use_peft and  train_config.save_optimizer:
                         model_checkpointing.save_optimizer_checkpoint(
-                            model, optimizer, rank, train_config, epoch=1
-                        )   
-                                
+                            model, optimizer, rank, train_config, epoch=epoch
+                        )
+                        print(" Saving the FSDP model checkpoints qnd optimizer using FULL_STATE_DICT")
+                        print("=====================================================")                     
+                if train_config.enable_fsdp:
+                    dist.barrier()
             
-            if local_rank == 0 and eval_epoch_loss < best_val_loss:
+            if eval_epoch_loss < best_val_loss:
                 best_val_loss = eval_epoch_loss
-                print(f"best eval loss on epoch {epoch} is {best_val_loss}")
+                if train_config.enable_fsdp:
+                    if rank==0:
+                        print(f"best eval loss on epoch {epoch} is {best_val_loss}")
+                else:
+                    print(f"best eval loss on epoch {epoch} is {best_val_loss}")
             val_loss.append(best_val_loss)
             val_prep.append(eval_ppl)
         
@@ -171,7 +195,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}")
         lr_scheduler.step()
     avg_epoch_time = sum(epoch_times)/len(epoch_times)
-    print("avg epoch time is {avg_epoch_time}")
+    print(f"avg epoch time is {avg_epoch_time}")
     print("==========================================")
     avg_train_prep = sum(train_prep)/len(train_prep)
     avg_train_loss = sum(train_loss)/len(train_loss)
@@ -185,7 +209,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         results['avg_eval_prep'] = avg_eval_prep
         results['avg_eval_loss'] = avg_eval_loss
         
-
+    #saving the training params including fsdp setting for reference.
+    if train_config.enable_fsdp and not train_config.use_peft:
+        save_train_params(train_config, fsdp_config, rank)
+        
     return results
 
 def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
@@ -200,10 +227,11 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
     
     Returns: eval_ppl, eval_epoch_loss
     """
+    if train_config.enable_fsdp:
+        world_size = int(os.environ["WORLD_SIZE"]) 
     model.eval()
     eval_preds = []
     eval_loss = 0.0  # Initialize evaluation loss
-    eval_dataset_len = 0
     with MemoryTrace() as memtrace:
         for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch")):
             for key in batch.keys():
@@ -217,9 +245,6 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
                 outputs = model(**batch)
                 loss = outputs.loss
                 eval_loss += loss.detach().float()
-                first_key = next(iter(batch))
-                eval_dataset_len+= len(batch[first_key])
-                
             # Decode predictions and add to evaluation predictions list
             preds = torch.argmax(outputs.logits, -1)
             eval_preds.extend(
@@ -233,11 +258,17 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
     
     # Compute average loss and perplexity
     eval_epoch_loss = eval_loss / len(eval_dataloader)
-    eval_epoch_loss = eval_epoch_loss/world_size
+    if train_config.enable_fsdp:
+        eval_epoch_loss = eval_epoch_loss/world_size
     eval_ppl = torch.exp(eval_epoch_loss)
     
     # Print evaluation metrics
-    print(f" {eval_ppl=} {eval_epoch_loss=}")
+    if train_config.enable_fsdp:
+        if local_rank==0:
+            print(f" {eval_ppl=} {eval_epoch_loss=}")
+    else:
+        print(f" {eval_ppl=} {eval_epoch_loss=}")
+        
     return eval_ppl, eval_epoch_loss
 
 def freeze_transformer_layers(model, num_layer):
@@ -262,7 +293,10 @@ def setup_environ_flags(rank):
     """Set environment flags for debugging purposes"""
     os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1)
     os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1)
-    os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
+    # os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
+    # This flag will help with CUDA memory fragmentations that can lead into OOM in some cases.
+    # Note this is only availble in PyTorch Nighlies (as of July 30 2023)
+    # os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True' 
     if rank == 0:
         print(f"--> Running with torch dist debug set to detail")
 
@@ -336,3 +370,42 @@ def get_policies(cfg, rank):
             print(f"bFloat16 support not present. Using FP32, and not mixed precision")
     wrapping_policy = get_llama_wrapper()
     return mixed_precision_policy, wrapping_policy
+
+def save_train_params(train_config, fsdp_config, rank):
+    """
+    This function saves the train_config and FSDP config into a train_params.yaml.
+    This will be used by converter script in the inference folder to fetch the HF model name or path.
+    It also would be hepful as a log for future references.
+    """
+    # Convert the train_config and fsdp_config objects to dictionaries, 
+    # converting all values to strings to ensure they can be serialized into a YAML file
+    train_config_dict = {k: str(v) for k, v in vars(train_config).items() if not k.startswith('__')}
+    fsdp_config_dict = {k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith('__')}
+    # Merge the two dictionaries into one
+    train_params_dict = {**train_config_dict, **fsdp_config_dict}
+    # Construct the folder name (follwoing FSDP checkpointing style) using properties of the train_config object
+    folder_name = (
+    train_config.dist_checkpoint_root_folder
+    + "/"
+    + train_config.dist_checkpoint_folder
+    + "-"
+    + train_config.model_name
+    )
+
+    save_dir = Path.cwd() / folder_name
+    # If the directory does not exist, create it
+    if not os.path.exists(save_dir):
+        os.makedirs(save_dir)
+    # Convert the dictionary to a YAML string
+    config_yaml = yaml.dump(train_params_dict, indent=4)
+    file_name = os.path.join(save_dir,'train_params.yaml')
+
+    # Check if there's a directory with the same name as the file
+    if os.path.isdir(file_name):
+        print(f"Error: {file_name} is a directory, not a file.")
+    else:
+        # Write the YAML string to the file
+        with open(file_name, 'w') as f:
+            f.write(config_yaml)
+        if rank==0:
+            print(f"training params are saved in {file_name}")