| 
					
				 | 
			
			
				@@ -137,6 +137,7 @@ def main(**kwargs): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         processor = AutoProcessor.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         processor.tokenizer.padding_side='right' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        is_vision = False 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         model = LlamaForCausalLM.from_pretrained( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             train_config.model_name, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             quantization_config=bnb_config, 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -188,23 +189,20 @@ def main(**kwargs): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             freeze_transformer_layers(model, train_config.num_freeze_layers) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [CLIPEncoderLayer]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer,CLIPEncoderLayer]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # if is_vision: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        #     my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer,CLIPEncoderLayer]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        #     my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         print("FSDP is enabled",my_auto_wrapping_policy) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         device_id = 0 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if is_xpu_available(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             device_id = torch.xpu.current_device() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         elif torch.cuda.is_available(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             device_id = torch.cuda.current_device() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if train_config.use_peft: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            wrapping_policy = my_auto_wrapping_policy 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            if is_vision: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                wrapping_policy = ModuleWrapPolicy([CLIPEncoderLayer, LlamaDecoderLayer]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                wrapping_policy = ModuleWrapPolicy([LlamaDecoderLayer]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         model = FSDP( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             model, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            auto_wrap_policy= wrapping_policy, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             cpu_offload=CPUOffload(offload_params=True) if fsdp_config.fsdp_cpu_offload else None, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             sharding_strategy=fsdp_config.sharding_strategy, 
			 |