| 
					
				 | 
			
			
				@@ -65,9 +65,6 @@ def main(**kwargs): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         clear_gpu_cache(local_rank) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         setup_environ_flags(rank) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    # Calculate gradient accumulation steps 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    gradient_accumulation_steps = train_config.batch_size_training // train_config.micro_batch_size 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     # Load the pre-trained model and setup its configuration 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     if train_config.enable_fsdp and train_config.low_cpu_fsdp: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """ 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -240,7 +237,7 @@ def main(**kwargs): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         tokenizer, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         optimizer, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         scheduler, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        gradient_accumulation_steps, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        train_config.gradient_accumulation_steps, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         train_config, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         fsdp_config if train_config.enable_fsdp else None, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         local_rank if train_config.enable_fsdp else None, 
			 |