Pārlūkot izejas kodu

adding quantization config

HamidShojanazeri 10 mēneši atpakaļ
vecāks
revīzija
03c58705bd
1 mainītis faili ar 17 papildinājumiem un 12 dzēšanām
  1. 17 12
      src/llama_recipes/configs/quantization.py

+ 17 - 12
src/llama_recipes/configs/quantization.py

@@ -1,9 +1,10 @@
 from dataclasses import dataclass, field
 from typing import List
-
+import torch
+from transformers import BitsAndBytesConfig
 
 @dataclass
-class QuantizationConfig:
+class quantizatio_config:
     quant_type: str  # "int4" or "int8"
     compute_dtype: torch.dtype
     use_double_quant: bool = False
@@ -11,16 +12,20 @@ class QuantizationConfig:
 
     def create_bnb_config(self):
         if self.quant_type == "int4":
-            quant_type_str = "nf4"
+            return BitsAndBytesConfig(
+                load_in_4bit=True,
+                bnb_4bit_quant_type="nf4",
+                bnb_4bit_compute_dtype=self.compute_dtype,
+                bnb_4bit_use_double_quant=self.use_double_quant,
+                bnb_4bit_quant_storage=self.quant_storage
+            )
         elif self.quant_type == "int8":
-            quant_type_str = "int8"
+            return BitsAndBytesConfig(
+                load_in_8bit=True,
+                bnb_8bit_quant_type="int8",
+                bnb_8bit_compute_dtype=self.compute_dtype,
+                bnb_8bit_use_double_quant=self.use_double_quant,
+                bnb_8bit_quant_storage=self.quant_storage
+            )
         else:
             raise ValueError("quant_type must be either 'int4' or 'int8'")
-
-        return BitsAndBytesConfig(
-            load_in_4bit=(self.quant_type == "int4"),
-            bnb_4bit_quant_type=quant_type_str,
-            bnb_4bit_compute_dtype=self.compute_dtype,
-            bnb_4bit_use_double_quant=self.use_double_quant,
-            bnb_4bit_quant_storage=self.quant_storage
-        )