| 
					
				 | 
			
			
				@@ -57,9 +57,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     elif train_config.use_fp16 and not train_config.enable_fsdp: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         scaler = torch.cuda.amp.GradScaler() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     if train_config.enable_fsdp: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        world_size = int(os.environ["WORLD_SIZE"])  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        world_size = int(os.environ["WORLD_SIZE"]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-     
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -74,12 +74,18 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         train_step_loss = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         val_step_loss = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         val_step_perplexity = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-         
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     epoch_times = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     checkpoint_times = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     results = {} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     best_val_loss = float("inf") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    total_train_steps = 0 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    max_steps_reached = False  # Flag to indicate max training steps reached 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # Start the training loop 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     for epoch in range(train_config.num_epochs): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # stop when the maximum number of training steps is reached 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if max_steps_reached: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            break 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         epoch_start_time = time.perf_counter() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         with MemoryTrace() as memtrace:  # track the memory usage 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             model.train() 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -87,6 +93,13 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             total_length = len(train_dataloader)//gradient_accumulation_steps 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             for step, batch in enumerate(train_dataloader): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                total_train_steps += 1 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                # stop when the maximum number of training steps is reached 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if train_config.max_train_step > 0 and total_train_steps > train_config.max_train_step: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    max_steps_reached = True 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    if not train_config.enable_fsdp or local_rank==0: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        print("max training steps reached, stopping training, total_train_steps: ", total_train_steps-1) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    break 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 for key in batch.keys(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     if train_config.enable_fsdp: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                         if is_xpu_available(): 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -98,7 +111,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                         if is_xpu_available(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                             batch[key] = batch[key].to('xpu:0') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                         else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                            batch[key] = batch[key].to('cuda:0')               
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            batch[key] = batch[key].to('cuda:0') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 with autocast(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     loss = model(**batch).loss 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 loss = loss / gradient_accumulation_steps 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -133,7 +146,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                         optimizer.zero_grad() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                         pbar.update(1) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                if wandb_run:  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if wandb_run: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     if not train_config.enable_fsdp or rank==0: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                         wandb_run.log({ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                             'train/epoch': epoch + 1, 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -158,10 +171,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if train_config.enable_fsdp: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             train_epoch_loss = train_epoch_loss/world_size 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         train_perplexity = torch.exp(train_epoch_loss) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-         
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         train_prep.append(float(train_perplexity)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         train_loss.append(float(train_epoch_loss)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-         
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if not train_config.enable_fsdp or rank==0: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             memtrace.print_stats() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -231,7 +244,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-         
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # Saving the results every epoch to plot later 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if train_config.save_metrics: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep) 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -255,7 +268,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         results["metrics_filename"] = metrics_filename 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     #saving the training params including fsdp setting for reference. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    if train_config.enable_fsdp and not train_config.use_peft: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    if train_config.enable_fsdp and not train_config.use_peft and rank==0: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         save_train_params(train_config, fsdp_config, rank) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     return results 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -279,8 +292,15 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     val_step_loss = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     val_step_perplexity = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     eval_loss = 0.0  # Initialize evaluation loss 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    total_eval_steps = 0 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     with MemoryTrace() as memtrace: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch", dynamic_ncols=True)): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            total_eval_steps += 1 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            # stop when the maximum number of eval steps is reached 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if train_config.max_eval_step > 0 and total_eval_steps > train_config.max_eval_step: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if not train_config.enable_fsdp or local_rank==0: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    print("max eval steps reached, stopping evaluation, total_eval_steps: ", total_eval_steps - 1) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                break 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             for key in batch.keys(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 if train_config.enable_fsdp: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     batch[key] = batch[key].to(local_rank) 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -288,7 +308,7 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     if is_xpu_available(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                         batch[key] = batch[key].to('xpu:0') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                        batch[key] = batch[key].to('cuda:0')   
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        batch[key] = batch[key].to('cuda:0') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             # Ensure no gradients are computed for this scope to save memory 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             with torch.no_grad(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 # Forward pass and compute loss 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -296,7 +316,7 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 loss = outputs.loss 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 if train_config.save_metrics: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     val_step_loss.append(loss.detach().float().item()) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    val_step_perplexity.append(float(torch.exp(loss.detach().float())))   
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    val_step_perplexity.append(float(torch.exp(loss.detach().float()))) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 eval_loss += loss.detach().float() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             # Decode predictions and add to evaluation predictions list 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -324,12 +344,12 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         print(f" {eval_ppl=} {eval_epoch_loss=}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    if wandb_run:  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    if wandb_run: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         wandb_run.log({ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                         'eval/perplexity': eval_ppl, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                         'eval/loss': eval_epoch_loss, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     }, commit=False) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-         
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     return eval_ppl, eval_epoch_loss, val_step_loss, val_step_perplexity 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 def freeze_transformer_layers(model, num_layer): 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -410,7 +430,7 @@ def print_model_size(model, config, rank: int = 0) -> None: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 def get_policies(cfg, rank): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     """Get the policies for mixed precision and fsdp wrapping""" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-     
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     verify_bfloat_support = (( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     torch.version.cuda 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     and torch.cuda.is_bf16_supported() 
			 |