Browse Source

update test env

Hamid Shojanazeri 1 year ago
parent
commit
d824099012
3 changed files with 82 additions and 29 deletions
  1. 1 0
      configs/fsdp.py
  2. 27 27
      llama_finetuning.py
  3. 54 2
      utils/train_utils.py

+ 1 - 0
configs/fsdp.py

@@ -16,6 +16,7 @@ class fsdp_config:
     pure_bf16: bool = False
     pure_bf16: bool = False
     optimizer: str= "AdamW"
     optimizer: str= "AdamW"
     optimizer_overlap: bool=False
     optimizer_overlap: bool=False
+    profile_mem: bool=False
     
     
     
     
     
     

+ 27 - 27
llama_finetuning.py

@@ -216,37 +216,37 @@ def main(**kwargs):
             collate_fn=default_data_collator,
             collate_fn=default_data_collator,
         )
         )
         
         
-    # Initialize the optimizer and learning rate scheduler
-    # if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
-    #     optimizer = AnyPrecisionAdamW(
-    #         model.parameters(),
-    #         lr=train_config.lr,
-    #         momentum_dtype=torch.bfloat16,
-    #         variance_dtype=torch.bfloat16,
-    #         use_kahan_summation=False,
-    #     )
-    # else:
-    #     optimizer = optim.AdamW(
-    #         model.parameters(),
-    #         lr=train_config.lr,
-    #         weight_decay=0.0,
-    #     )
-    if fsdp_config.optimizer_overlap:
-        print("we are hereeeeeeeee**************************************")
-        optim_kwargs = {"lr": train_config.lr}
-        _apply_optimizer_in_backward(
-            optimizer_class=optim.AdamW,
-            params=model.parameters(),
-            optimizer_kwargs=optim_kwargs,
-            register_hook=False,
+    #Initialize the optimizer and learning rate scheduler
+    if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
+        optimizer = AnyPrecisionAdamW(
+            model.parameters(),
+            lr=train_config.lr,
+            momentum_dtype=torch.bfloat16,
+            variance_dtype=torch.bfloat16,
+            use_kahan_summation=False,
         )
         )
-        for p in model.parameters():
-            assert hasattr(p, "_in_backward_optimizers")
-        optim_kwargs = {"lr": train_config.lr, "weight_decay":0.0}
+    else:
         optimizer = optim.AdamW(
         optimizer = optim.AdamW(
             model.parameters(),
             model.parameters(),
-            **optim_kwargs
+            lr=train_config.lr,
+            weight_decay=0.0,
         )
         )
+    # if fsdp_config.optimizer_overlap:
+    #     print("we are hereeeeeeeee**************************************")
+    #     optim_kwargs = {"lr": train_config.lr}
+    #     _apply_optimizer_in_backward(
+    #         optimizer_class=optim.AdamW,
+    #         params=model.parameters(),
+    #         optimizer_kwargs=optim_kwargs,
+    #         register_hook=False,
+    #     )
+    #     for p in model.parameters():
+    #         assert hasattr(p, "_in_backward_optimizers")
+    #     optim_kwargs = {"lr": train_config.lr, "weight_decay":0.0}
+    #     optimizer = optim.AdamW(
+    #         model.parameters(),
+    #         **optim_kwargs
+    #     )
         
         
         
         
     scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
     scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)

+ 54 - 2
utils/train_utils.py

@@ -3,7 +3,7 @@
 
 
 import os
 import os
 import sys
 import sys
-from typing import List
+from typing import List, Optional
 import yaml
 import yaml
 
 
 import fire
 import fire
@@ -36,6 +36,9 @@ from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
 from pathlib import Path
 from pathlib import Path
 sys.path.append(str(Path(__file__).resolve().parent.parent))
 sys.path.append(str(Path(__file__).resolve().parent.parent))
 from policies import bfSixteen, fpSixteen,bfSixteen_mixed, get_llama_wrapper
 from policies import bfSixteen, fpSixteen,bfSixteen_mixed, get_llama_wrapper
+import torch.autograd.profiler as profiler
+from torch.cuda._memory_viz import profile_plot
+from pickle import dump
 
 
 def set_tokenizer_params(tokenizer: LlamaTokenizer):
 def set_tokenizer_params(tokenizer: LlamaTokenizer):
     tokenizer.pad_token_id = 0
     tokenizer.pad_token_id = 0
@@ -64,6 +67,13 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
     Returns: results dictionary containing average training and validation perplexity and loss
     Returns: results dictionary containing average training and validation perplexity and loss
     """
     """
     # Create a gradient scaler for fp16
     # Create a gradient scaler for fp16
+    torch.cuda.memory._record_memory_history(True,
+        # keep 100,000 alloc/free events from before the snapshot
+        trace_alloc_max_entries=100000,
+
+        # record stack information for the trace events
+        trace_alloc_record_context=True)
+    
     if train_config.use_fp16 and train_config.enable_fsdp:
     if train_config.use_fp16 and train_config.enable_fsdp:
         scaler = ShardedGradScaler()
         scaler = ShardedGradScaler()
     elif train_config.use_fp16 and not train_config.enable_fsdp:
     elif train_config.use_fp16 and not train_config.enable_fsdp:
@@ -82,7 +92,18 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         with MemoryTrace() as memtrace:  # track the memory usage
         with MemoryTrace() as memtrace:  # track the memory usage
             model.train()
             model.train()
             total_loss = 0.0
             total_loss = 0.0
+            # if fsdp_config.profile_mem:
+            #     with torch.profiler.profile(
+            #     schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=1),
+            #     activities=[profiler.ProfilerActivity.CPU, profiler.ProfilerActivity.CUDA],
+            #     on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/llama2-7b'),
+            #     record_shapes=True,
+            #     profile_memory=True,
+            #     with_stack=True,
+            #     ) as prof:
             for step, batch in enumerate(tqdm(train_dataloader,colour="blue", desc=f"Training Epoch{epoch}")):
             for step, batch in enumerate(tqdm(train_dataloader,colour="blue", desc=f"Training Epoch{epoch}")):
+                if step >10:
+                    break
                 for key in batch.keys():
                 for key in batch.keys():
                     if train_config.enable_fsdp:
                     if train_config.enable_fsdp:
                         batch[key] = batch[key].to(local_rank)
                         batch[key] = batch[key].to(local_rank)
@@ -104,8 +125,14 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                     if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
                     if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
                         optimizer.step()
                         optimizer.step()
                         optimizer.zero_grad()
                         optimizer.zero_grad()
+                if step == 4:
+                    if rank==0:
+                        snapshot = torch.cuda.memory._snapshot()
+                        with open('snapshot.pickle', 'wb') as f:
+                            dump(snapshot, f)
                         
                         
                 print(f"\n step {step} is completed and loss is {loss.detach().float()}")
                 print(f"\n step {step} is completed and loss is {loss.detach().float()}")
+                    
         end_epoch = time.perf_counter()
         end_epoch = time.perf_counter()
         epoch_time = end_epoch- start_epoch
         epoch_time = end_epoch- start_epoch
         print(f"epoch time is {epoch_time}")
         print(f"epoch time is {epoch_time}")
@@ -235,6 +262,8 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
     eval_loss = 0.0  # Initialize evaluation loss
     eval_loss = 0.0  # Initialize evaluation loss
     with MemoryTrace() as memtrace:
     with MemoryTrace() as memtrace:
         for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch")):
         for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch")):
+            if step>6:
+                break
             for key in batch.keys():
             for key in batch.keys():
                 if train_config.enable_fsdp:
                 if train_config.enable_fsdp:
                     batch[key] = batch[key].to(local_rank)
                     batch[key] = batch[key].to(local_rank)
@@ -409,4 +438,27 @@ def save_train_params(train_config, fsdp_config, rank):
         with open(file_name, 'w') as f:
         with open(file_name, 'w') as f:
             f.write(config_yaml)
             f.write(config_yaml)
         if rank==0:
         if rank==0:
-            print(f"training params are saved in {file_name}")
+            print(f"training params are saved in {file_name}")
+
+
+def export_memory_timeline(path: str, device: Optional[str] = None) -> None:
+    try:
+        from torch.profiler._memory_profiler import MemoryProfile, MemoryProfileTimeline
+    except ImportError:
+        # Handle the ImportError here, such as providing an alternative implementation or an error message.
+        print("The required module 'MemoryProfileTimeline' is not available.")
+    
+
+def _memory_profile():
+    try:
+        from torch.profiler._memory_profiler import MemoryProfile, MemoryProfileTimeline
+    except ImportError:
+        # Handle the ImportError here, such as providing an alternative implementation or an error message.
+        print("The required module 'MemoryProfileTimeline' is not available.")
+    required = ("record_shapes", "profile_memory", "with_stack")
+    missing = [f"{i}=True" for i in required if not getattr(self, i)]
+    if missing:
+        raise ValueError(f"{', '.join(missing)} required for memory profiling.")
+
+    assert self.profiler is not None and self.profiler.kineto_results is not None
+    return MemoryProfile(self.profiler.kineto_results)