Browse Source

correct 8bit quantiation and make quant parameter configurable from commandline

Matthias Reso 9 months ago
parent
commit
818fe84032
2 changed files with 29 additions and 23 deletions
  1. 14 14
      src/llama_recipes/configs/quantization.py
  2. 15 9
      src/llama_recipes/finetuning.py

+ 14 - 14
src/llama_recipes/configs/quantization.py

@@ -8,23 +8,23 @@ from transformers import BitsAndBytesConfig
 
 @dataclass
 class quantization_config:
-    quant_type: str  # "int4" or "int8"
-    compute_dtype: torch.dtype
+    quant_type: str =  "fp4" # "fp4" or "nf4"
+    compute_dtype: torch.dtype = torch.bfloat16
     use_double_quant: bool = False
     quant_storage: torch.dtype = torch.bfloat16
 
-    def create_bnb_config(self) -> BitsAndBytesConfig:
-        if self.quant_type not in {"int4", "int8"}:
-            raise ValueError("quant_type must be either 'int4' or 'int8'")
+    def create_bnb_config(self, quantization: str) -> BitsAndBytesConfig:
+        if quantization not in {"4bit", "8bit"}:
+            raise ValueError("quantization must be either '4bit' or '8bit'")
 
-        config_params = {
-            "bnb_4bit_quant_type" if self.quant_type == "int4" else "bnb_8bit_quant_type": self.quant_type,
-            "bnb_4bit_compute_dtype" if self.quant_type == "int4" else "bnb_8bit_compute_dtype": self.compute_dtype,
-            "bnb_4bit_use_double_quant" if self.quant_type == "int4" else "bnb_8bit_use_double_quant": self.use_double_quant,
-            "bnb_4bit_quant_storage" if self.quant_type == "int4" else "bnb_8bit_quant_storage": self.quant_storage,
-        }
-
-        if self.quant_type == "int4":
+        if quantization == "4bit":
+            config_params = {
+                "bnb_4bit_quant_type": self.quant_type,
+                "bnb_4bit_compute_dtype": self.compute_dtype,
+                "bnb_4bit_use_double_quant": self.use_double_quant,
+                "bnb_4bit_quant_storage": self.quant_storage,
+            }
+            
             return BitsAndBytesConfig(load_in_4bit=True, **config_params)
         else:
-            return BitsAndBytesConfig(load_in_8bit=True, **config_params)
+            return BitsAndBytesConfig(load_in_8bit=True)

+ 15 - 9
src/llama_recipes/finetuning.py

@@ -27,7 +27,7 @@ from transformers.models.llama.modeling_llama import LlamaDecoderLayer
 
 from llama_recipes.configs import fsdp_config as FSDP_CONFIG
 from llama_recipes.configs import train_config as TRAIN_CONFIG
-from llama_recipes.configs.quantization import quantization_config 
+from llama_recipes.configs import quantization_config  as QUANTIZATION_CONFIG
 from llama_recipes.data.concatenator import ConcatDataset
 from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing
 
@@ -51,6 +51,7 @@ from llama_recipes.utils.train_utils import (
     get_policies,
 )
 from accelerate.utils import is_xpu_available
+from warnings import warn
 
 def setup_wandb(train_config, fsdp_config, **kwargs):
     try:
@@ -100,14 +101,19 @@ def main(**kwargs):
         if not train_config.enable_fsdp or rank==0:
             wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)
     
-    #setting quantization configs        
-    quant_config = quantization_config(
-        quant_type=train_config.quantization,
-        compute_dtype=torch.bfloat16,
-        use_double_quant=True,
-        quant_storage=torch.bfloat16
-    ) if train_config.quantization else None
-    bnb_config = quant_config.create_bnb_config() if quant_config else None
+    #setting quantization configs
+    bnb_config = None
+    if train_config.quantization:
+        if type(train_config.quantization) == type(True):
+            warn("Quantization (--quantization) is a boolean, please specify quantization as '4bit' or '8bit'. Defaulting to '8bit' but this might change in the future.", FutureWarning)
+            train_config.quantization = "8bit"
+
+        if train_config.quantization == "8bit" and train_config.enable_fsdp:
+            raise ValueError("8bit quantization is not supported with FSDP, please use 4bit quantization")
+
+        quant_config = QUANTIZATION_CONFIG()
+        update_config(quant_config, **kwargs)
+        bnb_config = quant_config.create_bnb_config(train_config.quantization)
 
     # Load the pre-trained model and setup its configuration
     use_cache = False if train_config.enable_fsdp else None