Browse Source

adding optimizer overlap

Hamid Shojanazeri 1 year ago
parent
commit
e87e7edf61
3 changed files with 54 additions and 5 deletions
  1. 1 0
      configs/fsdp.py
  2. 36 2
      llama_finetuning.py
  3. 17 3
      utils/train_utils.py

+ 1 - 0
configs/fsdp.py

@@ -15,6 +15,7 @@ class fsdp_config:
     fsdp_activation_checkpointing: bool=True
     pure_bf16: bool = False
     optimizer: str= "AdamW"
+    optimizer_overlap: bool=False
     
     
     

+ 36 - 2
llama_finetuning.py

@@ -127,15 +127,32 @@ def main(**kwargs):
 
         mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
         my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)
-   
-        model = FSDP(
+        if fsdp_config.optimizer_overlap:
+            try:
+                from torch.distributed.optim import _apply_optimizer_in_backward
+            except ImportError:
+                # Handle the ImportError here, such as providing an alternative implementation or an error message.
+                print("The required module 'torch.distributed.optim' is not available.")
+            model = FSDP(
             model,
             auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
             mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
             sharding_strategy=fsdp_config.sharding_strategy,
             device_id=torch.cuda.current_device(),
             limit_all_gathers=True,
+            use_orig_params=True,
         )
+            
+        else:    
+            model = FSDP(
+                model,
+                auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
+                mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
+                sharding_strategy=fsdp_config.sharding_strategy,
+                device_id=torch.cuda.current_device(),
+                limit_all_gathers=True,
+            )
+        
         if fsdp_config.fsdp_activation_checkpointing:
             policies.apply_fsdp_checkpointing(model)
     elif not train_config.quantization and not train_config.enable_fsdp:
@@ -214,6 +231,23 @@ def main(**kwargs):
             lr=train_config.lr,
             weight_decay=0.0,
         )
+    # if fsdp_config.optimizer_overlap:
+    #     optim_kwargs = {"lr": train_config.lr}
+    #     _apply_optimizer_in_backward(
+    #         optimizer_class=optim.AdamW,
+    #         params=model.parameters(),
+    #         optimizer_kwargs=optim_kwargs,
+    #         register_hook=False,
+    #     )
+    #     for p in model.parameters():
+    #         assert hasattr(p, "_in_backward_optimizers")
+    #     optim_kwargs = {"lr": train_config.lr, "weight_decay":0.0}
+    #     optimizer = optim.AdamW(
+    #         model.parameters(),
+    #         **optim_kwargs
+    #     )
+        
+        
     scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
 
     # Start the training process

+ 17 - 3
utils/train_utils.py

@@ -10,6 +10,7 @@ import torch
 import transformers
 from datasets import load_dataset
 from tqdm import tqdm
+import time
 """
 Unused imports:
 import torch.nn as nn
@@ -73,7 +74,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
     val_loss =[]
     results = {}
     best_val_loss = float("inf")
+    epoch_times=[]
     for epoch in range(train_config.num_epochs):
+        start_epoch = time.perf_counter()
         with MemoryTrace() as memtrace:  # track the memory usage
             model.train()
             total_loss = 0.0
@@ -104,10 +107,17 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                         optimizer.zero_grad()
                         
                 print(f"\n step {step} is completed and loss is {loss.detach().float()}")
+        end_epoch = time.perf_counter()
+        epoch_time = end_epoch- start_epoch
+        print(f"epoch time is {epoch_time}")
+        print("==================================================")
+        epoch_times.append(epoch_time)
         # Reducing total_loss across all devices if there's more than one CUDA device
         if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
             dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
-        train_epoch_loss = total_loss / data_set_len
+            world_size = int(os.environ["WORLD_SIZE"])
+        train_epoch_loss = total_loss / len(train_dataloader)
+        train_epoch_loss = train_epoch_loss/world_size
         train_perplexity = torch.exp(train_epoch_loss)
         
         train_prep.append(train_perplexity)
@@ -160,7 +170,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         
         print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}")
         lr_scheduler.step()
-
+    avg_epoch_time = sum(epoch_times)/len(epoch_times)
+    print("avg epoch time is {avg_epoch_time}")
+    print("==========================================")
     avg_train_prep = sum(train_prep)/len(train_prep)
     avg_train_loss = sum(train_loss)/len(train_loss)
     if train_config.run_validation:
@@ -217,9 +229,11 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
     # If there's more than one CUDA device, reduce evaluation loss across all devices
     if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
         dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
+        world_size = int(os.environ["WORLD_SIZE"])
     
     # Compute average loss and perplexity
-    eval_epoch_loss = eval_loss / eval_dataset_len
+    eval_epoch_loss = eval_loss / len(eval_dataloader)
+    eval_epoch_loss = eval_epoch_loss/world_size
     eval_ppl = torch.exp(eval_epoch_loss)
     
     # Print evaluation metrics