| 
					
				 | 
			
			
				@@ -83,7 +83,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         with MemoryTrace() as memtrace:  # track the memory usage 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             model.train() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             total_loss = 0.0 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            for step, batch in enumerate(tqdm(train_dataloader,colour="blue", desc=f"Training Epoch{epoch}")): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            total_length = len(train_dataloader)//gradient_accumulation_steps 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch}", total=total_length) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            for step, batch in enumerate(train_dataloader): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 for key in batch.keys(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     if train_config.enable_fsdp: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                         batch[key] = batch[key].to(local_rank) 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -99,17 +101,17 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                         scaler.step(optimizer) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                         scaler.update() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                         optimizer.zero_grad() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        pbar.update(step//gradient_accumulation_steps) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     # regular backpropagation when fp16 is not used 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     loss.backward() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                         optimizer.step() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                         optimizer.zero_grad() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                if train_config.enable_fsdp: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    if rank==0:        
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                        print(f"\n step {step} is completed and loss is {loss.detach().float()}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    print(f"\n step {step} is completed and loss is {loss.detach().float()}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        pbar.update(step//gradient_accumulation_steps) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                pbar.set_description(f"Training Epoch: {epoch}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         epoch_end_time = time.perf_counter()-epoch_start_time 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         epoch_times.append(epoch_end_time)     
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # Reducing total_loss across all devices if there's more than one CUDA device 
			 |