| 
					
				 | 
			
			
				@@ -116,7 +116,7 @@ def main(**kwargs): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if not train_config.enable_fsdp or rank == 0: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             wandb_run = setup_wandb(train_config, fsdp_config, **kwargs) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    # setting quantization configs 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+     # setting quantization configs 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     bnb_config = None 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     if train_config.quantization: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if type(train_config.quantization) == type(True): 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -135,6 +135,17 @@ def main(**kwargs): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         update_config(quant_config, **kwargs) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         bnb_config = quant_config.create_bnb_config(train_config.quantization) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if train_config.enable_fsdp: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if train_config.quantization == "4bit": 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                bnb_config.bnb_4bit_quant_storage = bnb_config.bnb_4bit_compute_dtype 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                from logging import getLogger 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                logger = getLogger() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                logger.warning( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    "FSDP and 4-bit QLoRA enabled. Setting `bnb_4bit_quant_storage` " 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    f"to {bnb_config.bnb_4bit_compute_dtype} for compatibility." 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     # Load the pre-trained model and setup its configuration 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     use_cache = False if train_config.enable_fsdp else None 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     config = AutoConfig.from_pretrained(train_config.model_name) 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -264,10 +275,9 @@ def main(**kwargs): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         elif torch.cuda.is_available(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             device_id = torch.cuda.current_device() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if train_config.freeze_LLM_only: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            use_orig_params = True 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            use_orig_params = False 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        use_orig_params = train_config.freeze_LLM_only or train_config.use_peft 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+         
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         model = FSDP( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             model, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             auto_wrap_policy=( 
			 |