Browse Source

feat: Enable QLoRA and FSDP compatibility (#974)

Igor Kasianenko 1 month ago
parent
commit
86c37e829a
1 changed files with 15 additions and 5 deletions
  1. 15 5
      src/llama_cookbook/finetuning.py

+ 15 - 5
src/llama_cookbook/finetuning.py

@@ -116,7 +116,7 @@ def main(**kwargs):
         if not train_config.enable_fsdp or rank == 0:
             wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)
 
-    # setting quantization configs
+     # setting quantization configs
     bnb_config = None
     if train_config.quantization:
         if type(train_config.quantization) == type(True):
@@ -135,6 +135,17 @@ def main(**kwargs):
         update_config(quant_config, **kwargs)
         bnb_config = quant_config.create_bnb_config(train_config.quantization)
 
+       
+        if train_config.enable_fsdp:
+            if train_config.quantization == "4bit":
+                bnb_config.bnb_4bit_quant_storage = bnb_config.bnb_4bit_compute_dtype
+                from logging import getLogger
+                logger = getLogger()
+                logger.warning(
+                    "FSDP and 4-bit QLoRA enabled. Setting `bnb_4bit_quant_storage` "
+                    f"to {bnb_config.bnb_4bit_compute_dtype} for compatibility."
+                )
+
     # Load the pre-trained model and setup its configuration
     use_cache = False if train_config.enable_fsdp else None
     config = AutoConfig.from_pretrained(train_config.model_name)
@@ -264,10 +275,9 @@ def main(**kwargs):
         elif torch.cuda.is_available():
             device_id = torch.cuda.current_device()
 
-        if train_config.freeze_LLM_only:
-            use_orig_params = True
-        else:
-            use_orig_params = False
+       
+        use_orig_params = train_config.freeze_LLM_only or train_config.use_peft
+        
         model = FSDP(
             model,
             auto_wrap_policy=(