Procházet zdrojové kódy

added the copyright

HamidShojanazeri před 10 měsíci
rodič
revize
ae9fd5eef4
1 změnil soubory, kde provedl 18 přidání a 19 odebrání
  1. 18 19
      src/llama_recipes/configs/quantization.py

+ 18 - 19
src/llama_recipes/configs/quantization.py

@@ -1,5 +1,8 @@
-from dataclasses import dataclass, field
-from typing import List
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+from dataclasses import dataclass
+from typing import Optional
 import torch
 from transformers import BitsAndBytesConfig
 
@@ -10,22 +13,18 @@ class quantizatio_config:
     use_double_quant: bool = False
     quant_storage: torch.dtype = torch.bfloat16
 
-    def create_bnb_config(self):
+    def create_bnb_config(self) -> BitsAndBytesConfig:
+        if self.quant_type not in {"int4", "int8"}:
+            raise ValueError("quant_type must be either 'int4' or 'int8'")
+
+        config_params = {
+            "bnb_4bit_quant_type" if self.quant_type == "int4" else "bnb_8bit_quant_type": self.quant_type,
+            "bnb_4bit_compute_dtype" if self.quant_type == "int4" else "bnb_8bit_compute_dtype": self.compute_dtype,
+            "bnb_4bit_use_double_quant" if self.quant_type == "int4" else "bnb_8bit_use_double_quant": self.use_double_quant,
+            "bnb_4bit_quant_storage" if self.quant_type == "int4" else "bnb_8bit_quant_storage": self.quant_storage,
+        }
+
         if self.quant_type == "int4":
-            return BitsAndBytesConfig(
-                load_in_4bit=True,
-                bnb_4bit_quant_type="nf4",
-                bnb_4bit_compute_dtype=self.compute_dtype,
-                bnb_4bit_use_double_quant=self.use_double_quant,
-                bnb_4bit_quant_storage=self.quant_storage
-            )
-        elif self.quant_type == "int8":
-            return BitsAndBytesConfig(
-                load_in_8bit=True,
-                bnb_8bit_quant_type="int8",
-                bnb_8bit_compute_dtype=self.compute_dtype,
-                bnb_8bit_use_double_quant=self.use_double_quant,
-                bnb_8bit_quant_storage=self.quant_storage
-            )
+            return BitsAndBytesConfig(load_in_4bit=True, **config_params)
         else:
-            raise ValueError("quant_type must be either 'int4' or 'int8'")
+            return BitsAndBytesConfig(load_in_8bit=True, **config_params)