Przeglądaj źródła

adding fsdp-qlora in progress

Hamid Shojanazeri 11 miesięcy temu
rodzic
commit
f3fd43dc62

+ 1 - 1
src/llama_recipes/configs/__init__.py

@@ -1,7 +1,7 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
-from llama_recipes.configs.peft import lora_config, llama_adapter_config, prefix_config
+from llama_recipes.configs.peft import lora_config, llama_adapter_config, prefix_config, qlora_config
 from llama_recipes.configs.fsdp import fsdp_config
 from llama_recipes.configs.training import train_config
 from llama_recipes.configs.wandb import wandb_config

+ 16 - 1
src/llama_recipes/configs/peft.py

@@ -23,4 +23,19 @@ class llama_adapter_config:
 @dataclass
 class prefix_config:
      num_virtual_tokens: int=30
-     task_type: str= "CAUSAL_LM"    
+     task_type: str= "CAUSAL_LM"    
+
+
+@dataclass
+class qlora_config:
+     r: int=8
+     lora_alpha: int=32
+     target_modules: List[str] = field(default_factory=lambda: ["q_proj", "v_proj"])
+     bias= "none"
+     task_type: str= "CAUSAL_LM"
+     lora_dropout: float=0.05
+     inference_mode: bool = False
+     bnb_4bit_quant_type: str = "bf16"
+     bnb_4bit_compute_dtype: str = "bf16"
+     bnb_4bit_quant_storage: str = "bf16"
+     use_nested_quant: bool = False

+ 8 - 1
src/llama_recipes/configs/training.py

@@ -26,12 +26,14 @@ class train_config:
     mixed_precision: bool=True
     val_batch_size: int=1
     dataset = "samsum_dataset"
-    peft_method: str = "lora" # None , llama_adapter, prefix
+    peft_method: str = "lora" # None , qlora, llama_adapter, prefix
     use_peft: bool=False
     output_dir: str = "PATH/to/save/PEFT/model"
     freeze_layers: bool = False
     num_freeze_layers: int = 1
     quantization: bool = False
+    use_4bit_quantization: bool = False
+    use_8bit_quantization: bool = False
     one_gpu: bool = False
     save_model: bool = True
     dist_checkpoint_root_folder: str="PATH/to/save/FSDP/model" # will be used if using FSDP
@@ -40,3 +42,8 @@ class train_config:
     use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
     use_wandb: bool = False # Enable wandb for experient tracking
     save_metrics: bool = False # saves training metrics to a json file for later plotting
+    inference_mode: bool = False
+    bnb_4bit_quant_type: str = "bfloat16"
+    bnb_4bit_compute_dtype: str = "bfloat16"
+    bnb_4bit_quant_storage: str = "bfloat16"
+    use_nested_quant: bool = True

+ 11 - 7
src/llama_recipes/finetuning.py

@@ -47,6 +47,7 @@ from llama_recipes.utils.train_utils import (
     clear_gpu_cache,
     print_model_size,
     get_policies,
+    set_quantization_settings
 )
 from accelerate.utils import is_xpu_available
 
@@ -101,6 +102,9 @@ def main(**kwargs):
 
     # Load the pre-trained model and setup its configuration
     use_cache = False if train_config.enable_fsdp else None
+    if train_config.quantization:
+        bnb_config = set_quantization_settings(train_config)
+
     if train_config.enable_fsdp and train_config.low_cpu_fsdp:
         """
         for FSDP, we can save cpu memory by loading pretrained model on rank0 only.
@@ -108,15 +112,15 @@ def main(**kwargs):
         model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms
         overhead and currently requires latest nightly.
         """
-        v = packaging.version.parse(torch.__version__)
-        verify_latest_nightly = v.is_devrelease and v.dev >= 20230701
-        if not verify_latest_nightly:
-            raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, "
-                            "please install latest nightly.")
+        # v = packaging.version.parse(torch.__version__)
+        # verify_latest_nightly = v.is_devrelease and v.dev >= 20230701
+        # if not verify_latest_nightly:
+        #     raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, "
+        #                     "please install latest nightly.")
         if rank == 0:
             model = LlamaForCausalLM.from_pretrained(
                 train_config.model_name,
-                load_in_8bit=True if train_config.quantization else None,
+                quantization_config=bnb_config,
                 device_map="auto" if train_config.quantization else None,
                 use_cache=use_cache,
                 attn_implementation="sdpa" if train_config.use_fast_kernels else None,
@@ -130,7 +134,7 @@ def main(**kwargs):
     else:
         model = LlamaForCausalLM.from_pretrained(
             train_config.model_name,
-            load_in_8bit=True if train_config.quantization else None,
+            quantization_config=bnb_config,
             device_map="auto" if train_config.quantization else None,
             use_cache=use_cache,
             attn_implementation="sdpa" if train_config.use_fast_kernels else None,

+ 2 - 2
src/llama_recipes/utils/config_utils.py

@@ -14,7 +14,7 @@ from peft import (
 from transformers import default_data_collator
 from transformers.data import DataCollatorForSeq2Seq
 
-from llama_recipes.configs import datasets, lora_config, llama_adapter_config, prefix_config, train_config
+from llama_recipes.configs import datasets, lora_config, qlora_config, llama_adapter_config, prefix_config, train_config
 from llama_recipes.data.sampler import LengthBasedBatchSampler, DistributedLengthBasedBatchSampler
 from llama_recipes.utils.dataset_utils import DATASET_PREPROC
 
@@ -41,7 +41,7 @@ def update_config(config, **kwargs):
 
 
 def generate_peft_config(train_config, kwargs):
-    configs = (lora_config, llama_adapter_config, prefix_config)
+    configs = (lora_config, qlora_config, llama_adapter_config, prefix_config)
     peft_configs = (LoraConfig, AdaptionPromptConfig, PrefixTuningConfig)
     names = tuple(c.__name__.rstrip("_config") for c in configs)
 

+ 32 - 0
src/llama_recipes/utils/train_utils.py

@@ -493,3 +493,35 @@ def save_to_json(output_filename, train_step_loss, train_epoch_loss, train_step_
     }
     with open(output_filename, "w") as f:
         json.dump(metrics_data, f)
+
+
+def set_quantization_settings(train_config):
+     
+    """
+    Configures and returns quantization settings based on training  and PEFT configuration.
+    
+    Parameters:
+    - train_config: training config object, expected to include settings for 4-bit "use_4bit_quantization" or 8-bit "use_8bit_quantization" along with "quantization",
+    and "qlora" as the peft_method.
+    - peft_config: peft configs that include qlora settings such as "compute_dtype", "quant_storage_dtype", "use_nested_quant", and "bnb_4bit_quant_type".
+    Returns:
+    - A BitsAndBytesConfig object configured with the specified settings.
+    """
+    from transformers import BitsAndBytesConfig
+
+    if train_config.use_4bit_quantization:
+        compute_dtype = getattr(torch, train_config.bnb_4bit_compute_dtype)
+        quant_storage_dtype = getattr(torch, train_config.bnb_4bit_quant_storage)
+        
+        # Initialize BitsAndBytesConfig with 4-bit quantization settings.
+        bnb_config = BitsAndBytesConfig(
+            load_in_4bit= train_config.use_4bit_quantization,
+            bnb_4bit_quant_type= train_config.bnb_4bit_quant_type,
+            bnb_4bit_compute_dtype= compute_dtype,
+            bnb_4bit_use_double_quant= train_config.use_nested_quant,
+            bnb_4bit_quant_storage= quant_storage_dtype,
+        )
+    # Initialize BitsAndBytesConfig with 8-bit quantization flag; 
+    elif train_config.use_8bit_quantization:
+        bnb_config = BitsAndBytesConfig(load_in_8bit=train_config.use_8bit_quantization)
+    return bnb_config