| 
					
				 | 
			
			
				@@ -66,6 +66,7 @@ def main(**kwargs): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         setup_environ_flags(rank) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     # Load the pre-trained model and setup its configuration 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    use_cache = False if train_config.enable_fsdp else None 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     if train_config.enable_fsdp and train_config.low_cpu_fsdp: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         for FSDP, we can save cpu memory by loading pretrained model on rank0 only. 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -83,9 +84,11 @@ def main(**kwargs): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 train_config.model_name, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 load_in_8bit=True if train_config.quantization else None, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 device_map="auto" if train_config.quantization else None, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                use_cache=use_cache, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             llama_config = LlamaConfig.from_pretrained(train_config.model_name) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            llama_config.use_cache = use_cache 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             with torch.device("meta"): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 model = LlamaForCausalLM(llama_config) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -94,6 +97,7 @@ def main(**kwargs): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             train_config.model_name, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             load_in_8bit=True if train_config.quantization else None, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             device_map="auto" if train_config.quantization else None, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            use_cache=use_cache, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     if train_config.enable_fsdp and train_config.use_fast_kernels: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """ 
			 |