Explorar o código

adding quantization dataclass

HamidShojanazeri hai 10 meses
pai
achega
d270deea81
Modificáronse 1 ficheiros con 26 adicións e 0 borrados
  1. 26 0
      src/llama_recipes/configs/quantization.py

+ 26 - 0
src/llama_recipes/configs/quantization.py

@@ -0,0 +1,26 @@
+from dataclasses import dataclass, field
+from typing import List
+
+
+@dataclass
+class QuantizationConfig:
+    quant_type: str  # "int4" or "int8"
+    compute_dtype: torch.dtype
+    use_double_quant: bool = False
+    quant_storage: torch.dtype = torch.bfloat16
+
+    def create_bnb_config(self):
+        if self.quant_type == "int4":
+            quant_type_str = "nf4"
+        elif self.quant_type == "int8":
+            quant_type_str = "int8"
+        else:
+            raise ValueError("quant_type must be either 'int4' or 'int8'")
+
+        return BitsAndBytesConfig(
+            load_in_4bit=(self.quant_type == "int4"),
+            bnb_4bit_quant_type=quant_type_str,
+            bnb_4bit_compute_dtype=self.compute_dtype,
+            bnb_4bit_use_double_quant=self.use_double_quant,
+            bnb_4bit_quant_storage=self.quant_storage
+        )