| 
					
				 | 
			
			
				@@ -36,6 +36,7 @@ from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from pathlib import Path 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 sys.path.append(str(Path(__file__).resolve().parent.parent)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from policies import bfSixteen, fpSixteen,bfSixteen_mixed, get_llama_wrapper 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from accelerate.utils import is_xpu_available, is_ccl_available 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 def set_tokenizer_params(tokenizer: LlamaTokenizer): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     tokenizer.pad_token_id = 0 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -113,7 +114,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         epoch_end_time = time.perf_counter()-epoch_start_time 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         epoch_times.append(epoch_end_time)     
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # Reducing total_loss across all devices if there's more than one CUDA device 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if torch.cuda.device_count() > 1 and train_config.enable_fsdp: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            dist.all_reduce(total_loss, op=dist.ReduceOp.SUM) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        elif torch.cuda.device_count() > 1 and train_config.enable_fsdp: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             dist.all_reduce(total_loss, op=dist.ReduceOp.SUM) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         train_epoch_loss = total_loss / len(train_dataloader) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if train_config.enable_fsdp: 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -125,17 +128,29 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				          
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if train_config.enable_fsdp: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             if rank==0: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if is_xpu_available(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    print(f"Max XPU memory allocated was {memtrace.peak} GB") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    print(f"Max XPU memory reserved was {memtrace.max_reserved} GB") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    print(f"Peak active XPU memory was {memtrace.peak_active_gb} GB") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    print(f"Xpu Malloc retires : {memtrace.cuda_malloc_retires}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    print(f"Max CUDA memory allocated was {memtrace.peak} GB") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if is_xpu_available(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                print(f"Max XPU memory allocated was {memtrace.peak} GB") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                print(f"Max XPU memory reserved was {memtrace.max_reserved} GB") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                print(f"Peak active XPU memory was {memtrace.peak_active_gb} GB") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                print(f"Xpu Malloc retires : {memtrace.cuda_malloc_retires}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 print(f"Max CUDA memory allocated was {memtrace.peak} GB") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            print(f"Max CUDA memory allocated was {memtrace.peak} GB") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				          
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # Update the learning rate as needed 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         lr_scheduler.step() 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -259,6 +274,8 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				      
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     # If there's more than one CUDA device, reduce evaluation loss across all devices 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    if is_xpu_available() and (torch.cuda.device_count() > 1 and train_config.enable_fsdp): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     if torch.cuda.device_count() > 1 and train_config.enable_fsdp: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				      
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -292,7 +309,11 @@ def check_frozen_layers_peft_model(model): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 def setup(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     """Initialize the process group for distributed training""" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    dist.init_process_group("nccl") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    if is_ccl_available(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # distributed training on xpus 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        dist.init_process_group("ccl") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        dist.init_process_group("nccl") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 def setup_environ_flags(rank): 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -316,7 +337,10 @@ def clear_gpu_cache(rank=None): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     """Clear the GPU cache for all ranks""" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     if rank == 0: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         print(f"Clearing GPU cache for all ranks") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    torch.cuda.empty_cache() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    if is_xpu_available(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        torch.xpu_empty_cache() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        torch.cuda.empty_cache() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 def get_parameter_dtypes(model): 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -348,13 +372,14 @@ def print_model_size(model, config, rank: int = 0) -> None: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 def get_policies(cfg, rank): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     """Get the policies for mixed precision and fsdp wrapping""" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				      
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    verify_bfloat_support = ( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    verify_bfloat_support = (( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     torch.version.cuda 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     and torch.cuda.is_bf16_supported() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     and packaging.version.parse(torch.version.cuda).release >= (11, 0) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     and dist.is_nccl_available() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     and nccl.version() >= (2, 10) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ) or 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    (is_xpu_available())) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     mixed_precision_policy = None 
			 |