Kaynağa Gözat

adding torch cuda snapshot

Hamid Shojanazeri 1 yıl önce
ebeveyn
işleme
684487a097

+ 1 - 0
src/llama_recipes/configs/training.py

@@ -34,6 +34,7 @@ class train_config:
     dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP
     save_optimizer: bool=False # will be used if using FSDP
     use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
+    profile: bool = False
 
     
     

+ 19 - 2
src/llama_recipes/utils/train_utils.py

@@ -6,7 +6,7 @@ import time
 import yaml
 from pathlib import Path
 from pkg_resources import packaging
-
+from pickle import dump
 
 import torch
 import torch.cuda.nccl as nccl
@@ -48,6 +48,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
     
     Returns: results dictionary containing average training and validation perplexity and loss
     """
+    torch.cuda.memory._record_memory_history(
+    True, trace_alloc_max_entries=100000, trace_alloc_record_context=True)
+    
     # Create a gradient scaler for fp16
     if train_config.use_fp16 and train_config.enable_fsdp:
         scaler = ShardedGradScaler()
@@ -97,7 +100,21 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
 
                 pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
             pbar.close()
-                
+        
+        snapshot = torch.cuda.memory._snapshot()
+        '''
+        after you saved the snapshot
+        download the visualizer as follows:
+        wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/cuda/_memory_viz.py
+        
+        Visualize it using:
+        
+        python _memory_viz.py trace snapshot.pickle
+        
+        '''
+        with open("snapshot.pickle", "wb") as f:
+            pickle.dump(snapshot, f)  
+        
         epoch_end_time = time.perf_counter()-epoch_start_time
         epoch_times.append(epoch_end_time)    
         # Reducing total_loss across all devices if there's more than one CUDA device