| 
					
				 | 
			
			
				@@ -93,10 +93,16 @@ def main(**kwargs): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     gradient_accumulation_steps = train_config.batch_size_training // train_config.micro_batch_size 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				       
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     # Load the pre-trained model and setup its configuration 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    if train_config.enable_fsdp: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        # for FSDP, we save cpu memory by loading pretrained model on rank0 only. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    if train_config.enable_fsdp and train_config.low_cpu_fsdp: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # for FSDP, we can save cpu memory by loading pretrained model on rank0 only. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # this avoids cpu oom when loading large models like llama 70B, in which case 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        # model alone would consume 2+TB cpu mem (70 * 4 * 8) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # overhead and currently requires latest nightly. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        v = packaging.version.parse(torch.__version__) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        verify_latest_nightly = v.is_devrelease and v.dev >= 20230701 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if not verify_latest_nightly: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, " 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            "please install latest nightly.") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if rank == 0: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             model = LlamaForCausalLM.from_pretrained( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 train_config.model_name, 
			 |